#!/usr/bin/env python3
# -*- coding: utf-8 -*-
#
# License: GNU General Public License v2
# Author: thl-cmk[at]outlook[dot]com
# URL   : https://thl-cmk.hopto.org
# Date  : 2023-10-12
# File  : create_topology_utils.py

from os import environ
from json import dumps
import socket
from ast import literal_eval
from time import time as now_time
from tomllib import loads as toml_loads
from tomllib import TOMLDecodeError
from re import match as re_match
from pathlib import Path
from typing import List, Dict, Any
from enum import Enum, unique


@unique
class ExitCodes(Enum):
    OK = 0
    BAD_OPTION_LIST = 1
    BAD_TOML_FORMAT = 3


# constants
CREATE_TOPOLOGY_VERSION = '0.1.0-202310128'
SCRIPT = '~/local/bin/network-topology/create_topology_data.py'
SAMPLE_SEEDS = 'Core01 Core02'
PATH_CDP = 'networking,cdp_cache'
PATH_LLDP = 'networking,lldp_cache'
PATH_INTERFACES = 'networking,interfaces'
LABEL_CDP = 'inv_CDP'
LABEL_LLDP = 'inv_LLDP'
COLUMNS_LLDP = 'system_name,local_port_num,port_id'
COLUMNS_CDP = 'device_id,local_port,device_port'
USER_DATA_FILE = 'create_topology_data.toml'
CACHE_INTERFACES_ITEM = 'interface_items'
LAYERS = {
    'CDP': {'path': PATH_CDP, 'columns': COLUMNS_CDP, 'label': LABEL_CDP},
    'LLDP': {'path': PATH_LLDP, 'columns': COLUMNS_LLDP, 'label': LABEL_LLDP},
}
OMD_ROOT = environ["OMD_ROOT"]


def get_data_form_live_status(query: str):
    address = f'{OMD_ROOT}/tmp/run/live'
    family = socket.AF_INET if type(address) is tuple else socket.AF_UNIX
    sock = socket.socket(family, socket.SOCK_STREAM)
    sock.connect(address)
    sock.sendall(query.encode())
    sock.shutdown(socket.SHUT_WR)
    chunks = []
    while len(chunks) == 0 or chunks[-1] != "":
        chunks.append(sock.recv(4096).decode())
    sock.close()
    if len(chunks):
        reply = "".join(chunks).strip()
        reply = literal_eval(reply)
        if reply != [[b'']]:
            return reply


def get_data_from_toml(file: str, debug: bool = False) -> Dict:
    data = {}
    toml_file = Path(file)
    if toml_file.exists():
        try:
            data = toml_loads(toml_file.read_text())
        except TOMLDecodeError as e:
            print(
                f'ERROR: data file {toml_file} is not in valid TOML format! ({e}), (see https://toml.io/en/)')
            exit(code=ExitCodes.BAD_TOML_FORMAT.value)
    else:
        print(f'WARNING: User data {file} not found.')
    if debug:
        print(f'TOML file read: {file}')
        print(f'Data from toml file: {data}')
    return data


def rm_tree(root: Path):
    # safety
    if not str(root).startswith(f'{OMD_ROOT}/var/topology_data'):
        print(f"WARNING: bad path to remove, {str(root)}, don\'t delete it.")
        return
    for p in root.iterdir():
        if p.is_dir():
            rm_tree(p)
        else:
            p.unlink()
    root.rmdir()


def remove_old_data(keep: int, min_age: int, path: str, protected: List[str], debug: bool = False):
    path = Path(path)
    default_topo = path.joinpath('default')
    directories = [str(directory) for directory in list(path.iterdir())]
    # keep default top
    if str(default_topo) in directories:
        directories.remove(str(default_topo))
        keep -= 1
        if default_topo.is_symlink():
            try:
                directories.remove(str(default_topo.readlink()))
            except ValueError:
                pass

    # keep protected topologies
    for directory in protected:
        try:
            directories.remove(str(path.joinpath(directory)))
        except ValueError as e:
            if debug:
                print(e)
                print(directory)
                print(str(path.joinpath(directory)))
                print(directories)
        else:
            print(f'Protected topology: {directory}, will not be deleted.')

    if len(directories) < keep < 1:
        return

    topo_by_age = {}

    for directory in directories:
        if Path(directory).is_dir():
            topo_by_age[Path(directory).stat().st_ctime] = directory

    topo_age = list(topo_by_age.keys())
    topo_age.sort()

    while len(topo_by_age) > keep:
        if min_age * 86400 > now_time() - topo_age[0]:
            print(f'Topology "{Path(topo_by_age[topo_age[0]]).name}" not older then {min_age} day(s). not deleted.')
            return
        print(f'delete old topology: {topo_by_age[topo_age[0]]}')
        rm_tree(Path(topo_by_age[topo_age[0]]))
        topo_by_age.pop(topo_age[0])
        topo_age.pop(0)


def save_data_to_file(data: Dict, path: str, file: str, make_default: bool):
    """
    Save the data as json file.
    Args:
        data: the topology data
        path: the path were to save the dat
        file: the file name to save  data in
        make_default: if True, create the symlink "default" with path as target

    Returns:
        None
    """

    path_file = f'{path}/{file}'
    save_file = Path(f'{path_file}')
    save_file.parent.mkdir(exist_ok=True, parents=True)
    save_file.write_text(dumps(data))
    if make_default:
        parent_path = Path(f'{path}').parent
        Path(f'{parent_path}/default').unlink(missing_ok=True)
        Path(f'{parent_path}/default').symlink_to(target=Path(path), target_is_directory=True)


def save_topology(
        data: dict,
        base_directory: str,
        output_directory: str,
        dont_compare: bool,
        make_default: bool,
        topology_file_name: str,
):
    path = f'{base_directory}/{output_directory}'

    def _save():
        save_data_to_file(
            data=data,
            path=path,
            file=topology_file_name,
            make_default=make_default,
        )

    if dont_compare:
        _save()
    else:
        if not is_equal_with_default(
                data=data,
                file=f'{base_directory}/default/{topology_file_name}'
        ):
            _save()
        else:
            print(
                'Topology matches default topology, not saved! Use "--dont-compare" to save identical topologies.'
            )


def is_mac_address(mac_address: str, debug: bool = False) -> bool:
    """
    Checks if mac_address is a valid MAC address. Will only accept MAC address in the form "AA:BB:CC:DD:EE:FF"
    (lower case is also ok).
    Args:
        mac_address: the MAC address to check
        debug: optional. If True the result of the function will be printed to stdout

    Returns:
        True if mac_address is a valid MAC address
        False if mac_address not a valid MAC address
    """
    re_mac_pattern = '([0-9A-Z]{2}\\:){5}[0-9A-Z]{2}'
    if re_match(re_mac_pattern, mac_address.upper()):
        if debug:
            print(f'mac: {mac_address}, match')
        return True
    else:
        if debug:
            print(f'mac: {mac_address}, no match')
        return False


def is_list_of_str_equal(list1: List[str], list2: List[str]) -> bool:
    """
    Compares two list of strings. Before compared the list will internal sorted.
    Args:
        list1:
        list2:

    Returns:
        True if both lists match
        False if they don't match
    """
    tmp_list1 = list1.copy()
    tmp_list2 = list2.copy()
    tmp_list1.sort()
    tmp_list2.sort()
    return tmp_list1 == tmp_list2


def merge_topologies(topo_pri: Dict, topo_sec: Dict) -> Dict:
    """
    Merge dict_prim into dict_sec
    Args:
        topo_pri: data of dict_pri will overwrite the data in dict_sec
        topo_sec: dict where the data of dict_pri will be merged to

    Returns:
        Dict: topo_sec that contains merged data from top_sec and top_pri
    """
    keys_pri = list(topo_pri.keys())

    # first transfer all completely missing items from dict_prim to dict_sec
    for key in keys_pri:
        if key not in topo_sec.keys():
            topo_sec[key] = topo_pri[key]
        else:
            topo_sec[key]['connections'].update(topo_pri[key].get('connections', {}))
            topo_sec[key]['interfaces'] = list(set((topo_sec[key]['interfaces'] + topo_pri[key].get('interfaces', []))))
        topo_pri.pop(key)
    return topo_sec


def compare_dicts(dict1: Dict, dict2: Dict) -> bool:
    # check top level keys
    if not is_list_of_str_equal(list(dict1.keys()), list(dict2.keys())):
        # print('top level dont match')
        # print(f'dict1: {list(dict1.keys())}')
        # print(f'dict1: {list(dict2.keys())}')
        return False

    for key, value in dict1.items():
        _type = type(value)
        if _type == dict:
            if not compare_dicts(value, dict2[key]):
                return False
        elif _type == list:
            if not is_list_of_str_equal(value, dict2[key]):
                # print(f'list1: {value}')
                # print(f'list2: {dict2[key]}')
                return False
        elif _type == str:
            if not value == dict2[key]:
                # print('value dont match')
                # print(f'value1: {value}')
                # print(f'value2 {dict2[key]}')
                return False
        else:
            return False

    return True


def is_equal_with_default(data: Dict, file: str) -> bool:
    default_file = Path(file)
    if default_file.exists():
        default_data = literal_eval(default_file.read_text())
        return compare_dicts(data, default_data)


def get_table_from_inventory(inventory: Dict[str, Any], path: List[str]) -> List | None:
    path = ('Nodes,' + ',Nodes,'.join(path.split(',')) + ',Table,Rows').split(',')
    table = inventory.copy()
    for m in path:
        try:
            table = table[m]
        except KeyError:
            return
    return table