from geopy.distance import geodesic
import mysql.connector
import json
import random
import numpy as np
from sklearn.cluster import DBSCAN
from shapely.geometry import Point, LineString, MultiPoint, Polygon
from shapely.ops import transform
from pyproj import Proj, transform as pyproj_transform
import numpy as np
from scipy.spatial.distance import cdist
from itertools import combinations

def generate_random_color():
    # """Generate a random color in HEX format."""
    return "#{:02x}{:02x}{:02x}".format(random.randint(0, 255), random.randint(0, 255), random.randint(0, 255))



def haversine(lat1, lon1, lat2, lon2):
    """Vectorized haversine formula."""
    R = 6371  # Radius of the Earth in km
    lat1, lon1, lat2, lon2 = map(np.radians, [lat1, lon1, lat2, lon2])
    dlat = lat2 - lat1
    dlon = lon2 - lon1
    a = np.sin(dlat / 2) ** 2 + np.cos(lat1) * np.cos(lat2) * np.sin(dlon / 2) ** 2
    c = 2 * np.arctan2(np.sqrt(a), np.sqrt(1 - a))
    return R * c

def haversine_distance_matrix(coordinates):
    """Compute the distance matrix using vectorization."""
    coords = np.radians(coordinates)  # Convert to radians for haversine
    lat = coords[:, 0][:, np.newaxis]
    lon = coords[:, 1][:, np.newaxis]
    return haversine(lat, lon, lat.T, lon.T)

def max_distance_in_group(coordinates):
    """
    Calculate the maximum distance (diameter) between two points in a group.
    """
    # Ensure coordinates are tuples of two elements (lon, lat)
    valid_coordinates = [(coord[0], coord[1]) for coord in coordinates if len(coord) >= 2]

    if len(valid_coordinates) < 2:
        return 0  # No distance for a single point or no valid points

    # Compute pairwise distances
    distances = [
        geodesic(coord1, coord2).km
        for coord1, coord2 in combinations(valid_coordinates, 2)
    ]
    return max(distances)


def split_group_by_sales_with_dbscan(groups, max_chiffre, max_diameter_km, eps=0.5, min_samples=1):
    """
    Divides groups into subgroups based on max_chiffre while using DBSCAN to group nearby positions 
    and respect a maximum diameter.

    Args:
        groups (dict): Input groups with clients and coordinates.
        max_chiffre (float): Maximum total sales per group.
        max_diameter_km (float): Maximum allowable diameter for a group.
        eps (float): DBSCAN epsilon parameter (in kilometers).
        min_samples (int): Minimum number of samples per cluster for DBSCAN.

    Returns:
        list: Final list of grouped clients with constraints applied.
    """
    final_groups = []
    group_id_counter = 1

    for cluster in groups.values():
        clients = cluster["clients"]
        coordinates = [(client["longitude"], client["latitude"]) for client in clients]

        # Apply DBSCAN to group nearby points
        dbscan = DBSCAN(eps=eps / 6371.0, min_samples=min_samples, metric="haversine")  # Convert eps to radians
        labels = dbscan.fit_predict(np.radians(coordinates))  # DBSCAN uses radians for haversine

        # Create subgroups based on DBSCAN labels
        sub_clusters = {}
        for label, client in zip(labels, clients):
            if label not in sub_clusters:
                sub_clusters[label] = {"clients": [], "total_sales": 0, "coordinates": []}
            sub_clusters[label]["clients"].append(client)
            sub_clusters[label]["total_sales"] += client["chiffre_vente"]
            sub_clusters[label]["coordinates"].append((client["longitude"], client["latitude"]))

        # Split each DBSCAN cluster into smaller groups based on constraints
        for sub_cluster in sub_clusters.values():
            current_subgroup = []
            current_sales_sum = 0

            for client in sub_cluster["clients"]:
                client_sales = client["chiffre_vente"]
                # Build current coordinates list with only valid tuples
                current_coordinates = [(c["longitude"], c["latitude"]) for c in current_subgroup] + [(client["longitude"], client["latitude"])]

                # Check constraints
                if (
                    current_sales_sum + client_sales <= max_chiffre and
                    max_distance_in_group(current_coordinates) <= max_diameter_km
                ):
                    current_subgroup.append(client)
                    current_sales_sum += client_sales
                else:
                    # Save the current subgroup and reset
                    if current_subgroup:
                        final_groups.append({
                            "group_id": group_id_counter,
                            "clients": current_subgroup,
                            "total_sales": current_sales_sum,
                            "color": generate_random_color(),
                            "coordinates": [(c["longitude"], c["latitude"]) for c in current_subgroup],
                        })
                        group_id_counter += 1

                    # Start a new subgroup with the current client
                    current_subgroup = [client]
                    current_sales_sum = client_sales

            # Save the last subgroup if it exists
            if current_subgroup:
                final_groups.append({
                    "group_id": group_id_counter,
                    "clients": current_subgroup,
                    "total_sales": current_sales_sum,
                    "color": generate_random_color(),
                    "coordinates": [(c["longitude"], c["latitude"]) for c in current_subgroup],
                })
                group_id_counter += 1

    return final_groups

def add_geometric_properties(groups):
    """
    Ajoute les propriétés géométriques à chaque groupe.
    """
    for group in groups:
        coordinates = group["coordinates"]
        group["line"] = coordinates  # Liste ordonnée des points GPS

        # Calculer la longueur totale de la ligne
        if len(coordinates) >= 2:
            line_length = sum(
                geodesic(coordinates[i], coordinates[i + 1]).km
                for i in range(len(coordinates) - 1)
            )
            group["line_length"] = line_length
        else:
            group["line_length"] = 0  # Pas de ligne possible avec moins de deux points

        # Générer un polygone convexe englobant
        if len(coordinates) >= 3:
            points = MultiPoint([Point(lon, lat) for lon, lat in coordinates])
            polygon = points.convex_hull  # Polygone convexe
            if isinstance(polygon, Polygon):
                group["layer_polygon_wkt"] = polygon.wkt  # Représentation WKT
                group["layer_polygon"] = list(map(list, polygon.exterior.coords))  # Coordonnées du polygone

                # Calculer l'aire et la longueur du polygone en utilisant une projection métrique
                transformer = Proj(init="epsg:4326")  # WGS84
                metric_transformer = Proj(init="epsg:3395")  # Mercator
                polygon_metric = transform(metric_transformer.transform, polygon)

                group["polygon_area"] = polygon_metric.area / 1_000_000  # Aire en km²
            else:
                group["layer_polygon_wkt"] = None
                group["layer_polygon"] = []
                group["polygon_area"] = 0
        else:
            group["layer_polygon_wkt"] = None
            group["layer_polygon"] = []
            group["polygon_area"] = 0

    return groups


def group_clients_by_sales(min_chiffre, max_chiffre, max_diameter_km):
    # Connexion à la base de données
    conn = mysql.connector.connect(
        host="51.77.140.160",
        user="bassem",
        password="Clediss1234++",
        database="dist_utic",
        port=3306
    )

    try:
        cursor = conn.cursor(dictionary=True)
        # Requête SQL pour récupérer les données des clients
        query = """
            SELECT 
                c.id, c.code, c.nom, c.prenom, c.latitude, c.longitude, 
                c.gouvernorat, c.delegation, c.localite, c.region,c.zone,
                (SELECT SUM(e.net_a_payer)  
                 FROM entetecommercials as e 
                 WHERE e.client_code=c.code AND e.type IN ('bl', 'avoir')) AS chiffre_vente
            FROM clients as c
            WHERE c.isactif=1 AND c.latitude IS NOT NULL AND c.longitude IS NOT NULL
        """
        cursor.execute(query)
        data = cursor.fetchall()

        # Préparation des données
        clients = []
        coordinates = []
        for client in data:
            chiffre_vente = client.get('chiffre_vente', 0) or 0
            latitude = float(client['latitude']) if client['latitude'] else None
            longitude = float(client['longitude']) if client['longitude'] else None

            if latitude and longitude:
                clients.append({
                    "id": client["id"],
                    "code": client["code"],
                    "nom": client["nom"],
                    "prenom": client["prenom"],
                    "gouvernorat": client["gouvernorat"],
                    "delegation": client["delegation"],
                    "localite": client["localite"],
                    "region": client["region"],
                    "zone": client["zone"],
                    "latitude": latitude,
                    "longitude": longitude,
                    "chiffre_vente": chiffre_vente
                })
                coordinates.append((latitude, longitude))

        # Utiliser DBSCAN pour le regroupement géographique
        distance_matrix = haversine_distance_matrix(coordinates)
        dbscan = DBSCAN(eps=2, min_samples=1, metric="precomputed")
        labels = dbscan.fit_predict(distance_matrix)

        # Regrouper les clients par étiquette
        clusters = {}
        for label, client in zip(labels, clients):
            if label not in clusters:
                clusters[label] = {"clients": [], "total_sales": 0}
            clusters[label]["clients"].append(client)
            clusters[label]["total_sales"] += client["chiffre_vente"]

        # Diviser les groupes selon la contrainte de max_chiffre et max_diameter_km
        valid_groups = split_group_by_sales_with_dbscan(clusters, max_chiffre, max_diameter_km)

        # Ajouter des propriétés géométriques
        final_groups = add_geometric_properties(valid_groups)

         # Trier les groupes par total_sales en ordre décroissant
        final_groups_sorted = sorted(final_groups, key=lambda x: x["total_sales"], reverse=True)

        # Afficher le résultat sous forme JSON
        response = json.dumps(final_groups_sorted, indent=4, ensure_ascii=False)
        print(response)

    finally:
        conn.close()


if __name__ == "__main__":
    import sys
    if len(sys.argv) != 4:
        print("Usage: python Tournees.py <min_chiffre> <max_chiffre> <max_diameter_km>")
        sys.exit(1)

    min_chiffre = float(sys.argv[1])
    max_chiffre = float(sys.argv[2])
    max_diameter_km = float(sys.argv[3])
    group_clients_by_sales(min_chiffre, max_chiffre, max_diameter_km)
