class WindowMultiTargetClassificationMeasurements(BaseObject):
    """ This class will maintain a fixed sized window of the newest information
    about one classifier. It can provide, as requested, any of the relevant
    current metrics about the classifier, measured inside the window.

    This class will keep updated statistics about a multi output classifier,
    using a confusion matrix adapted to multi output problems, the
    MOLConfusionMatrix, alongside other of the classifier's relevant
    attributes stored in ComplexFastBuffer objects, which will simulate
    fixed sized windows.

    Its functionality is somewhat similar to those of the
    MultiTargetClassificationMeasurements class. The difference is that the statistics
    kept by this class are local, or partial, while the statistics kept by
    the MultiTargetClassificationMeasurements class are global.

    At any given moment, it can compute the following statistics: hamming_loss,
    hamming_score, exact_match and j_index.

    Parameters
    ----------
    targets: list
        A list containing the possible labels.

    dtype: data type (Default: numpy.int64)
        The data type of the existing labels.

    window_size: int (Default: 200)
        The width of the window. Determines how many samples the object
        can see.

    Examples
    --------

    """
    def __init__(self, targets=None, dtype=np.int64, window_size=200):
        super().__init__()
        if targets is not None:
            self.n_targets = len(targets)
        else:
            self.n_targets = 0
        self.confusion_matrix = MOLConfusionMatrix(self.n_targets, dtype)
        self.last_true_label = None
        self.last_prediction = None

        self.targets = targets
        self.window_size = window_size
        self.exact_match_count = 0
        self.j_sum = 0
        self.true_labels = FastComplexBuffer(window_size, self.n_targets)
        self.predictions = FastComplexBuffer(window_size, self.n_targets)

    def reset(self):
        if self.targets is not None:
            self.n_targets = len(self.targets)
        else:
            self.n_targets = 0
        self.confusion_matrix.restart(self.n_targets)
        self.last_true_label = None
        self.last_prediction = None
        self.exact_match_count = 0
        self.j_sum = 0
        self.true_labels = FastComplexBuffer(self.window_size, self.n_targets)
        self.predictions = FastComplexBuffer(self.window_size, self.n_targets)

    def add_result(self, y_true, y_pred):
        """ Updates its statistics with the results of a prediction.

        Adds the result to the MOLConfusionMatrix, and updates the
        ComplexFastBuffer objects.

        Parameters
        ----------
        y_true: list or numpy.ndarray
            The true label.

        y_pred: list or numpy.ndarray
            The classifier's prediction

        """
        self.last_true_label = y_true
        self.last_prediction = y_pred
        m = 0
        if hasattr(y_true, 'size'):
            m = y_true.size
        elif hasattr(y_true, 'append'):
            m = len(y_true)
        self.n_targets = m

        for i in range(m):
            self.confusion_matrix.update(i, y_true[i], y_pred[i])

        old_true = self.true_labels.add_element(y_true)
        old_predict = self.predictions.add_element(y_pred)
        if (old_true is not None) and (old_predict is not None):
            for i in range(m):
                self.confusion_matrix.remove(old_true[0][i], old_predict[0][i])

    def get_last(self):
        return self.last_true_label, self.last_prediction

    def get_hamming_loss(self):
        """ Computes the window/current Hamming loss, which is the
        complement of the Hamming score metric.

        Returns
        -------
        float
            The window/current hamming loss.

        """
        return 1.0 - self.get_hamming_score()

    def get_hamming_score(self):
        """ Computes the window/current Hamming score, defined as the number of
        correctly classified labels divided by the total number of labels
        classified.

        Returns
        -------
        float
            The window/current hamming score.

        """
        return hamming_score(self.true_labels.get_queue(),
                             self.predictions.get_queue())

    def get_exact_match(self):
        """ Computes the window/current exact match metric.

        This is the most strict multi output metric, defined as the number of
        samples that have all their labels correctly classified, divided by the
        total number of samples.

        Returns
        -------
        float
            The window/current exact match metric.

        """
        return exact_match(self.true_labels.get_queue(),
                           self.predictions.get_queue())

    def get_j_index(self):
        """ Computes the window/current Jaccard index, also known as the intersection
        over union metric. It is calculated by dividing the number of correctly
        classified labels by the union of predicted and true labels.

        Returns
        -------
        float
            The window/current Jaccard index.

        """
        return j_index(self.true_labels.get_queue(),
                       self.predictions.get_queue())

    def get_total_sum(self):
        return self.confusion_matrix.get_total_sum()

    @property
    def matrix(self):
        return self.confusion_matrix.matrix

    @property
    def sample_count(self):
        return self.true_labels.get_current_size()

    def get_info(self):
        return '{}:'.format(type(self).__name__) + \
               ' - sample_count: {}'.format(self.sample_count) + \
               ' - hamming_loss: {:.6f}'.format(self.get_hamming_loss()) + \
               ' - hamming_score: {:.6f}'.format(self.get_hamming_score()) + \
               ' - exact_match: {:.6f}'.format(self.get_exact_match()) + \
               ' - j_index: {:.6f}'.format(self.get_j_index())

    def get_class_type(self):
        return 'measurement'
class MultiTargetClassificationMeasurements(BaseObject):
    """ This class will keep updated statistics about a multi output classifier,
    using a confusion matrix adapted to multi output problems, the
    MOLConfusionMatrix, alongside other relevant attributes.

    The performance metrics for multi output tasks are different from those used
    for normal classification tasks. Thus, the statistics provided by this class
    are different from those provided by the ClassificationMeasurements and from
    the WindowClassificationMeasurements.

    At any given moment, it can compute the following statistics: hamming_loss,
    hamming_score, exact_match and j_index.

    Parameters
    ----------
    targets: list
        A list containing the possible labels.

    dtype: data type (Default: numpy.int64)
        The data type of the existing labels.

    Examples
    --------

    """
    def __init__(self, targets=None, dtype=np.int64):
        super().__init__()
        if targets is not None:
            self.n_targets = len(targets)
        else:
            self.n_targets = 0
        self.confusion_matrix = MOLConfusionMatrix(self.n_targets, dtype)
        self.last_true_label = None
        self.last_prediction = None
        self.sample_count = 0
        self.targets = targets
        self.exact_match_count = 0
        self.j_sum = 0

    def reset(self):
        if self.targets is not None:
            self.n_targets = len(self.targets)
        else:
            self.n_targets = 0
        self.confusion_matrix.restart(self.n_targets)
        self.last_true_label = None
        self.last_prediction = None
        self.sample_count = 0
        self.exact_match_count = 0
        self.j_sum = 0

    def add_result(self, y_true, y_pred):
        """ Updates its statistics with the results of a prediction.

        Adds the result to the MOLConfusionMatrix and update exact_matches and
        j-index sum counts.

        Parameters
        ----------
        y_true: list or numpy.ndarray
            The true label.

        y_pred: list or numpy.ndarray
            The classifier's prediction

        """
        self.last_true_label = y_true
        self.last_prediction = y_pred
        m = 0
        if isinstance(y_true, np.ndarray):
            m = y_true.size
        elif isinstance(y_true, list):
            m = len(y_true)
        self.n_targets = m
        equal = True
        for i in range(m):
            self.confusion_matrix.update(i, y_true[i], y_pred[i])
            # update exact_match count
            if y_true[i] != y_pred[i]:
                equal = False

        # update exact_match
        if equal:
            self.exact_match_count += 1

        # update j_index count
        inter = sum((y_true * y_pred) > 0) * 1.
        union = sum((y_true + y_pred) > 0) * 1.
        if union > 0:
            self.j_sum += inter / union
        elif np.sum(y_true) == 0:
            self.j_sum += 1

        self.sample_count += 1

    def get_last(self):
        return self.last_true_label, self.last_prediction

    def get_hamming_loss(self):
        """ Computes the Hamming loss, which is the complement of the
        Hamming score metric.

        Returns
        -------
        float
            The hamming loss.

        """
        return 1.0 - self.get_hamming_score()

    def get_hamming_score(self):
        """ Computes the Hamming score, defined as the number of correctly
        classified labels divided by the total number of labels classified.

        Returns
        -------
        float
            The Hamming score.

        """
        try:
            return self.confusion_matrix.get_sum_main_diagonal() / (
                self.sample_count * self.n_targets)
        except ZeroDivisionError:
            return 0.0

    def get_exact_match(self):
        """ Computes the exact match metric.

        This is the most strict multi output metric, defined as the number of
        samples that have all their labels correctly classified, divided by the
        total number of samples.

        Returns
        -------
        float
            The exact match metric.

        """
        return self.exact_match_count / self.sample_count

    def get_j_index(self):
        """ Computes the Jaccard index, also known as the intersection over union
        metric. It is calculated by dividing the number of correctly classified
        labels by the union of predicted and true labels.

        Returns
        -------
        float
            The Jaccard index.

        """
        return self.j_sum / self.sample_count

    def get_total_sum(self):
        return self.confusion_matrix.get_total_sum()

    @property
    def _matrix(self):
        return self.confusion_matrix.matrix

    def get_info(self):
        return '{}:'.format(type(self).__name__) + \
               ' - sample_count: {}'.format(self.sample_count) + \
               ' - hamming_loss: {:.6f}'.format(self.get_hamming_loss()) + \
               ' - hamming_score: {:.6f}'.format(self.get_hamming_score()) + \
               ' - exact_match: {:.6f}'.format(self.get_exact_match()) + \
               ' - j_index: {:.6f}'.format(self.get_j_index())

    def get_class_type(self):
        return 'measurement'