def on_epoch_end(self, iteration, writer=None, **kwargs):

        if writer is None:
            return

        # log intra class vs. inter class distance distributions
        log_positive = np.hstack(self.log_positive_)
        log_negative = np.hstack(self.log_negative_)
        writer.add_histogram(
            'train/distance/intra_class', log_positive,
            global_step=iteration, bins='doane')
        writer.add_histogram(
            'train/distance/inter_class', log_negative,
            global_step=iteration, bins='doane')

        # log same/different experiment on training samples
        _, _, _, eer = det_curve(
            np.hstack([np.ones(len(log_positive)),
                       np.zeros(len(log_negative))]),
            np.hstack([log_positive, log_negative]),
            distances=True)
        writer.add_scalar('train/eer', eer,
                          global_step=iteration)

        # log raw triplet loss (before max(0, .))
        log_delta = np.vstack(self.log_delta_)
        writer.add_histogram(
            'train/triplet/delta', log_delta,
            global_step=iteration, bins='doane')

        # log distribution of embedding norms
        log_norm = np.hstack(self.log_norm_)
        writer.add_histogram(
            'train/embedding/norm', log_norm,
            global_step=iteration, bins='doane')
Exemple #2
0
    def on_epoch_end(self, epoch, logs={}):

        # keep track of current time
        now = datetime.datetime.now().isoformat()

        embedding = self.extract_embedding(self.model)
        fX = embedding.predict(self.X_)
        distance = pdist(fX, metric=self.distance)
        prefix = self.log_dir + '/{subset}.plot.{epoch:04d}'.format(
            subset=self.subset, epoch=epoch)

        # plot distance distribution every 20 epochs (and 10 first epochs)
        xlim = get_range(metric=self.distance)
        if (epoch < 10) or (epoch % 20 == 0):
            plot_distributions(self.y_,
                               distance,
                               prefix,
                               xlim=xlim,
                               ymax=3,
                               nbins=100,
                               dpi=75)

        # plot DET curve once every 20 epochs (and 10 first epochs)
        if (epoch < 10) or (epoch % 20 == 0):
            eer = plot_det_curve(self.y_,
                                 distance,
                                 prefix,
                                 distances=True,
                                 dpi=75)
        else:
            _, _, _, eer = det_curve(self.y_, distance, distances=True)

        # store equal error rate in file
        mode = 'a' if epoch else 'w'
        path = self.log_dir + '/{subset}.eer.txt'.format(subset=self.subset)
        with open(path, mode=mode) as fp:
            fp.write(self.EER_TEMPLATE_.format(epoch=epoch, eer=eer, now=now))
            fp.flush()

        # plot eer = f(epoch)
        self.eer_.append(eer)
        best_epoch = np.argmin(self.eer_)
        best_value = np.min(self.eer_)
        fig = plt.figure()
        plt.plot(self.eer_, 'b')
        plt.plot([best_epoch], [best_value], 'bo')
        plt.plot([0, epoch], [best_value, best_value], 'k--')
        plt.grid(True)
        plt.xlabel('epoch')
        plt.ylabel('EER on {subset}'.format(subset=self.subset))
        TITLE = 'EER = {best_value:.5g} on {subset} @ epoch #{best_epoch:d}'
        title = TITLE.format(best_value=best_value,
                             best_epoch=best_epoch,
                             subset=self.subset)
        plt.title(title)
        plt.tight_layout()
        path = self.log_dir + '/{subset}.eer.png'.format(subset=self.subset)
        plt.savefig(path, dpi=75)
        plt.close(fig)
Exemple #3
0
    def on_epoch_end(self, iteration, writer=None, **kwargs):

        if writer is None:
            return

        log_y_pred = np.hstack(self.log_y_pred_)
        log_y_true = np.hstack(self.log_y_true_)
        log_y_pred = log_y_pred.reshape((-1, self.n_classes))
        log_y_true = log_y_true.reshape((-1, ))
        if self.n_classes < 3:
            _, _, _, eer = det_curve(log_y_true == 0,
                                     log_y_pred[:, 0])
            writer.add_scalar(f'train/eer',
                eer, global_step=iteration)
        else:
            for k in range(self.n_classes):
                _, _, _, eer = det_curve(log_y_true == k,
                                         log_y_pred[:, k])
                writer.add_scalar(f'train/eer/{k}',
                    eer, global_step=iteration)
def speaker_diarization_xp(sequence_embedding, X, y, distance='angular'):

    fX = sequence_embedding.transform(X)

    # compute distance between every pair of sequences
    y_pred = pdist(fX, metric=distance)

    # compute same/different groundtruth
    y_true = pdist(y, metric='chebyshev') < 1

    # return DET curve
    return det_curve(y_true, y_pred, distances=True)
def plot_det_curve(y_true, scores, save_to,
                   distances=False, dpi=150):
    """DET curve

    This function will create (and overwrite) the following files:
        - {save_to}.det.png
        - {save_to}.det.eps
        - {save_to}.det.txt

    Parameters
    ----------
    y_true : (n_samples, ) array-like
        Boolean reference.
    scores : (n_samples, ) array-like
        Predicted score.
    save_to : str
        Files path prefix.
    distances : boolean, optional
        When True, indicate that `scores` are actually `distances`
    dpi : int, optional
        Resolution of .png file. Defaults to 150.

    Returns
    -------
    eer : float
        Equal error rate
    """

    fpr, fnr, thresholds, eer = det_curve(y_true, scores, distances=distances)

    # plot DET curve
    plt.figure(figsize=(12, 12))
    plt.loglog(fpr, fnr, 'b')
    plt.loglog([eer], [eer], 'bo')
    plt.xlabel('False Positive Rate')
    plt.ylabel('False Negative Rate')
    plt.xlim(1e-2, 1.)
    plt.ylim(1e-2, 1.)
    plt.grid(True)
    plt.tight_layout()
    plt.savefig(save_to + '.det.png', dpi=dpi)
    plt.savefig(save_to + '.det.eps')
    plt.close()

    # save DET curve in text file
    txt = save_to + '.det.txt'
    line = '{t:.6f} {fp:.6f} {fn:.6f}\n'
    with open(txt, 'w') as f:
        for i, (t, fp, fn) in enumerate(zip(thresholds, fpr, fnr)):
            f.write(line.format(t=t, fp=fp, fn=fn))

    return eer
def plot_det_curve(y_true, scores, save_to, distances=False, dpi=150):
    """DET curve

    This function will create (and overwrite) the following files:
        - {save_to}.det.png
        - {save_to}.det.eps
        - {save_to}.det.txt

    Parameters
    ----------
    y_true : (n_samples, ) array-like
        Boolean reference.
    scores : (n_samples, ) array-like
        Predicted score.
    save_to : str
        Files path prefix.
    distances : boolean, optional
        When True, indicate that `scores` are actually `distances`
    dpi : int, optional
        Resolution of .png file. Defaults to 150.

    Returns
    -------
    eer : float
        Equal error rate
    """

    fpr, fnr, thresholds, eer = det_curve(y_true, scores, distances=distances)

    # plot DET curve
    plt.figure(figsize=(12, 12))
    plt.loglog(fpr, fnr, 'b')
    plt.loglog([eer], [eer], 'bo')
    plt.xlabel('False Positive Rate')
    plt.ylabel('False Negative Rate')
    plt.xlim(1e-2, 1.)
    plt.ylim(1e-2, 1.)
    plt.grid(True)
    plt.tight_layout()
    plt.savefig(save_to + '.det.png', dpi=dpi)
    plt.savefig(save_to + '.det.eps')
    plt.close()

    # save DET curve in text file
    txt = save_to + '.det.txt'
    line = '{t:.6f} {fp:.6f} {fn:.6f}\n'
    with open(txt, 'w') as f:
        for i, (t, fp, fn) in enumerate(zip(thresholds, fpr, fnr)):
            f.write(line.format(t=t, fp=fp, fn=fn))

    return eer
Exemple #7
0
def plotROC(scores, labels, epoch=0, seed=1):
    fpr, fnr, thresholds, eer = det_curve(labels.ravel(),
                                          scores.ravel(),
                                          distances=False)

    # SRE-2008 performance parameters
    Cmiss = 10
    Cfa = 1
    P_tgt = 0.01
    Cdet08 = Cmiss * fnr * P_tgt + Cfa * fpr * (1 - P_tgt)
    dcf08 = 10 * np.min(Cdet08)

    # SRE-2010 performance parameters
    Cmiss = 1
    Cfa = 1
    P_tgt = 0.001
    Cdet10 = Cmiss * fnr * P_tgt + Cfa * fpr * (1 - P_tgt)
    dcf10 = 1000 * np.min(Cdet10)

    fig = plt.figure(figsize=(12, 12))
    plt.loglog(fpr, fnr, color='darkorange', lw=2, label='EER = %0.2f' % eer)
    print('EER = {0:.2f}%, THR = {1:.6f}'.format(
        eer * 100., thresholds[np.argmin(np.abs(fpr - eer))]))

    print('minDCF08 = {0:.4f}, THR = {1:.6f}'.format(
        dcf08, thresholds[np.argmin(Cdet08)]))
    print('minFPR = {0:.4f}%, minFNR = {1:.4f}%\n'.format(
        fpr[np.argmin(Cdet08)] * 100, fnr[np.argmin(Cdet08)] * 100))

    print('minDCF10 = {0:.4f}, THR = {1:.6f}'.format(
        dcf10, thresholds[np.argmin(Cdet10)]))
    print('minFPR = {0:.4f}%, minFNR = {1:.4f}%\n'.format(
        fpr[np.argmin(Cdet10)] * 100, fnr[np.argmin(Cdet10)] * 100))

    plt.loglog([eer], [eer], 'bo')
    plt.loglog([fpr[np.argmin(Cdet08)]], [fnr[np.argmin(Cdet08)]], 'ro')
    plt.xlabel('False Positive Rate')
    plt.ylabel('False Negative Rate')
    plt.xlim(1e-4, 1.)
    plt.ylim(1e-2, 1.)
    plt.grid(True)
    plt.tight_layout()
    fig.savefig('./images/DET_' + str(epoch) + '.png',
                dpi=300,
                orientation='portrait')
    fig.savefig('./images/DET_latest.png', dpi=300, orientation='portrait')
    plt.close()
    def _validate_epoch_segment(self, epoch, protocol_name,
                                subset='development',
                                validation_data=None):

        model = self.load_model(epoch).to(self.device)
        model.eval()

        sequence_embedding = SequenceEmbedding(
            model, self.feature_extraction_,
            batch_size=self.batch_size, device=self.device)


        fX = sequence_embedding.apply(validation_data['X'])
        y_pred = pdist(fX, metric=self.metric)
        _, _, _, eer = det_curve(validation_data['y'], y_pred,
                                 distances=True)

        return {'EER.{0:g}s'.format(self.duration): {'minimize': True,
                                                'value': eer}}
    def _validate_epoch_turn(self,
                             epoch,
                             protocol_name,
                             subset='development',
                             validation_data=None):

        model = self.load_model(epoch).to(self.device)
        model.eval()

        sequence_embedding = SequenceEmbedding(model,
                                               self.feature_extraction_,
                                               batch_size=self.batch_size,
                                               device=self.device)

        fX = sequence_embedding.apply(validation_data['X'])

        z = validation_data['z']

        # iterate over segments, speech turn by speech turn

        fX_avg = []
        nz = np.vstack([np.arange(len(z)), z]).T
        for _, nz_ in itertools.groupby(nz, lambda t: t[1]):

            # (n, 2) numpy array where
            # * n is the number of segments in current speech turn
            # * dim #0 is the index of segment in original batch
            # * dim #1 is the index of speech turn (used for grouping)
            nz_ = np.stack(nz_)

            # compute (and stack) average embedding over all segments
            # of current speech turn
            indices = nz_[:, 0]

            fX_avg.append(np.mean(fX[indices], axis=0))

        fX = np.vstack(fX_avg)
        y_pred = pdist(fX, metric=self.metric)
        _, _, _, eer = det_curve(validation_data['y'], y_pred, distances=True)
        metrics = {}
        metrics['EER.turn'] = {'minimize': True, 'value': eer}
        return metrics
    def _validate_epoch_turn(self, epoch, protocol_name,
                             subset='development',
                             validation_data=None):

        model = self.load_model(epoch).to(self.device)
        model.eval()

        sequence_embedding = SequenceEmbedding(
            model, self.feature_extraction_,
            batch_size=self.batch_size, device=self.device)

        fX = sequence_embedding.apply(validation_data['X'])

        z = validation_data['z']

        # iterate over segments, speech turn by speech turn

        fX_avg = []
        nz = np.vstack([np.arange(len(z)), z]).T
        for _, nz_ in itertools.groupby(nz, lambda t: t[1]):

            # (n, 2) numpy array where
            # * n is the number of segments in current speech turn
            # * dim #0 is the index of segment in original batch
            # * dim #1 is the index of speech turn (used for grouping)
            nz_ = np.stack(nz_)

            # compute (and stack) average embedding over all segments
            # of current speech turn
            indices = nz_[:, 0]

            fX_avg.append(np.mean(fX[indices], axis=0))

        fX = np.vstack(fX_avg)
        y_pred = pdist(fX, metric=self.metric)
        _, _, _, eer = det_curve(validation_data['y'], y_pred,
                                 distances=True)
        metrics = {}
        metrics['EER.turn'] = {'minimize': True, 'value': eer}
        return metrics
    def on_epoch_end(self, iteration, writer=None, **kwargs):

        if writer is None:
            return

        # log intra class vs. inter class distance distributions
        log_positive = np.hstack(self.log_positive_)
        log_negative = np.hstack(self.log_negative_)
        writer.add_histogram('train/distance/intra_class',
                             log_positive,
                             global_step=iteration,
                             bins='doane')
        writer.add_histogram('train/distance/inter_class',
                             log_negative,
                             global_step=iteration,
                             bins='doane')

        # log same/different experiment on training samples
        _, _, _, eer = det_curve(np.hstack(
            [np.ones(len(log_positive)),
             np.zeros(len(log_negative))]),
                                 np.hstack([log_positive, log_negative]),
                                 distances=True)
        writer.add_scalar('train/eer', eer, global_step=iteration)

        # log raw triplet loss (before max(0, .))
        log_delta = np.vstack(self.log_delta_)
        writer.add_histogram('train/triplet/delta',
                             log_delta,
                             global_step=iteration,
                             bins='doane')

        # log distribution of embedding norms
        log_norm = np.hstack(self.log_norm_)
        writer.add_histogram('train/embedding/norm',
                             log_norm,
                             global_step=iteration,
                             bins='doane')
def speaker_recognition_xp(aggregation, protocol, subset='development',
                           distance='angular', threads=None):

    method = '{subset}_enroll'.format(subset=subset)
    enroll = getattr(protocol, method)(yield_name=True)

    method = '{subset}_test'.format(subset=subset)
    test = getattr(protocol, method)(yield_name=True)

    # TODO parallelize using multiprocessing
    fX = {}
    for name, item in itertools.chain(enroll, test):
        if name in fX:
            continue
        embeddings = aggregation.apply(item)
        fX[name] = np.sum(embeddings.data, axis=0)

    method = '{subset}_keys'.format(subset=subset)
    keys = getattr(protocol, method)()

    enroll_fX = l2_normalize(np.vstack([fX[name] for name in keys.index]))
    test_fX = l2_normalize(np.vstack([fX[name] for name in keys]))

    # compare all possible (enroll, test) pairs at once
    D = cdist(enroll_fX, test_fX, metric=distance)

    positive = D[np.where(keys == 1)]
    negative = D[np.where(keys == -1)]
    # untested = D[np.where(keys == 0)]
    y_pred = np.hstack([positive, negative])

    n_positive = positive.shape[0]
    n_negative = negative.shape[0]
    # n_untested = untested.shape[0]
    y_true = np.hstack([np.ones(n_positive,), np.zeros(n_negative)])

    return det_curve(y_true, y_pred, distances=True)
    def _validate_epoch_segment(self,
                                epoch,
                                protocol_name,
                                subset='development',
                                validation_data=None):

        model = self.load_model(epoch).to(self.device)
        model.eval()

        sequence_embedding = SequenceEmbedding(model,
                                               self.feature_extraction_,
                                               batch_size=self.batch_size,
                                               device=self.device)

        fX = sequence_embedding.apply(validation_data['X'])
        y_pred = pdist(fX, metric=self.metric)
        _, _, _, eer = det_curve(validation_data['y'], y_pred, distances=True)

        return {
            'EER.{0:g}s'.format(self.duration): {
                'minimize': True,
                'value': eer
            }
        }
Exemple #14
0
    def eval(self, model, partition: str = 'development'):
        model.eval()
        sequence_embedding = SequenceEmbedding(
            model=model,
            feature_extraction=self.config.feature_extraction,
            duration=self.config.duration,
            step=.5 * self.config.duration,
            batch_size=self.batch_size,
            device=common.DEVICE)
        protocol = get_protocol(self.config.protocol_name,
                                progress=False,
                                preprocessors=self.config.preprocessors)

        y_true, y_pred, cache = [], [], {}

        for trial in getattr(protocol, f"{partition}_trial")():

            # Compute embeddings
            emb1 = self._file_embedding(trial['file1'], sequence_embedding,
                                        cache)
            emb2 = self._file_embedding(trial['file2'], sequence_embedding,
                                        cache)

            # Compare embeddings
            dist = cdist(emb1, emb2,
                         metric=self.distance.to_sklearn_metric())[0, 0]

            y_pred.append(dist)
            y_true.append(trial['reference'])

        _, _, _, eer = det_curve(np.array(y_true),
                                 np.array(y_pred),
                                 distances=True)

        # Returning 1-eer because the evaluator keeps track of the highest metric value
        return 1 - eer, y_pred, y_true
    def _validate_epoch_verification(self,
                                     epoch,
                                     protocol_name,
                                     subset='development',
                                     validation_data=None):
        """Perform a speaker verification experiment using model at `epoch`

        Parameters
        ----------
        epoch : int
            Epoch to validate.
        protocol_name : str
            Name of speaker verification protocol
        subset : {'train', 'development', 'test'}, optional
            Name of subset.
        validation_data : provided by `validate_init`

        Returns
        -------
        metrics : dict
        """

        # load current model
        model = self.load_model(epoch).to(self.device)
        model.eval()

        # use user-provided --duration when available
        # otherwise use 'duration' used for training
        if self.duration is None:
            duration = self.task_.duration
        else:
            duration = self.duration
        min_duration = None

        # if 'duration' is still None, it means that
        # network was trained with variable lengths
        if duration is None:
            duration = self.task_.max_duration
            min_duration = self.task_.min_duration

        step = .5 * duration

        if isinstance(self.feature_extraction_, Precomputed):
            self.feature_extraction_.use_memmap = False

        # initialize embedding extraction
        sequence_embedding = SequenceEmbedding(model,
                                               self.feature_extraction_,
                                               duration=duration,
                                               step=step,
                                               min_duration=min_duration,
                                               batch_size=self.batch_size,
                                               device=self.device)

        metrics = {}
        protocol = get_protocol(protocol_name,
                                progress=False,
                                preprocessors=self.preprocessors_)

        enrolment_models, enrolment_khashes = {}, {}
        enrolments = getattr(protocol, '{0}_enrolment'.format(subset))()
        for i, enrolment in enumerate(enrolments):
            data = sequence_embedding.apply(enrolment,
                                            crop=enrolment['enrol_with'])
            model_id = enrolment['model_id']
            model = np.mean(np.stack(data), axis=0, keepdims=True)
            enrolment_models[model_id] = model

            # in some specific speaker verification protocols,
            # enrolment data may be  used later as trial data.
            # therefore, we cache information about enrolment data
            # to speed things up by reusing the enrolment as trial
            h = hash((get_unique_identifier(enrolment),
                      tuple(enrolment['enrol_with'])))
            enrolment_khashes[h] = model_id

        trial_models = {}
        trials = getattr(protocol, '{0}_trial'.format(subset))()
        y_true, y_pred = [], []
        for i, trial in enumerate(trials):
            model_id = trial['model_id']

            h = hash((get_unique_identifier(trial), tuple(trial['try_with'])))

            # re-use enrolment model whenever possible
            if h in enrolment_khashes:
                model = enrolment_models[enrolment_khashes[h]]

            # re-use trial model whenever possible
            elif h in trial_models:
                model = trial_models[h]

            else:
                data = sequence_embedding.apply(trial, crop=trial['try_with'])
                model = np.mean(data, axis=0, keepdims=True)
                # cache trial model for later re-use
                trial_models[h] = model

            distance = cdist(enrolment_models[model_id],
                             model,
                             metric=self.metric)[0, 0]
            y_pred.append(distance)
            y_true.append(trial['reference'])

        _, _, _, eer = det_curve(np.array(y_true),
                                 np.array(y_pred),
                                 distances=True)
        metrics['EER'] = {'minimize': True, 'value': eer}

        return metrics
Exemple #16
0
    def validate(self, protocol_name, subset='development'):

        # prepare paths
        validate_dir = self.VALIDATE_DIR.format(train_dir=self.train_dir_,
                                                protocol=protocol_name)
        validate_txt = self.VALIDATE_TXT.format(validate_dir=validate_dir,
                                                subset=subset)
        validate_png = self.VALIDATE_PNG.format(validate_dir=validate_dir,
                                                subset=subset)
        validate_eps = self.VALIDATE_EPS.format(validate_dir=validate_dir,
                                                subset=subset)

        # create validation directory
        mkdir_p(validate_dir)

        # Build validation set
        y = self._validation_set(protocol_name, subset=subset)

        # list of equal error rates, and current epoch
        eers, epoch = [], 0

        desc_format = ('EER = {eer:.2f}% @ epoch #{epoch:d} ::'
                       ' Best EER = {best_eer:.2f}% @ epoch #{best_epoch:d} :')
        progress_bar = tqdm(unit='epoch', total=1000)

        with open(validate_txt, mode='w') as fp:

            # watch and evaluate forever
            while True:

                weights_h5 = LoggingCallback.WEIGHTS_H5.format(
                    log_dir=self.train_dir_, epoch=epoch)

                # wait until weight file is available
                if not isfile(weights_h5):
                    time.sleep(60)
                    continue

                # load model for current epoch
                sequence_labeling = SequenceLabeling.from_disk(
                    self.train_dir_, epoch)

                # initialize sequence labeling
                duration = self.config_['sequences']['duration']
                step = duration  # hack to make things faster
                # step = self.config_['sequences']['step']
                aggregation = SequenceLabelingAggregation(
                    sequence_labeling,
                    self.feature_extraction_,
                    duration=duration,
                    step=step)
                aggregation.cache_preprocessed_ = False

                # estimate equal error rate (average of all files)
                eers_ = []
                protocol = get_protocol(protocol_name,
                                        progress=False,
                                        preprocessors=self.preprocessors_)
                file_generator = getattr(protocol, subset)()
                for current_file in file_generator:
                    identifier = get_unique_identifier(current_file)
                    uem = get_annotated(current_file)
                    y_true = y[identifier].crop(uem)[:, 1]
                    counts = Counter(y_true)
                    if counts[0] * counts[1] == 0:
                        continue
                    y_pred = aggregation.apply(current_file).crop(uem)[:, 1]

                    _, _, _, eer = det_curve(y_true, y_pred, distances=False)

                    eers_.append(eer)
                eer = np.mean(eers_)
                eers.append(eer)

                # save equal error rate to file
                fp.write(
                    self.VALIDATE_TXT_TEMPLATE.format(epoch=epoch, eer=eer))
                fp.flush()

                # keep track of best epoch so far
                best_epoch, best_eer = np.argmin(eers), np.min(eers)

                progress_bar.set_description(
                    desc_format.format(epoch=epoch,
                                       eer=100 * eer,
                                       best_epoch=best_epoch,
                                       best_eer=100 * best_eer))
                progress_bar.update(1)

                # plot
                fig = plt.figure()
                plt.plot(eers, 'b')
                plt.plot([best_epoch], [best_eer], 'bo')
                plt.plot([0, epoch], [best_eer, best_eer], 'k--')
                plt.grid(True)
                plt.xlabel('epoch')
                plt.ylabel('EER on {subset}'.format(subset=subset))
                TITLE = '{best_eer:.5g} @ epoch #{best_epoch:d}'
                title = TITLE.format(best_eer=best_eer,
                                     best_epoch=best_epoch,
                                     subset=subset)
                plt.title(title)
                plt.tight_layout()
                plt.savefig(validate_png, dpi=75)
                plt.savefig(validate_eps)
                plt.close(fig)

                # validate next epoch
                epoch += 1

        progress_bar.close()
Exemple #17
0
    def on_epoch_end(self, epoch, logs={}):

        # keep track of current time
        now = datetime.datetime.now().isoformat()
        prefix = self.log_dir + '/{subset}.plot.{epoch:04d}'.format(
            epoch=epoch, subset=self.subset)

        from pyannote.audio.embedding.base import SequenceEmbedding
        sequence_embedding = SequenceEmbedding()
        sequence_embedding.embedding_ = self.glue.extract_embedding(self.model)

        from pyannote.audio.embedding.aggregation import \
            SequenceEmbeddingAggregation
        aggregation = SequenceEmbeddingAggregation(
            sequence_embedding,
            self.glue.feature_extractor,
            duration=self.glue.duration,
            min_duration=self.glue.min_duration,
            step=self.glue.step,
            internal=-2)

        # TODO / pass internal as parameter
        aggregation.cache_preprocessed_ = False

        # embed enroll and test recordings

        method = '{subset}_enroll'.format(subset=self.subset)
        enroll = getattr(self.protocol, method)(yield_name=True)

        method = '{subset}_test'.format(subset=self.subset)
        test = getattr(self.protocol, method)(yield_name=True)

        fX = {}
        for name, item in itertools.chain(enroll, test):
            if name in fX:
                continue
            embeddings = aggregation.apply(item)
            fX[name] = np.sum(embeddings.data, axis=0)

        # perform trials

        method = '{subset}_keys'.format(subset=self.subset)
        keys = getattr(self.protocol, method)()

        enroll_fX = l2_normalize(np.vstack([fX[name] for name in keys.index]))
        test_fX = l2_normalize(np.vstack([fX[name] for name in keys]))

        D = cdist(enroll_fX, test_fX, metric=self.glue.distance)

        y_true = []
        y_pred = []
        key_mapping = {0: None, -1: 0, 1: 1}
        for i, _ in enumerate(keys.index):
            for j, _ in enumerate(keys):
                y = key_mapping[keys.iloc[i, j]]
                if y is None:
                    continue

                y_true.append(y)
                y_pred.append(D[i, j])

        y_true = np.array(y_true)
        y_pred = np.array(y_pred)

        # plot DET curve once every 20 epochs (and 10 first epochs)
        if (epoch < 10) or (epoch % 20 == 0):
            eer = plot_det_curve(y_true,
                                 y_pred,
                                 prefix,
                                 distances=True,
                                 dpi=75)
        else:
            _, _, _, eer = det_curve(y_true, y_pred, distances=True)

        # store equal error rate in file
        mode = 'a' if epoch else 'w'
        path = self.log_dir + '/{subset}.eer.txt'.format(subset=self.subset)
        with open(path, mode=mode) as fp:
            fp.write(self.EER_TEMPLATE_.format(epoch=epoch, eer=eer, now=now))
            fp.flush()

        # plot eer = f(epoch)
        self.eer_.append(eer)
        best_epoch = np.argmin(self.eer_)
        best_value = np.min(self.eer_)
        fig = plt.figure()
        plt.plot(self.eer_, 'b')
        plt.plot([best_epoch], [best_value], 'bo')
        plt.plot([0, epoch], [best_value, best_value], 'k--')
        plt.grid(True)
        plt.xlabel('epoch')
        plt.ylabel('EER on {subset}'.format(subset=self.subset))
        TITLE = 'EER = {best_value:.5g} on {subset} @ epoch #{best_epoch:d}'
        title = TITLE.format(best_value=best_value,
                             best_epoch=best_epoch,
                             subset=self.subset)
        plt.title(title)
        plt.tight_layout()
        path = self.log_dir + '/{subset}.eer.png'.format(subset=self.subset)
        plt.savefig(path, dpi=75)
        plt.close(fig)
    def _validate_epoch_verification(self,
                                     epoch,
                                     validation_data,
                                     protocol=None,
                                     subset='development',
                                     device: Optional[torch.device] = None,
                                     batch_size: int = 32,
                                     n_jobs: int = 1,
                                     duration: float = None,
                                     step: float = 0.25,
                                     metric: str = None,
                                     **kwargs):

        # initialize embedding extraction
        pretrained = Pretrained(validate_dir=self.validate_dir_,
                                epoch=epoch,
                                duration=duration,
                                step=step,
                                batch_size=batch_size,
                                device=device)

        _protocol = get_protocol(protocol,
                                 progress=False,
                                 preprocessors=self.preprocessors_)

        y_true, y_pred, cache = [], [], {}

        for trial in getattr(_protocol, '{0}_trial'.format(subset))():

            # compute embedding for file1
            file1 = trial['file1']
            hash1 = self.get_hash(file1)
            if hash1 in cache:
                emb1 = cache[hash1]
            else:
                emb1 = self.get_embedding(file1, pretrained)
                cache[hash1] = emb1

            # compute embedding for file2
            file2 = trial['file2']
            hash2 = self.get_hash(file2)
            if hash2 in cache:
                emb2 = cache[hash2]
            else:
                emb2 = self.get_embedding(file2, pretrained)
                cache[hash2] = emb2

            # compare embeddings
            distance = cdist(emb1, emb2, metric=metric)[0, 0]
            y_pred.append(distance)

            y_true.append(trial['reference'])

        _, _, _, eer = det_curve(np.array(y_true),
                                 np.array(y_pred),
                                 distances=True)

        return {
            'metric': 'equal_error_rate',
            'minimize': True,
            'value': float(eer)
        }
    def _validate_epoch_verification(
        self,
        epoch,
        validation_data,
        protocol=None,
        subset: Subset = "development",
        device: Optional[torch.device] = None,
        batch_size: int = 32,
        n_jobs: int = 1,
        duration: float = None,
        step: float = 0.25,
        metric: str = None,
        **kwargs,
    ):

        # initialize embedding extraction
        pretrained = Pretrained(
            validate_dir=self.validate_dir_,
            epoch=epoch,
            duration=duration,
            step=step,
            batch_size=batch_size,
            device=device,
        )

        preprocessors = self.preprocessors_
        if "audio" not in preprocessors:
            preprocessors["audio"] = FileFinder()
        if "duration" not in preprocessors:
            preprocessors["duration"] = get_audio_duration
        _protocol = get_protocol(protocol, preprocessors=preprocessors)

        y_true, y_pred, cache = [], [], {}

        for trial in getattr(_protocol, f"{subset}_trial")():

            # compute embedding for file1
            file1 = trial["file1"]
            hash1 = self.get_hash(file1)
            if hash1 in cache:
                emb1 = cache[hash1]
            else:
                emb1 = self.get_embedding(file1, pretrained)
                cache[hash1] = emb1

            # compute embedding for file2
            file2 = trial["file2"]
            hash2 = self.get_hash(file2)
            if hash2 in cache:
                emb2 = cache[hash2]
            else:
                emb2 = self.get_embedding(file2, pretrained)
                cache[hash2] = emb2

            # compare embeddings
            distance = cdist(emb1, emb2, metric=metric)[0, 0]
            y_pred.append(distance)

            y_true.append(trial["reference"])
        _, _, _, eer = det_curve(np.array(y_true),
                                 np.array(y_pred),
                                 distances=True)

        return {
            "metric": "equal_error_rate",
            "minimize": True,
            "value": float(eer)
        }
    def fit(self,
            model,
            feature_extraction,
            protocol,
            log_dir,
            subset='train',
            epochs=1000,
            restart=0,
            gpu=False):

        import tensorboardX
        writer = tensorboardX.SummaryWriter(log_dir=log_dir)

        checkpoint = Checkpoint(log_dir=log_dir, restart=restart > 0)

        batch_generator = SpeechSegmentGenerator(feature_extraction,
                                                 per_label=self.per_label,
                                                 per_fold=self.per_fold,
                                                 duration=self.duration,
                                                 parallel=self.parallel)
        batches = batch_generator(protocol, subset=subset)
        batch = next(batches)

        batches_per_epoch = batch_generator.batches_per_epoch

        if restart > 0:
            weights_pt = checkpoint.WEIGHTS_PT.format(log_dir=log_dir,
                                                      epoch=restart)
            model.load_state_dict(torch.load(weights_pt))

        if gpu:
            model = model.cuda()

        model.internal = False

        parameters = list(model.parameters())

        if self.variant in [2, 3, 4, 5, 6, 7, 8]:

            # norm batch-normalization
            self.norm_bn = nn.BatchNorm1d(1,
                                          eps=1e-5,
                                          momentum=0.1,
                                          affine=True)
            if gpu:
                self.norm_bn = self.norm_bn.cuda()
            parameters += list(self.norm_bn.parameters())

        if self.variant in [9]:
            # norm batch-normalization
            self.norm_bn = nn.BatchNorm1d(1,
                                          eps=1e-5,
                                          momentum=0.1,
                                          affine=False)
            if gpu:
                self.norm_bn = self.norm_bn.cuda()
            parameters += list(self.norm_bn.parameters())

        if self.variant in [5, 6, 7]:
            self.positive_bn = nn.BatchNorm1d(1,
                                              eps=1e-5,
                                              momentum=0.1,
                                              affine=False)
            self.negative_bn = nn.BatchNorm1d(1,
                                              eps=1e-5,
                                              momentum=0.1,
                                              affine=False)
            if gpu:
                self.positive_bn = self.positive_bn.cuda()
                self.negative_bn = self.negative_bn.cuda()
            parameters += list(self.positive_bn.parameters())
            parameters += list(self.negative_bn.parameters())

        if self.variant in [8, 9]:

            self.delta_bn = nn.BatchNorm1d(1,
                                           eps=1e-5,
                                           momentum=0.1,
                                           affine=False)
            if gpu:
                self.delta_bn = self.delta_bn.cuda()
            parameters += list(self.delta_bn.parameters())

        optimizer = Adam(parameters)
        if restart > 0:
            optimizer_pt = checkpoint.OPTIMIZER_PT.format(log_dir=log_dir,
                                                          epoch=restart)
            optimizer.load_state_dict(torch.load(optimizer_pt))
            if gpu:
                for state in optimizer.state.values():
                    for k, v in state.items():
                        if torch.is_tensor(v):
                            state[k] = v.cuda()

        epoch = restart if restart > 0 else -1
        while True:
            epoch += 1
            if epoch > epochs:
                break

            loss_avg, tloss_avg, closs_avg = 0., 0., 0.

            if epoch % 5 == 0:
                log_positive = []
                log_negative = []
                log_delta = []
                log_norm = []

            desc = 'Epoch #{0}'.format(epoch)
            for i in tqdm(range(batches_per_epoch), desc=desc):

                model.zero_grad()

                batch = next(batches)

                X = batch['X']
                if not getattr(model, 'batch_first', True):
                    X = np.rollaxis(X, 0, 2)
                X = np.array(X, dtype=np.float32)
                X = Variable(torch.from_numpy(X))

                if gpu:
                    X = X.cuda()

                fX = model(X)

                # pre-compute pairwise distances
                distances = self.pdist(fX)

                # sample triplets
                triplets = getattr(self, 'batch_{0}'.format(self.sampling))
                anchors, positives, negatives = triplets(batch['y'], distances)

                # compute triplet loss
                tlosses, deltas, pos_index, neg_index = self.triplet_loss(
                    distances,
                    anchors,
                    positives,
                    negatives,
                    return_delta=True)

                tloss = torch.mean(tlosses)

                if self.variant == 1:

                    closses = F.sigmoid(
                        F.softsign(deltas) *
                        torch.norm(fX[anchors], 2, 1, keepdim=True))

                    # if d(a, p) < d(a, n) (i.e. good case)
                    #   --> sign(delta) < 0
                    #   --> loss decreases when norm increases.
                    #       i.e. encourages longer anchor

                    # if d(a, p) > d(a, n) (i.e. bad case)
                    #   --> sign(delta) > 0
                    #   --> loss increases when norm increases
                    #       i.e. encourages shorter anchor

                elif self.variant == 2:

                    norms_ = torch.norm(fX, 2, 1, keepdim=True)
                    norms_ = F.sigmoid(self.norm_bn(norms_))

                    confidence = (norms_[anchors] + norms_[positives] +
                                  norms_[negatives]) / 3
                    # if |x| is average
                    #    --> normalized |x| = 0
                    #    --> confidence = 0.5

                    # if |x| is bigger than average
                    #    --> normalized |x| >> 0
                    #    --> confidence = 1

                    # if |x| is smaller than average
                    #    --> normalized |x| << 0
                    #    --> confidence = 0

                    correctness = F.sigmoid(-deltas / np.pi * 6)
                    # if d(a, p) = d(a, n) (i.e. uncertain case)
                    #    --> correctness = 0.5

                    # if d(a, p) - d(a, n) = -𝛑 (i.e. best possible case)
                    #    --> correctness = 1

                    # if d(a, p) - d(a, n) = +𝛑 (i.e. worst possible case)
                    #    --> correctness = 0

                    closses = torch.abs(confidence - correctness)
                    # small if (and only if) confidence & correctness agree

                elif self.variant == 3:

                    norms_ = torch.norm(fX, 2, 1, keepdim=True)
                    norms_ = F.sigmoid(self.norm_bn(norms_))
                    confidence = (norms_[anchors] * norms_[positives] *
                                  norms_[negatives]) / 3

                    correctness = F.sigmoid(-(deltas + np.pi / 4) / np.pi * 6)
                    # correctness = 0.5 at delta == -pi/4
                    # correctness = 1 for delta == -pi
                    # correctness = 0 for delta < 0

                    closses = torch.abs(confidence - correctness)

                elif self.variant == 4:

                    norms_ = torch.norm(fX, 2, 1, keepdim=True)
                    norms_ = F.sigmoid(self.norm_bn(norms_))
                    confidence = (norms_[anchors] * norms_[positives] *
                                  norms_[negatives])**1 / 3

                    correctness = F.sigmoid(-(deltas + np.pi / 4) / np.pi * 6)
                    # correctness = 0.5 at delta == -pi/4
                    # correctness = 1 for delta == -pi
                    # correctness = 0 for delta < 0

                    # delta = pos - neg ... should be < 0

                    closses = torch.abs(confidence - correctness)

                elif self.variant == 5:

                    norms_ = torch.norm(fX, 2, 1, keepdim=True)
                    confidence = F.sigmoid(self.norm_bn(norms_))

                    confidence_pos = .5 * (confidence[anchors] +
                                           confidence[positives])
                    # low positive distance == high correctness
                    correctness_pos = F.sigmoid(
                        -self.positive_bn(distances[pos_index].view(-1, 1)))

                    confidence_neg = .5 * (confidence[anchors] +
                                           confidence[negatives])
                    # high negative distance == high correctness
                    correctness_neg = F.sigmoid(
                        self.negative_bn(distances[neg_index].view(-1, 1)))

                    closses = .5 * (torch.abs(confidence_pos - correctness_pos) \
                                  + torch.abs(confidence_neg - correctness_neg))

                elif self.variant == 6:

                    norms_ = torch.norm(fX, 2, 1, keepdim=True)
                    confidence = F.sigmoid(self.norm_bn(norms_))

                    confidence_pos = .5 * (confidence[anchors] +
                                           confidence[positives])
                    # low positive distance == high correctness
                    correctness_pos = F.sigmoid(
                        -self.positive_bn(distances[pos_index].view(-1, 1)))

                    closses = torch.abs(confidence_pos - correctness_pos)

                elif self.variant == 7:

                    norms_ = torch.norm(fX, 2, 1, keepdim=True)
                    confidence = F.sigmoid(self.norm_bn(norms_))

                    confidence_neg = .5 * (confidence[anchors] +
                                           confidence[negatives])
                    # high negative distance == high correctness
                    correctness_neg = F.sigmoid(
                        self.negative_bn(distances[neg_index].view(-1, 1)))

                    closses = torch.abs(confidence_neg - correctness_neg)

                elif self.variant in [8, 9]:

                    norms_ = torch.norm(fX, 2, 1, keepdim=True)
                    norms_ = F.sigmoid(self.norm_bn(norms_))
                    confidence = (norms_[anchors] * norms_[positives] *
                                  norms_[negatives]) / 3

                    correctness = F.sigmoid(-self.delta_bn(deltas))
                    closses = torch.abs(confidence - correctness)

                closs = torch.mean(closses)

                if epoch % 5 == 0:

                    if gpu:
                        fX_npy = fX.data.cpu().numpy()
                        pdist_npy = distances.data.cpu().numpy()
                        delta_npy = deltas.data.cpu().numpy()
                    else:
                        fX_npy = fX.data.numpy()
                        pdist_npy = distances.data.numpy()
                        delta_npy = deltas.data.numpy()

                    log_norm.append(np.linalg.norm(fX_npy, axis=1))

                    same_speaker = pdist(batch['y'].reshape((-1, 1)),
                                         metric='chebyshev') < 1
                    log_positive.append(pdist_npy[np.where(same_speaker)])
                    log_negative.append(pdist_npy[np.where(~same_speaker)])

                    log_delta.append(delta_npy)

                # log loss
                if gpu:
                    tloss_ = float(tloss.data.cpu().numpy())
                    closs_ = float(closs.data.cpu().numpy())
                else:
                    tloss_ = float(tloss.data.numpy())
                    closs_ = float(closs.data.numpy())
                tloss_avg += tloss_
                closs_avg += closs_
                loss_avg += tloss_ + closs_

                loss = tloss + closs
                loss.backward()
                optimizer.step()

            tloss_avg /= batches_per_epoch
            writer.add_scalar('tloss', tloss_avg, global_step=epoch)

            closs_avg /= batches_per_epoch
            writer.add_scalar('closs', closs_avg, global_step=epoch)

            loss_avg /= batches_per_epoch
            writer.add_scalar('loss', loss_avg, global_step=epoch)

            if epoch % 5 == 0:

                log_positive = np.hstack(log_positive)
                writer.add_histogram('embedding/pairwise_distance/positive',
                                     log_positive,
                                     global_step=epoch,
                                     bins=np.linspace(0, np.pi, 50))
                log_negative = np.hstack(log_negative)

                writer.add_histogram('embedding/pairwise_distance/negative',
                                     log_negative,
                                     global_step=epoch,
                                     bins=np.linspace(0, np.pi, 50))

                _, _, _, eer = det_curve(np.hstack(
                    [np.ones(len(log_positive)),
                     np.zeros(len(log_negative))]),
                                         np.hstack(
                                             [log_positive, log_negative]),
                                         distances=True)
                writer.add_scalar('eer', eer, global_step=epoch)

                log_norm = np.hstack(log_norm)
                writer.add_histogram('norm',
                                     log_norm,
                                     global_step=epoch,
                                     bins='doane')

                log_delta = np.vstack(log_delta)
                writer.add_histogram('delta',
                                     log_delta,
                                     global_step=epoch,
                                     bins='doane')

            checkpoint.on_epoch_end(epoch, model, optimizer)

            if hasattr(self, 'norm_bn'):
                confidence_pt = self.CONFIDENCE_PT.format(log_dir=log_dir,
                                                          epoch=epoch)
                torch.save(self.norm_bn.state_dict(), confidence_pt)
Exemple #21
0
    def validate(self,
                 protocol_name,
                 subset='development',
                 aggregate=False,
                 every=1,
                 start=0):

        # prepare paths
        validate_dir = self.VALIDATE_DIR.format(train_dir=self.train_dir_,
                                                protocol=protocol_name)
        validate_txt = self.VALIDATE_TXT.format(
            validate_dir=validate_dir,
            subset=subset,
            aggregate='aggregate.' if aggregate else '')
        validate_png = self.VALIDATE_PNG.format(
            validate_dir=validate_dir,
            subset=subset,
            aggregate='aggregate.' if aggregate else '')
        validate_eps = self.VALIDATE_EPS.format(
            validate_dir=validate_dir,
            subset=subset,
            aggregate='aggregate.' if aggregate else '')

        # create validation directory
        mkdir_p(validate_dir)

        # Build validation set
        if aggregate:
            X, n, y = self._validation_set_z(protocol_name, subset=subset)
        else:
            X, y = self._validation_set_y(protocol_name, subset=subset)

        # list of equal error rates, and epoch to process
        eers, epoch = SortedDict(), start

        desc_format = ('Best EER = {best_eer:.2f}% @ epoch #{best_epoch:d} ::'
                       ' EER = {eer:.2f}% @ epoch #{epoch:d} :')

        progress_bar = tqdm(unit='epoch')

        with open(validate_txt, mode='w') as fp:

            # watch and evaluate forever
            while True:

                # last completed epochs
                completed_epochs = self.get_epochs(self.train_dir_) - 1

                if completed_epochs < epoch:
                    time.sleep(60)
                    continue

                # if last completed epoch has already been processed
                # go back to first epoch that hasn't been processed yet
                process_epoch = epoch if completed_epochs in eers \
                                      else completed_epochs

                # do not validate this epoch if it has been done before...
                if process_epoch == epoch and epoch in eers:
                    epoch += every
                    progress_bar.update(every)
                    continue

                weights_h5 = LoggingCallback.WEIGHTS_H5.format(
                    log_dir=self.train_dir_, epoch=process_epoch)

                # this is needed for corner case when training is started from
                # an epoch > 0
                if not isfile(weights_h5):
                    time.sleep(60)
                    continue

                # sleep 5 seconds to let the checkpoint callback finish
                time.sleep(5)

                embedding = keras.models.load_model(
                    weights_h5, custom_objects=CUSTOM_OBJECTS, compile=False)

                if aggregate:

                    def embed(X):
                        func = K.function([
                            embedding.get_layer(name='input').input,
                            K.learning_phase()
                        ], [embedding.get_layer(name='internal').output])
                        return func([X, 0])[0]
                else:
                    embed = embedding.predict

                # embed all validation sequences
                fX = embed(X)

                if aggregate:
                    indices = np.hstack([[0], np.cumsum(n)])
                    fX = np.stack([
                        np.sum(np.sum(fX[i:j], axis=0), axis=0)
                        for i, j in pairwise(indices)
                    ])
                    fX = l2_normalize(fX)

                # compute pairwise distances
                y_pred = pdist(fX, metric=self.approach_.metric)
                # compute pairwise groundtruth
                y_true = pdist(y, metric='chebyshev') < 1
                # estimate equal error rate
                _, _, _, eer = det_curve(y_true, y_pred, distances=True)
                eers[process_epoch] = eer

                # save equal error rate to file
                fp.write(
                    self.VALIDATE_TXT_TEMPLATE.format(epoch=process_epoch,
                                                      eer=eer))
                fp.flush()

                # keep track of best epoch so far
                best_epoch = eers.iloc[np.argmin(eers.values())]
                best_eer = eers[best_epoch]

                progress_bar.set_description(
                    desc_format.format(epoch=process_epoch,
                                       eer=100 * eer,
                                       best_epoch=best_epoch,
                                       best_eer=100 * best_eer))

                # plot
                fig = plt.figure()
                plt.plot(eers.keys(), eers.values(), 'b')
                plt.plot([best_epoch], [best_eer], 'bo')
                plt.plot([eers.iloc[0], eers.iloc[-1]], [best_eer, best_eer],
                         'k--')
                plt.grid(True)
                plt.xlabel('epoch')
                plt.ylabel('EER on {subset}'.format(subset=subset))
                TITLE = '{best_eer:.5g} @ epoch #{best_epoch:d}'
                title = TITLE.format(best_eer=best_eer,
                                     best_epoch=best_epoch,
                                     subset=subset)
                plt.title(title)
                plt.tight_layout()
                plt.savefig(validate_png, dpi=75)
                plt.savefig(validate_eps)
                plt.close(fig)

                # go to next epoch
                if epoch == process_epoch:
                    epoch += every
                    progress_bar.update(every)
                else:
                    progress_bar.update(0)

        progress_bar.close()
    def fit(self, model, feature_extraction, protocol, log_dir, subset='train',
            epochs=1000, restart=0, gpu=False):

        import tensorboardX
        writer = tensorboardX.SummaryWriter(log_dir=log_dir)

        checkpoint = Checkpoint(log_dir=log_dir,
                                      restart=restart > 0)

        batch_generator = SpeechSegmentGenerator(
            feature_extraction,
            per_label=self.per_label, per_fold=self.per_fold,
            duration=self.duration, parallel=self.parallel)
        batches = batch_generator(protocol, subset=subset)
        batch = next(batches)

        batches_per_epoch = batch_generator.batches_per_epoch

        if restart > 0:
            weights_pt = checkpoint.WEIGHTS_PT.format(
                log_dir=log_dir, epoch=restart)
            model.load_state_dict(torch.load(weights_pt))

        if gpu:
            model = model.cuda()

        model.internal = False

        parameters = list(model.parameters())

        if self.variant in [2, 3, 4, 5, 6, 7, 8]:

            # norm batch-normalization
            self.norm_bn = nn.BatchNorm1d(
                1, eps=1e-5, momentum=0.1, affine=True)
            if gpu:
                self.norm_bn = self.norm_bn.cuda()
            parameters += list(self.norm_bn.parameters())

        if self.variant in [9]:
            # norm batch-normalization
            self.norm_bn = nn.BatchNorm1d(
                1, eps=1e-5, momentum=0.1, affine=False)
            if gpu:
                self.norm_bn = self.norm_bn.cuda()
            parameters += list(self.norm_bn.parameters())

        if self.variant in [5, 6, 7]:
            self.positive_bn = nn.BatchNorm1d(
                1, eps=1e-5, momentum=0.1, affine=False)
            self.negative_bn = nn.BatchNorm1d(
                1, eps=1e-5, momentum=0.1, affine=False)
            if gpu:
                self.positive_bn = self.positive_bn.cuda()
                self.negative_bn = self.negative_bn.cuda()
            parameters += list(self.positive_bn.parameters())
            parameters += list(self.negative_bn.parameters())

        if self.variant in [8, 9]:

            self.delta_bn = nn.BatchNorm1d(
                1, eps=1e-5, momentum=0.1, affine=False)
            if gpu:
                self.delta_bn = self.delta_bn.cuda()
            parameters += list(self.delta_bn.parameters())

        optimizer = Adam(parameters)
        if restart > 0:
            optimizer_pt = checkpoint.OPTIMIZER_PT.format(
                log_dir=log_dir, epoch=restart)
            optimizer.load_state_dict(torch.load(optimizer_pt))
            if gpu:
                for state in optimizer.state.values():
                    for k, v in state.items():
                        if torch.is_tensor(v):
                            state[k] = v.cuda()

        epoch = restart if restart > 0 else -1
        while True:
            epoch += 1
            if epoch > epochs:
                break

            loss_avg, tloss_avg, closs_avg = 0., 0., 0.

            if epoch % 5 == 0:
                log_positive = []
                log_negative = []
                log_delta = []
                log_norm = []

            desc = 'Epoch #{0}'.format(epoch)
            for i in tqdm(range(batches_per_epoch), desc=desc):

                model.zero_grad()

                batch = next(batches)

                X = batch['X']
                if not getattr(model, 'batch_first', True):
                    X = np.rollaxis(X, 0, 2)
                X = np.array(X, dtype=np.float32)
                X = Variable(torch.from_numpy(X))

                if gpu:
                    X = X.cuda()

                fX = model(X)

                # pre-compute pairwise distances
                distances = self.pdist(fX)

                # sample triplets
                triplets = getattr(self, 'batch_{0}'.format(self.sampling))
                anchors, positives, negatives = triplets(batch['y'], distances)

                # compute triplet loss
                tlosses, deltas, pos_index, neg_index  = self.triplet_loss(
                    distances, anchors, positives, negatives,
                    return_delta=True)

                tloss = torch.mean(tlosses)

                if self.variant == 1:

                    closses = F.sigmoid(
                        F.softsign(deltas) * torch.norm(fX[anchors], 2, 1, keepdim=True))

                    # if d(a, p) < d(a, n) (i.e. good case)
                    #   --> sign(delta) < 0
                    #   --> loss decreases when norm increases.
                    #       i.e. encourages longer anchor

                    # if d(a, p) > d(a, n) (i.e. bad case)
                    #   --> sign(delta) > 0
                    #   --> loss increases when norm increases
                    #       i.e. encourages shorter anchor

                elif self.variant == 2:

                    norms_ = torch.norm(fX, 2, 1, keepdim=True)
                    norms_ = F.sigmoid(self.norm_bn(norms_))

                    confidence = (norms_[anchors] + norms_[positives] + norms_[negatives]) / 3
                    # if |x| is average
                    #    --> normalized |x| = 0
                    #    --> confidence = 0.5

                    # if |x| is bigger than average
                    #    --> normalized |x| >> 0
                    #    --> confidence = 1

                    # if |x| is smaller than average
                    #    --> normalized |x| << 0
                    #    --> confidence = 0

                    correctness = F.sigmoid(-deltas / np.pi * 6)
                    # if d(a, p) = d(a, n) (i.e. uncertain case)
                    #    --> correctness = 0.5

                    # if d(a, p) - d(a, n) = -𝛑 (i.e. best possible case)
                    #    --> correctness = 1

                    # if d(a, p) - d(a, n) = +𝛑 (i.e. worst possible case)
                    #    --> correctness = 0

                    closses = torch.abs(confidence - correctness)
                    # small if (and only if) confidence & correctness agree

                elif self.variant == 3:

                    norms_ = torch.norm(fX, 2, 1, keepdim=True)
                    norms_ = F.sigmoid(self.norm_bn(norms_))
                    confidence = (norms_[anchors] * norms_[positives] * norms_[negatives]) / 3

                    correctness = F.sigmoid(-(deltas + np.pi / 4) / np.pi * 6)
                    # correctness = 0.5 at delta == -pi/4
                    # correctness = 1 for delta == -pi
                    # correctness = 0 for delta < 0

                    closses = torch.abs(confidence - correctness)

                elif self.variant == 4:

                    norms_ = torch.norm(fX, 2, 1, keepdim=True)
                    norms_ = F.sigmoid(self.norm_bn(norms_))
                    confidence = (norms_[anchors] * norms_[positives] * norms_[negatives]) ** 1/3

                    correctness = F.sigmoid(-(deltas + np.pi / 4) / np.pi * 6)
                    # correctness = 0.5 at delta == -pi/4
                    # correctness = 1 for delta == -pi
                    # correctness = 0 for delta < 0

                    # delta = pos - neg ... should be < 0

                    closses = torch.abs(confidence - correctness)

                elif self.variant == 5:

                    norms_ = torch.norm(fX, 2, 1, keepdim=True)
                    confidence = F.sigmoid(self.norm_bn(norms_))

                    confidence_pos = .5 * (confidence[anchors] + confidence[positives])
                    # low positive distance == high correctness
                    correctness_pos = F.sigmoid(
                        -self.positive_bn(distances[pos_index].view(-1, 1)))

                    confidence_neg = .5 * (confidence[anchors] + confidence[negatives])
                    # high negative distance == high correctness
                    correctness_neg = F.sigmoid(
                        self.negative_bn(distances[neg_index].view(-1, 1)))

                    closses = .5 * (torch.abs(confidence_pos - correctness_pos) \
                                  + torch.abs(confidence_neg - correctness_neg))

                elif self.variant == 6:

                    norms_ = torch.norm(fX, 2, 1, keepdim=True)
                    confidence = F.sigmoid(self.norm_bn(norms_))

                    confidence_pos = .5 * (confidence[anchors] + confidence[positives])
                    # low positive distance == high correctness
                    correctness_pos = F.sigmoid(
                        -self.positive_bn(distances[pos_index].view(-1, 1)))

                    closses = torch.abs(confidence_pos - correctness_pos)

                elif self.variant == 7:

                    norms_ = torch.norm(fX, 2, 1, keepdim=True)
                    confidence = F.sigmoid(self.norm_bn(norms_))

                    confidence_neg = .5 * (confidence[anchors] + confidence[negatives])
                    # high negative distance == high correctness
                    correctness_neg = F.sigmoid(
                        self.negative_bn(distances[neg_index].view(-1, 1)))

                    closses = torch.abs(confidence_neg - correctness_neg)

                elif self.variant in [8, 9]:

                    norms_ = torch.norm(fX, 2, 1, keepdim=True)
                    norms_ = F.sigmoid(self.norm_bn(norms_))
                    confidence = (norms_[anchors] * norms_[positives] * norms_[negatives]) / 3

                    correctness = F.sigmoid(-self.delta_bn(deltas))
                    closses = torch.abs(confidence - correctness)

                closs = torch.mean(closses)

                if epoch % 5 == 0:

                    if gpu:
                        fX_npy = fX.data.cpu().numpy()
                        pdist_npy = distances.data.cpu().numpy()
                        delta_npy = deltas.data.cpu().numpy()
                    else:
                        fX_npy = fX.data.numpy()
                        pdist_npy = distances.data.numpy()
                        delta_npy = deltas.data.numpy()

                    log_norm.append(np.linalg.norm(fX_npy, axis=1))

                    same_speaker = pdist(batch['y'].reshape((-1, 1)), metric='chebyshev') < 1
                    log_positive.append(pdist_npy[np.where(same_speaker)])
                    log_negative.append(pdist_npy[np.where(~same_speaker)])

                    log_delta.append(delta_npy)

                # log loss
                if gpu:
                    tloss_ = float(tloss.data.cpu().numpy())
                    closs_ = float(closs.data.cpu().numpy())
                else:
                    tloss_ = float(tloss.data.numpy())
                    closs_ = float(closs.data.numpy())
                tloss_avg += tloss_
                closs_avg += closs_
                loss_avg += tloss_ + closs_

                loss = tloss + closs
                loss.backward()
                optimizer.step()

            tloss_avg /= batches_per_epoch
            writer.add_scalar('tloss', tloss_avg, global_step=epoch)

            closs_avg /= batches_per_epoch
            writer.add_scalar('closs', closs_avg, global_step=epoch)

            loss_avg /= batches_per_epoch
            writer.add_scalar('loss', loss_avg, global_step=epoch)

            if epoch % 5 == 0:

                log_positive = np.hstack(log_positive)
                writer.add_histogram(
                    'embedding/pairwise_distance/positive', log_positive,
                    global_step=epoch, bins=np.linspace(0, np.pi, 50))
                log_negative = np.hstack(log_negative)

                writer.add_histogram(
                    'embedding/pairwise_distance/negative', log_negative,
                    global_step=epoch, bins=np.linspace(0, np.pi, 50))

                _, _, _, eer = det_curve(
                    np.hstack([np.ones(len(log_positive)), np.zeros(len(log_negative))]),
                    np.hstack([log_positive, log_negative]), distances=True)
                writer.add_scalar('eer', eer, global_step=epoch)

                log_norm = np.hstack(log_norm)
                writer.add_histogram(
                    'norm', log_norm,
                    global_step=epoch, bins='doane')

                log_delta = np.vstack(log_delta)
                writer.add_histogram(
                    'delta', log_delta,
                    global_step=epoch, bins='doane')

            checkpoint.on_epoch_end(epoch, model, optimizer)

            if hasattr(self, 'norm_bn'):
                confidence_pt = self.CONFIDENCE_PT.format(
                    log_dir=log_dir, epoch=epoch)
                torch.save(self.norm_bn.state_dict(), confidence_pt)
    def _validate_epoch_verification(self, epoch, protocol_name,
                                     subset='development',
                                     validation_data=None):
        """Perform a speaker verification experiment using model at `epoch`

        Parameters
        ----------
        epoch : int
            Epoch to validate.
        protocol_name : str
            Name of speaker verification protocol
        subset : {'train', 'development', 'test'}, optional
            Name of subset.
        validation_data : provided by `validate_init`

        Returns
        -------
        metrics : dict
        """


        # load current model
        model = self.load_model(epoch).to(self.device)
        model.eval()

        # use user-provided --duration when available
        # otherwise use 'duration' used for training
        if self.duration is None:
            duration = self.task_.duration
        else:
            duration = self.duration
        min_duration = None

        # if 'duration' is still None, it means that
        # network was trained with variable lengths
        if duration is None:
            duration = self.task_.max_duration
            min_duration = self.task_.min_duration

        step = .5 * duration

        if isinstance(self.feature_extraction_, Precomputed):
            self.feature_extraction_.use_memmap = False

        # initialize embedding extraction
        sequence_embedding = SequenceEmbedding(
            model, self.feature_extraction_, duration=duration,
            step=step, min_duration=min_duration,
            batch_size=self.batch_size, device=self.device)

        metrics = {}
        protocol = get_protocol(protocol_name, progress=False,
                                preprocessors=self.preprocessors_)

        enrolment_models, enrolment_khashes = {}, {}
        enrolments = getattr(protocol, '{0}_enrolment'.format(subset))()
        for i, enrolment in enumerate(enrolments):
            data = sequence_embedding.apply(enrolment,
                                            crop=enrolment['enrol_with'])
            model_id = enrolment['model_id']
            model = np.mean(np.stack(data), axis=0, keepdims=True)
            enrolment_models[model_id] = model

            # in some specific speaker verification protocols,
            # enrolment data may be  used later as trial data.
            # therefore, we cache information about enrolment data
            # to speed things up by reusing the enrolment as trial
            h = hash((get_unique_identifier(enrolment),
                      tuple(enrolment['enrol_with'])))
            enrolment_khashes[h] = model_id

        trial_models = {}
        trials = getattr(protocol, '{0}_trial'.format(subset))()
        y_true, y_pred = [], []
        for i, trial in enumerate(trials):
            model_id = trial['model_id']

            h = hash((get_unique_identifier(trial),
                      tuple(trial['try_with'])))

            # re-use enrolment model whenever possible
            if h in enrolment_khashes:
                model = enrolment_models[enrolment_khashes[h]]

            # re-use trial model whenever possible
            elif h in trial_models:
                model = trial_models[h]

            else:
                data = sequence_embedding.apply(trial, crop=trial['try_with'])
                model = np.mean(data, axis=0, keepdims=True)
                # cache trial model for later re-use
                trial_models[h] = model

            distance = cdist(enrolment_models[model_id], model,
                             metric=self.metric)[0, 0]
            y_pred.append(distance)
            y_true.append(trial['reference'])

        _, _, _, eer = det_curve(np.array(y_true), np.array(y_pred),
                                 distances=True)
        metrics['EER'] = {'minimize': True, 'value': eer}

        return metrics