"""
This module contain the implementation of artificial data generator for set data

parameters:
* data_size : (int) an integer number specifies number of total number of data that will be generated
* size_of_clusters : (numpy arry) specifies size for each cluster. If empty array is passed then the size of all cluster will be the same.
                Note that len of array should equal to number_of_cluster and sum of this array should equal to data_size
* number_of_cluster : (int) an integer number specifies number of cluster to create
* dimension : (int) an integer number specifies total number of features that will be generate in the data set
* distance_threshold : (float) a number specifies the maximum distance away from the cluster representative according to Jaccard's method
* size_of_set : (tuple(int,int)) a tuple of intergers specifies the minimum and maximum feature that each data has to contain
* all_features : (string[]) an array of string containing all possible features of the dataset
* gt_representative : (string[][]) if empty the program will randomly generate the cluster representatives, else program uses the provides values as cluster representatives

returns:
type : tuple
* artificially generated data set as a list of numpy array containing strings
* ground truths labels  of the data set
* list of numpy array (cluster representative / medoids)
"""

import math
import numpy as np
import random
import distances

def generate(
    data_size, 
    size_of_clusters, 
    number_of_cluster, 
    dimension, 
    distance_threshold, 
    size_of_set,
    all_features,
    gt_representative):

  
    if len(gt_representative) == 0:
        print('=== Generating representatives... ===')
        representatives = create_cluster_representatives(
            number_of_cluster, 
            size_of_set, 
            all_features)
    else: representatives = gt_representative

    if len(representatives) != number_of_cluster:
        raise('number of representatives and number of clusters are not equal')

    print('=== Generating data and ground truths labels... ===')
    data, ground_truth_labels = generate_cluster_members(
        data_size, 
        representatives, 
        size_of_set, 
        all_features,
        distance_threshold,
        size_of_clusters)
    
    print('=== Calculating pairwise distances and then overlap... ===')
    overlap_percentage = calculate_overlap(data, ground_truth_labels, representatives)

    print('=== Shuffling data and lables... ===')
    combined = list(zip(data, ground_truth_labels))
    random.shuffle(combined)
    data[:], ground_truth_labels[:] = zip(*combined)
    
    return(data, ground_truth_labels, representatives, overlap_percentage)

def _find_closest_member_from_other_clusters(own_cluster_id, pw_dist, ground_truth_labels, data_id):
    all_member_from_other_cluster_ids = np.where(ground_truth_labels != own_cluster_id)[0]

    pw_excluded_own_cluster = np.full((1, pw_dist.shape[1]), np.Inf)
    pw_excluded_own_cluster[:, all_member_from_other_cluster_ids] = pw_dist[data_id, all_member_from_other_cluster_ids]

    min_index = np.argmin(pw_excluded_own_cluster)
    return min_index

def calculate_overlap(data, ground_truth_labels, representatives):
    
    pw_dist = distances.calculate_pairwise_distance(np.array(data))
    
    overlap_count = 0
    for i in range(len(data)):
        data_id = _find_closest_member_from_other_clusters(ground_truth_labels[i], pw_dist, ground_truth_labels, i)
        data_closest = data[data_id]
        if distances.jaccard_seq(data[i], data_closest) < distances.jaccard_seq(data[i], representatives[ground_truth_labels[i]]):
            overlap_count = overlap_count + 1

    overlap_percentage = (overlap_count * 100) / len(data)

    return overlap_percentage

DEFAULT_STRING = '   '

def create_cluster_representatives(number_of_cluster, size_of_set, all_features):

    all_representative = []
    for i in range(number_of_cluster):
        representative = np.array(_create_cluster_member_random(all_features, 6, 6))
        all_representative.append(representative)
        # representative = np.array(_create_cluster_member_random(all_features, size_of_set[0], size_of_set[1]))
        
    return all_representative

def _create_cluster_member_random(all_features, min_feature, max_feature):

    number_of_member_features = random.randrange(min_feature, max_feature + 1)
    return random.sample(all_features.tolist(), number_of_member_features)

def _find_number_of_member_per_cluster(data_size, representatives):
    n_centers = len(representatives)
    number_of_data_per_cluster = [int((data_size) // n_centers)] * n_centers

    for i in range(data_size % n_centers):
        number_of_data_per_cluster[i] += 1

    return number_of_data_per_cluster

def _create_cluster_member_random_different_feature_length(
    representative, 
    all_features, 
    min_feature, 
    max_feature,
    distance_threshold):
    
    similarity_threshold = 1 - distance_threshold
    
    max_feature_temp = math.floor(len(representative) / similarity_threshold)

    number_of_member_features = random.randrange(min_feature, max_feature_temp + 1)
    min_intersect = math.floor((similarity_threshold * (len(representative) + number_of_member_features)) / (1 + similarity_threshold))

    if (min_intersect > number_of_member_features or min_intersect > len(representative)):
        print('--- Bad data object ---')

    new_member = np.full(number_of_member_features, DEFAULT_STRING)
    number_of_intersect = random.randint(min_intersect, min(number_of_member_features, len(representative)))
    total_available_features = [x for x in all_features if x not in representative]
    
    if number_of_member_features > len(representative):
        new_member[0:len(representative)] = representative
        id_to_change = random.sample(range(len(representative)), len(representative) - number_of_intersect)
        empty_space = len(np.where(new_member == DEFAULT_STRING)[0])
        new_features_pool = random.sample(total_available_features, len(id_to_change) + empty_space)
        new_member[len(representative):] = new_features_pool[0:(number_of_member_features - len(representative))]
        new_member[id_to_change] = new_features_pool[(number_of_member_features - len(representative)):]
    elif number_of_member_features < len(representative):
        selected_id = random.sample(range(len(representative)), number_of_member_features)
        new_member = representative[selected_id]
        new_features = random.sample(total_available_features, len(new_member) - number_of_intersect)
        id_to_change = random.sample(range(len(new_member)), len(new_member) - number_of_intersect)
        new_member[id_to_change] = new_features
    else:
        new_member = np.copy(representative)
        id_to_change = random.sample(range(number_of_member_features), number_of_member_features - number_of_intersect)
        new_features = random.sample(total_available_features, len(id_to_change))
        new_member[id_to_change] = new_features
        
    return new_member

def generate_cluster_members(
    data_size, 
    representatives, 
    size_of_set,
    all_features,
    distance_threshold,
    size_of_clusters):
    
    if len(size_of_clusters) == 0:
        number_of_data_per_cluster = _find_number_of_member_per_cluster(data_size, representatives)
    else : number_of_data_per_cluster = size_of_clusters

    data = []
    ground_truth_labels = []

    for i in range(len(representatives)):
        cluster_len = 0

        data.append(representatives[i])
        ground_truth_labels.append(i)

        cluster_len = 1

        while cluster_len != number_of_data_per_cluster[i]:
            member = _create_cluster_member_random_different_feature_length(
                representatives[i], 
                all_features, 
                size_of_set[0], 
                size_of_set[1], 
                distance_threshold)

            if distances.jaccard_seq(representatives[i], member) < distance_threshold:
                data.append(member)
                ground_truth_labels.append(i)
                cluster_len = cluster_len + 1

    if len(np.array(ground_truth_labels)) != len(data):
        raise Exception('Program terminated, lengths of data and ground truths labels are not equal')

    return (data, np.array(ground_truth_labels))


