"""
generate_spartnkey.py

This script generates a UBX-RXM-SPARTNKEY message used to transfer dynamic SPARTN keys to a u-blox receiver. 
The UBX-RXM-SPARTNKEY message is structured as follows:

- Header: Sync characters, message class, and ID.
- Length: Specifies the length of the payload.
- Payload: Contains the message content, including version, number of keys, reserved bytes, and key data.
- Checksum: Ensures data integrity.

The payload consists of several parts:
1. Version: Message version (typically 0x01).
2. numKeys: Number of keys in the message (0, 1, or 2).
3. reserved0: Reserved bytes.
4. Repeated group for each key, including reserved1, keyLengthBytes, validFromWno (week number), validFromTow (time of week), and the key value itself.

The script reads a JSON configuration file to retrieve the 'current' and 'next' SPARTN keys. Each key's data is then processed to form its respective payload segment. The final UBX message is constructed by concatenating the header, calculated payload length, payload, and checksum. The output is a hexadecimal string representing the UBX-RXM-SPARTNKEY message.
"""


import json
import struct
import datetime

import serial


# Constants - EDIT THESE
JSON_FILE_PATH = r"C:\Users\Jon\response.json"
PORT = "COM71"
BAUD = "38400"

# DONT TOUCH ANYTHING BELOW THIS LINE

GPS_EPOCH = datetime.datetime(1980, 1, 6, 0, 0, 0, tzinfo=datetime.timezone.utc)
LEAP_SECONDS = 18  # Current time offset between GPS and UTC time

def read_json_file(file_path):
    """
    Read JSON data from a file and return the parsed JSON.

    Args:
        file_path (str): The file path to the JSON file containing the key data.

    Returns:
        dict: Parsed JSON data from the file.
    """
    with open(file_path, "r", encoding="utf-8") as file:
        return json.load(file)


def calculate_gps_time_of_week(timestamp_ms):
    """
    Calculate the GPS Time of Week (TOW) dynamically, adjusting for leap seconds.

    Args:
        timestamp_ms (int): The timestamp in milliseconds (Unix time, UTC-based).

    Returns:
        int: The GPS time of week in seconds (0-604799).
    """
    date = datetime.datetime.fromtimestamp((timestamp_ms / 1000.0) + LEAP_SECONDS, datetime.timezone.utc)

    # Compute total seconds since GPS epoch (with leap second adjustment)
    total_seconds_since_gps_epoch = (date - GPS_EPOCH).total_seconds()

    # Compute TOW (remainder of seconds within the current GPS week)
    tow = int(total_seconds_since_gps_epoch % 604800)

    return tow  # Ensure integer output


def calculate_gps_week_number(timestamp_ms):
    """
    Calculate the GPS week number from a given timestamp, adjusting for leap seconds.

    Args:
        timestamp_ms (int): The timestamp in milliseconds (Unix time, UTC-based).

    Returns:
        int: The GPS week number.
    """
    date = datetime.datetime.fromtimestamp((timestamp_ms / 1000.0) + LEAP_SECONDS, datetime.timezone.utc)
    return (date - GPS_EPOCH).days // 7

def create_key_payload(key_data):
    """
    Create the payload part of the key from the given key data.

    Args:
        key_data (dict): Contains 'start' timestamp and 'value' of the key.

    Returns:
        bytes: The binary payload for the key.

    Raises:
        ValueError: If the key value length is not 16 bytes.
    """
    week_number = calculate_gps_week_number(key_data["start"])
    gps_time_of_week = calculate_gps_time_of_week(key_data["start"])  # Dynamic TOW
    key_value = bytes.fromhex(key_data["value"])
    if len(key_value) != 16:
        raise ValueError("Key length is not 16 bytes")
    return struct.pack("<BBHI", 0, 16, week_number, gps_time_of_week) + key_value


def calculate_checksum(data):
    """
    Calculate the UBX checksum which is a 2-byte Fletcher Algorithm checksum.

    Args:
        data (bytes): The data over which the checksum is to be calculated.

    Returns:
        tuple: Two-byte checksum.
    """
    ck_a = 0
    ck_b = 0
    for byte in data:
        ck_a = (ck_a + byte) & 0xFF
        ck_b = (ck_b + ck_a) & 0xFF
    return ck_a, ck_b


def main():
    """
    Main function for generating and sending the UBX-RXM-SPARTNKEY message.
    Steps:
      1. Load JSON key data.
      2. Extract 'current' and 'next' keys (supports both MQTT and pointPerfectThingCreds formats).
      3. Create binary payloads for the keys.
      4. Construct the UBX message (header, payload length, payload, checksum).
      5. Print the final message in hex and send it via serial.
    """

    # Load the JSON data containing key information.
    json_data = read_json_file(JSON_FILE_PATH)
    if json_data is None:
        print("Failed to load JSON data.")
        return
    
    # Extract the 'current' and 'next' key data from the JSON file.
    # Supports both MQTT and pointPerfectThingCreds (API) formats.
    if "MQTT" in json_data and "dynamickeys" in json_data["MQTT"]:
        current_key = json_data["MQTT"]["dynamickeys"].get("current")
        next_key = json_data["MQTT"]["dynamickeys"].get("next")
    elif "pointPerfectThingCreds" in json_data:
        creds_key = next(iter(json_data["pointPerfectThingCreds"]))
        current_key = json_data["pointPerfectThingCreds"][creds_key]["dynamickeys"].get("current")
        next_key = json_data["pointPerfectThingCreds"][creds_key]["dynamickeys"].get("next")
    else:
        raise ValueError("Unrecognized JSON format. Please verify your JSON structure.")

    # Validate that both keys contain the required fields.
    if not current_key or "start" not in current_key or "value" not in current_key:
        raise ValueError("Missing 'start' and/or 'value' in the current key. Check your JSON file.")
    if not next_key or "start" not in next_key or "value" not in next_key:
        raise ValueError("Missing 'start' and/or 'value' in the next key. Check your JSON file.")

    # Create binary payloads for the keys (each payload is 24 bytes).
    current_key_payload = create_key_payload(current_key)
    next_key_payload = create_key_payload(next_key)
   
    # Define the UBX message header (sync characters, message class, and ID).
    header = b"\xb5\x62\x02\x36"

    # Calculate payload length:
    # - 4 bytes for version, number of keys, and reserved bytes.
    # - 24 bytes per key payload (first 8 bytes: metadata, next 16 bytes: key value).
    payload_length = 4 + (2 * 24)  
    payload_length_bytes = struct.pack("<H", payload_length)

    # Construct the payload:
    # 1. Version (1 byte), number of keys (1 byte), reserved (2 bytes)
    # 2. Key metadata for each key (first 8 bytes of each payload, excluding the 16-byte key value)
    # 3. Actual key values (16 bytes each)
    payload = struct.pack(
        "<BB2s", 1, 2, b"\x00\x00")  
    payload += (current_key_payload[:-16] + next_key_payload[:-16])  
    payload += (current_key_payload[-16:] + next_key_payload[-16:])  

    # Calculate the UBX checksum using a 2-byte Fletcher Algorithm.
    ck_a, ck_b = calculate_checksum(header[2:] + payload_length_bytes + payload)

    # Construct the final UBX message.
    ubx_message = header + payload_length_bytes + payload + bytes([ck_a, ck_b])
    formatted_hex = " ".join(f"{byte:02x}" for byte in ubx_message)
    print(formatted_hex)

    # Convert the formatted hex string to bytes and send it via the serial port.
    spartnkey = bytes.fromhex(formatted_hex)
    try:
        with serial.Serial(PORT, BAUD, timeout=1) as ser:
            ser.write(spartnkey)
        print("Message sent successfully.")
    except serial.SerialException as e:
        print(f"Serial communication error: {e}")

# Ensure that the main function is called when the script is executed directly.
if __name__ == "__main__":
    main()
