from CTimingEstimationResult import CTimingEstimationResult
from CAlgorithmBase import CAlgorithmBase


class CAlgorithmSinglePhoton(CAlgorithmBase):

    def __init__(self, photon_count):
        self.__photon_count = photon_count

    @property
    def algorithm_name(self):
        return "Single"

    @property
    def photon_count(self):
        return self.__photon_count

    def evaluate_collection_timestamps(self, coincidence_collection):
        timestamps_detector1 = np.copy(coincidence_collection.detector1.timestamps[:, self.photon_count])
        timestamps_detector2 = np.copy(coincidence_collection.detector2.timestamps[:, self.photon_count])

        timing_estimation_results = CTimingEstimationResult(self.algorithm_name, self.photon_count, timestamps_detector1, timestamps_detector2)
        return timing_estimation_results

    def evaluate_single_timestamp(self, single_event):
        timestamps = np.copy(single_event.timestamps[self.photon_count])
        return timestamps

CAlgorithmBase.register(CAlgorithmSinglePhoton)
assert issubclass(CAlgorithmSinglePhoton, CAlgorithmBase)
    #
    #     for label in slices:
    #         datasets[label] = slice_data(data, slices[label])
    #     return datasets

    def __train_network(self):

        # climate.enable_default_logging()

        neural_input = self.__event_collection.detector1.timestamps[:, :self.photon_count - 1] \
                       - self.__event_collection.detector1.timestamps[:, self.photon_count - 1:self.photon_count]
        neural_target = np.transpose(np.matrix(self.__event_collection.detector1.interaction_time.ravel()-self.__event_collection.detector1.timestamps[:, self.photon_count - 1:self.photon_count].ravel()))

        self.__neural_network = theanets.Experiment(
            # Neural network for regression (sigmoid hidden, linear output)
            theanets.Regressor,
            # Input layer, hidden layer, output layer
            layers=(self.photon_count - 1, self.__hidden_layers, 1)
        )

        i = 0
        for train, valid in self.__neural_network.itertrain([neural_input, neural_target], optimize='rmsprop'):
            sys.stdout.write('\rNeural network with %d inputs, %d hidden layers - Iteration %d: Training error: %f ' %
                             (self.photon_count - 1, self.__hidden_layers, i, np.sqrt(train['err']) /(np.sqrt(2)/2)))
            i += 1
            sys.stdout.flush()


CAlgorithmBase.register(CAlgorithmNeuralNetwork)
assert issubclass(CAlgorithmNeuralNetwork, CAlgorithmBase)
from CTimingEstimationResult import CTimingEstimationResult
from CAlgorithmBase import CAlgorithmBase


class CAlgorithmMean(CAlgorithmBase):

    def __init__(self, photon_count):
        self.__photon_count = photon_count

    @property
    def algorithm_name(self):
        return "Mean"

    @property
    def photon_count(self):
        return self.__photon_count

    def evaluate_collection_timestamps(self, coincidence_collection):
        timestamps_detector1 = np.mean(coincidence_collection.detector1.timestamps[:, :self.photon_count], axis=1)
        timestamps_detector2 = np.mean(coincidence_collection.detector2.timestamps[:, :self.photon_count], axis=1)

        timing_estimation_results = CTimingEstimationResult(self.algorithm_name, self.photon_count, timestamps_detector1, timestamps_detector2)
        return timing_estimation_results

    def evaluate_single_timestamp(self, single_event):
        timestamps = np.mean(single_event.photon_timestamps[:self.photon_count])
        return timestamps

CAlgorithmBase.register(CAlgorithmMean)
assert issubclass(CAlgorithmMean, CAlgorithmBase)
    @property
    def photon_count(self):
        return self.__photon_count

    def _calculate_coefficients(self):
        corrected_timestamps = self._training_coincidence_collection.detector2.timestamps[:, :self.photon_count] - self._training_coincidence_collection.detector2.interaction_time[:, None]
        covariance = np.cov(corrected_timestamps[:, :self.photon_count], rowvar=0)
        unity = np.ones(self.photon_count)
        inverse_covariance = np.linalg.inv(covariance)
        w = np.dot(unity, inverse_covariance)
        n = np.dot(w, unity.T)
        self._mlh_coefficients = w / n

    def evaluate_collection_timestamps(self, coincidence_collection):
        current_mlh_length = len(self._mlh_coefficients)
        timestamps_detector1 = np.dot(coincidence_collection.detector1.timestamps[:, :current_mlh_length], self._mlh_coefficients)
        timestamps_detector2 = np.dot(coincidence_collection.detector2.timestamps[:, :current_mlh_length], self._mlh_coefficients)

        timing_estimation_results = CTimingEstimationResult(self.algorithm_name, self.photon_count, timestamps_detector1, timestamps_detector2)
        return timing_estimation_results

    def evaluate_single_timestamp(self, single_event):
        return np.dot(single_event.photon_timestamps[:len(self._mlh_coefficients)], self._mlh_coefficients)

    def print_coefficients(self):
        self._calculate_coefficients()
        print(self._mlh_coefficients)

CAlgorithmBase.register(CAlgorithmBlue)
assert issubclass(CAlgorithmBlue, CAlgorithmBase)
    @property
    def photon_count(self):
        return self.__photon_count

    def _calculate_coefficients(self):
        corrected_timestamps = self._training_coincidence_collection.detector1.timestamps[:, :self.photon_count] - self._training_coincidence_collection.detector2.timestamps[:, :self.photon_count]
        covariance = np.cov(corrected_timestamps[:, :self.photon_count], rowvar=0)
        unity = np.ones(self.photon_count)
        inverse_covariance = np.linalg.inv(covariance)
        w = np.dot(unity, inverse_covariance)
        n = np.dot(w, unity.T)
        self._mlh_coefficients = w / n

    def evaluate_collection_timestamps(self, coincidence_collection):
        current_mlh_length = len(self._mlh_coefficients)
        timestamps_detector1 = np.dot(coincidence_collection.detector1.timestamps[:, :current_mlh_length], self._mlh_coefficients)
        timestamps_detector2 = np.dot(coincidence_collection.detector2.timestamps[:, :current_mlh_length], self._mlh_coefficients)

        timing_estimation_results = CTimingEstimationResult(self.algorithm_name, self.photon_count, timestamps_detector1, timestamps_detector2)
        return timing_estimation_results

    def evaluate_single_timestamp(self, single_event):
        return np.dot(single_event.photon_timestamps[:len(self._mlh_coefficients)], self._mlh_coefficients)

    def print_coefficients(self):
        self._calculate_coefficients()
        print(self._mlh_coefficients)

CAlgorithmBase.register(CAlgorithmBlueDifferential)
assert issubclass(CAlgorithmBlueDifferential, CAlgorithmBase)
        # Calcul des coefficients pour le detecteur 1
        corrected_timestamps = self._training_coincidence_collection.detector1.timestamps[:, :self.photon_count] - timestamps_detector2[:,None]

        covariance = np.cov(corrected_timestamps[:, :self.photon_count], rowvar=0)
        unity = np.ones(self.photon_count)
        inverse_covariance = np.linalg.inv(covariance)
        w = np.dot(unity, inverse_covariance)
        n = np.dot(w, unity.T)
        self._mlh_coefficients = w / n




    def evaluate_collection_timestamps(self, coincidence_collection):
        current_mlh_length = len(self._mlh_coefficients)
        timestamps_detector1 = np.dot(coincidence_collection.detector1.timestamps[:, :current_mlh_length], self._mlh_coefficients)
        timestamps_detector2 = np.dot(coincidence_collection.detector2.timestamps[:, :current_mlh_length], self._mlh_coefficients)

        timing_estimation_results = CTimingEstimationResult(self.algorithm_name, self.photon_count, timestamps_detector1, timestamps_detector2)
        return timing_estimation_results

    def evaluate_single_timestamp(self, single_event):
        return np.dot(single_event.photon_timestamps[:len(self._mlh_coefficients)], self._mlh_coefficients)

    def print_coefficients(self):
        self._calculate_coefficients()
        print(self._mlh_coefficients)

CAlgorithmBase.register(CAlgorithmBlueExpectationMaximisation)
assert issubclass(CAlgorithmBlueExpectationMaximisation, CAlgorithmBase)