#!/usr/bin/env python3

# Copyright 2025 u-blox AG
# SPDX-License-Identifier: Apache-2.0

"""
This example illustrates the use of AssistNow predictive/live orbits
using Zero Touch Provisioning (ZTP) for authentication.

The goal of this example is to illustrate how to obtain AssistNow data
with a short self-contained script. It is not intended to be a complete
implementation.

Requirements:
  pip install pyserial requests
"""

__version__ = "v1.1.0"

import argparse
import json
import queue
import re
import threading
import time
import uuid

# pip install pyserial requests
import requests
import serial

# Poll UBX-SEC-UNIQID from the receiver
CMD_POLL_UNIQID = b"\xb5\x62\x27\x03\x00\x00\x2a\xa5"
# Poll UBX-MON-VER from the receiver
CMD_POLL_MONVER = b"\xb5\x62\x0a\x04\x00\x00\x0e\x34"
# Configure CFG-NAVSPG-ACKAIDING in RAM to enable MGA-ACK messages
CFG_RAM_ACKAIDING = b"\xb5\x62\x06\x8a\x09\x00\x00\x01\x00\x00\x25\x00\x11\x10\x01\xe1\x3e"
# Poll UBX-NAV-STATUS from the receiver
CMD_POLL_NAV_STATUS = b"\xb5\x62\x01\x03\x00\x00\x04\x0d"
# Cold boot the receiver
CMD_RESET = b"\xb5\x62\x06\x04\x04\x00\xff\xff\x02\x00\x0e\x61"

ZTP_ENDPOINT = "https://api.thingstream.io/ztp/assistnow/credentials"
SERIAL_TIMEOUT = 5.0  # seconds
HTTP_TIMEOUT = 5.0  # seconds


def validate_ubx_message(msg):
    """
    Validate a UBX message at the start of msg.
    Returns:
        * The length of the message if the message is valid
        * 0 if there is no valid UBX message at the start of msg
        * -1 if there is not enough data to validate
    """
    try:
        if msg is None or len(msg) < 8:
            raise RuntimeError("Message is None or too short")
        if msg[0:2] != b'\xB5\x62':
            raise RuntimeError("Message does not start with UBX header")
        payload_len = int.from_bytes(msg[4:6], "little")
        if len(msg) < payload_len + 8:
            print("Warning: Message is truncated, waiting for more data")
            return -1 # Message is truncated, return -1 to indicate this
        msg = msg[:payload_len + 8]  # Trim to the expected length
        chk = [0, 0]
        for i in msg[2:-2]:
            chk[0] = (chk[0] + i) & 0xFF
            chk[1] = (chk[1] + chk[0]) & 0xFF
        if bytes(chk) != msg[-2:]:
            raise RuntimeError("Checksum mismatch")
        return payload_len + 8
    except RuntimeError as e:
        print(f"Error: Failed to validate UBX message: {e}")
        return 0

def split_ubx_messages(msgs):
    """
    Generator method to split a byte sequence into valid UBX messages.
    Assumes msgs is a bytes or bytearray object containing only valid
    UBX messages.
    """
    while msgs:
        msg_len = validate_ubx_message(msgs)
        if msg_len <= 0:
            raise RuntimeError("Failed to split UBX messages")
        yield msgs[:msg_len]  # Yield the valid message
        msgs = msgs[msg_len:] # Remove the valid message from the buffer

class UBXReceiver(threading.Thread):
    """
    Handles communication with a u-blox GNSS receiver over a serial port.
    Supports the UBX protocol only. NMEA messages are discarded.
    """
    def __init__(self, device, baudrate=38400):
        super().__init__()
        self.ser = serial.serial_for_url(device, baudrate=baudrate, timeout=0.1)
        self.running = True
        self.buffer = b''
        self.queue = queue.Queue()
        self.expected_prefix = b''
        self.start()

    def _discard_until_preamble(self, min_bytes = 0):
        """
        Discard data in the buffer until the UBX preamble is found.
        If min_bytes is specified, it will discard at least that many bytes.
        """
        idx = self.buffer.find(b'\xB5', min_bytes)
        self.buffer = self.buffer[idx:] if idx != -1 else b''

    def run(self):
        """Thread method to read data from the serial port and process UBX messages."""
        try:
            while self.running:
                self.buffer += self.ser.read(1024)  # Read up to 1024 bytes
                self._discard_until_preamble()  # Discard any data before the first potential header
                # check for correct preamble and available data
                while (self.buffer and len(self.buffer) >= 8 and
                       self.buffer[0:2] == b'\xb5\x62' and
                       len(self.buffer) >= int.from_bytes(self.buffer[4:6], "little") + 8):
                    msg_len = validate_ubx_message(self.buffer)
                    if msg_len > 0: # found a valid UBX message
                        # Extract valid message
                        msg, self.buffer = self.buffer[:msg_len], self.buffer[msg_len:]
                        # Queue the message if it matches the expected prefix
                        if self.expected_prefix and msg.startswith(self.expected_prefix):
                            self.queue.put(msg)  # Put the message in the queue
                        # Skip any data before the next preamble
                        self._discard_until_preamble()
                    elif msg_len == 0: # found an invalid UBX message
                        # Discard the first byte and look for a new preamble
                        self._discard_until_preamble(min_bytes=1)
        except Exception as e:
            print(f"Error in UBXReceiver thread: {e}")
            import traceback
            traceback.print_exc()

    def stop(self):
        """Stop the UBXReceiver thread and close the serial port."""
        self.running = False
        self.join()  # Wait for the thread to finish
        if self.ser.is_open:
            self.ser.close()

    def __enter__(self):
        """Context manager entry method to start the UBXReceiver."""
        return self

    def __exit__(self, exc_type, exc_value, traceback):
        """Context manager exit method to stop the UBXReceiver."""
        _ = exc_type, exc_value, traceback  # Unused
        self.stop()

    def send(self, msg):
        """Send the given bytes (usually a UBX message) to the receiver."""
        self.ser.write(msg)

    def send_and_wait(self, msg, prefix, timeout=SERIAL_TIMEOUT):
        """Send the given UBX message and wait for a response matching the prefix."""
        self.expected_prefix = prefix
        self.ser.write(msg)
        try:
            msg = self.queue.get(timeout=timeout)  # Wait for a message in the queue
            return msg
        except queue.Empty:
            print("Error: No response received or timeout")
            return None  # No message received within the timeout

    def poll_ubx_message(self, poll_msg):
        """Poll a UBX message from the receiver and return the message."""
        expected = poll_msg[:4]  # The first 4 bytes of the message
        return self.send_and_wait(poll_msg, expected)

    def send_acked(self, msg, timeout=SERIAL_TIMEOUT):
        """Send a UBX message and wait for an ACK/NAK response."""
        expected = b'\xB5\x62\x05' # ACK or NAK prefix
        ret = self.send_and_wait(msg, expected, timeout)
        if ret and ret[3] != 1:
            print(f"Error: NAK received for message {msg.hex()}")
            return False  # NAK received
        return ret is not None  # True if ACK received, False if timeout

    def send_assistnow_data(self, data):
        """
        Send a sequence of MGA messages and wait for all MGA-ACKs. Returns a
        tuple with the number of sent messages and the number of unacknowledged
        messages.
        """
        self.expected_prefix = b'\xB5\x62\x13\x60'  # MGA-ACK prefix
        messages = list(split_ubx_messages(data))
        for msg in messages:
            self.send(msg)  # Send each MGA message
        # map message ID and first 4 payload bytes to messages for ACK matching
        message_dict = { msg[3:4] + msg[6:10]: msg for msg in messages }  # Map message IDs to messages
        tries = 1
        while message_dict:
            try:
                ack = self.queue.get(timeout=SERIAL_TIMEOUT)  # Wait for an ACK
                key = ack[6+3:6+8]
                del message_dict[key]  # Remove the acknowledged message
            except queue.Empty:  # Return number of sent and unacknowledged messages
                if tries < 3:
                    tries += 1
                    print(f"Warning: Resending {len(message_dict)} unacknowledged messages (try {tries})")
                    for msg in message_dict.values():
                        self.send(msg)  # Resend unacknowledged messages
                else:
                    break
        # count number of unacked keys starting with b'\x20' (MGA-ANO messages)
        num_unacked_ano = sum(1 for key in message_dict if key.startswith(b'\x20'))
        # Return the number of sent messages and the number of unacknowledged messages
        return (len(messages), len(message_dict), num_unacked_ano)

class AssistNowError(Exception):
    """Used for concise error handling where a stack trace is not needed."""
    pass

def run_assistnow(args):
    """Run the AssistNow predictive/live orbit client example."""
    with UBXReceiver(args.port, args.baudrate) as gnss:
        assist_data = b''

        print("Polling UBX-SEC-UNIQID from receiver")
        uniqid = gnss.poll_ubx_message(CMD_POLL_UNIQID)
        if uniqid:
            print(f"UBX-SEC-UNIQID: {uniqid[6+4:6+10].hex()}")
        else:
            raise AssistNowError("Failed to retrieve UBX-SEC-UNIQID")

        print("Polling UBX-MON-VER from receiver")
        monver = gnss.poll_ubx_message(CMD_POLL_MONVER)
        if monver:
            print("UBX-MON-VER:")
            for field in re.split(r'\0+', monver[6:-2].decode('utf-8').strip()):
                if field:  # Skip empty fields
                    print(f"  {field}")
        else:
            raise AssistNowError("Failed to retrieve UBX-MON-VER")

        if args.use_assist:
            # While it is fine to issue a ZTP request every time before getting
            # the AssistNow data, the chipcode could also be cached, as long as the
            # application is able to repeat the ZTP request in cases where the
            # AssistNow data request fails.

            print("Sending ZTP request to obtain chipcode with payload:")
            params = {
                "token": str(args.ztp_token),
                "messages": {
                    "UBX-SEC-UNIQID": uniqid.hex(),
                    "UBX-MON-VER": monver.hex()
                }
            }
            print(json.dumps(params, indent=4))
            headers = {
                "Content-Type": "application/json"  # We are sending JSON data
            }
            ztp_response = requests.post(ZTP_ENDPOINT,
                                        data=json.dumps(params),
                                        headers=headers,
                                        timeout=HTTP_TIMEOUT)
            if ztp_response.status_code == 200:
                ztp_data = ztp_response.json()
            else:
                raise AssistNowError("ZTP request failed "
                                    f"(status {ztp_response.status_code})\n"
                                    f"{ztp_response.text}")

            if 'chipcode' in ztp_data:
                print("Received ZTP response with authorization chipcode")
                # Not printing the chipcode for security reasons
            else:
                raise AssistNowError("ZTP response does not contain chipcode")

            print(f"Available data types for this device: {ztp_data['allowedData']}")

            # Now that we have the chipcode, we can request the AssistNow data.

            if args.live: # live orbits + almanac
                data_types = 'ulorb_l1,ukion,usvht,ualm'
                print(f"Requesting live orbits and almanac data: {data_types}")
            else:   # predictive orbits + almanac
                data_types = 'uporb_1,ualm'
                print(f"Requesting predictive orbits and almanac data: {data_types}")

            # using the url and chipcode returned by ZTP, request assistance data
            params = {
                'chipcode': ztp_data['chipcode'],
                'data': data_types,
                'gnss': 'gps,gal,glo,bds,qzss',
            }
            print(ztp_data['serviceUrl'])
            an_response = requests.get(ztp_data['serviceUrl'],
                                    params=params,
                                    timeout=HTTP_TIMEOUT)
            if an_response.status_code == 200:
                assist_data = an_response.content
                print("AssistNow data received successfully")
            else:
                raise AssistNowError("AssistNow request failed "
                                    f"(status {an_response.status_code})\n"
                                    f"{an_response.text}")

        # Cold boot the receiver to clear ephemeris, predictive orbits, almanac
        print("Resetting the receiver to simulate a cold start")
        gnss.send(CMD_RESET)

        reset_time = time.time()

        # Make sure the receiver is ready to receive commands by polling UNIQID
        print("Waiting for receiver to be ready after reset...")
        for _ in range(10):
            if gnss.poll_ubx_message(CMD_POLL_UNIQID):
                break
        else:
            raise AssistNowError("Receiver not responding after reset")

        # Enable ACKAIDING in RAM configuration to allow MGA-ACK messages
        # Try this a few times, as the receiver may not be ready after reset

        if args.use_assist:
            print("Enabling CFG-NAVSPG-ACKAIDING")
            if not gnss.send_acked(CFG_RAM_ACKAIDING, timeout=1):
                raise AssistNowError("Failed to enable CFG-NAVSPG-ACKAIDING")

            print(f"Sending {len(assist_data)} bytes of AssistNow messages to the receiver...")

            # See UBXReceiver code above. AssistNow data is split into individual
            # MGA messages, each of which is sent to the receiver, then we wait
            # for all the corresponding MGA-ACK messages.
            num_sent, num_unacked, num_unacked_ano = gnss.send_assistnow_data(assist_data)
            if num_unacked > 0:
                print(f"Warning: {num_unacked}/{num_sent} MGA messages not acknowledged")
                if num_unacked_ano > 0:
                    print(f"Error: {num_unacked_ano} predictive orbit messages not acknowledged")
                    print("  Please make sure the receiver supports predictive orbits, or")
                    print("  try using live orbits instead of predictive orbits.")
            else:
                print(f"All {num_sent} MGA messages acknowledged.")

        mga_end_time = time.time()
        mga_duration_ms = int(1000 * (mga_end_time - reset_time))

        print("Waiting for UBX-NAV-STATUS to report time to first fix...")
        while time.time() - mga_end_time < 60:
            nav_status = gnss.poll_ubx_message(CMD_POLL_NAV_STATUS)
            if not nav_status:
                print("Error: Failed to poll UBX-NAV-STATUS")
                return
            ttff = int.from_bytes(nav_status[6+8:6+12], "little")  # Time to first fix
            if ttff > 0:
                if args.use_assist:
                    print(f"Time to inject AssistNow data:       {mga_duration_ms:5d} ms")
                print(f"Time from reset to first fix (TTFF): {ttff:5d} ms")
                break
            time.sleep(1)  # Wait for a while before polling again
        else:
            print("Warning: No fix after 60 seconds, check antenna and sky visibility")


def main():
    ap = argparse.ArgumentParser(description="AssistNow predictive/live orbit client example")
    ap.add_argument("-P", "--port", type=str, required=True,
                    help="Serial port to connect to (e.g. COM1 or /dev/ttyUSB0)")
    ap.add_argument("-B", "--baudrate", type=int, default=38400,
                    help="Baud rate for serial communication (default: 38400)")
    ap.add_argument("-z", "--ztp_token", type=uuid.UUID, required=True,
                    help="ZTP device profile token")
    ap.add_argument("-v", "--version", action='version', version=f"%(prog)s {__version__}",
                    help="Show version and exit")

    data_group = ap.add_argument_group(title="AssistNow Data Options")
    excl_group = data_group.add_mutually_exclusive_group(required=False)
    excl_group.add_argument("-p", "--predictive", action='store_false', dest='live', default=False,
                    help="Obtain predictive orbits from AssistNow (default)")
    excl_group.add_argument("-l", "--live", action='store_true',
                    help="Obtain live orbits instead of predictive orbits")
    excl_group.add_argument("-n", "--no_assist", action='store_false', dest='use_assist',
                    help="Measure TTFF without any AssistNow data")
    args = ap.parse_args()

    try:
        run_assistnow(args)
    except KeyboardInterrupt:
        print("Interrupted by user, exiting...")
    except AssistNowError as e:
        print(f"Error: {e}")


if __name__ == "__main__":
    main()
