def speaker_spotting_try_system4(current_trial):

    # target model
    model = {}
    model_id = current_trial['model_id']
    model_embedding = models[current_trial['model_id']]
    model['mid'] = model_id
    model['embedding'] = model_embedding
    # where to look for this target
    try_with = current_trial['try_with']

    # precomputed embedding
    embeddings = precomputed(current_trial)

    # find index of first and last embedding fully included in 'try_with'
    indices = embeddings.sliding_window.crop(try_with, mode='strict')
    speech_timeline = REFERENCE[current_trial['uri']].crop(
        current_trial['try_with']).get_timeline().support()
    indices_speech = embeddings.sliding_window.crop(speech_timeline,
                                                    mode='strict')
    first, last = indices[0], indices[-1]
    onlineClustering = clustering.OnlineClustering(
        current_trial['uri'],
        cdist(embeddings.data, embeddings.data, metric='cosine'))
    start = embeddings.sliding_window[0].start
    data = np.zeros((len(embeddings.data), 1))
    for i, (window, _) in enumerate(embeddings):
        if i < first or (i not in indices_speech):
            start = window.end
            continue
        if i > last:
            break
        so_far = Segment(start, window.end)
        score = 0.
        example = {}
        example['segment'] = so_far
        example['embedding'] = embeddings.crop(so_far, mode='center')
        example['indice'] = [i]
        example['distances'] = {}
        example['distances'][model['mid']] = list(
            cdist(example['embedding'], model['embedding'],
                  metric='cosine').flatten())

        onlineClustering.upadateCluster2(example)
        if not onlineClustering.empty():
            #min_dist = min(onlineClustering.computeDistances({'embedding': model}))
            min_dist = min(onlineClustering.modelClusterDistance(model))
            score = max(score, 2 - min_dist)
        data[i] = score
        start = window.end
    data = data[first:last + 1]
    sliding_window = SlidingWindow(
        start=embeddings.sliding_window[first].start,
        duration=embeddings.sliding_window.duration,
        step=embeddings.sliding_window.step)

    return SlidingWindowFeature(data, sliding_window)
Beispiel #2
0
def speaker_spotting_try(current_trial):

    # target model
    model = models[current_trial['model_id']]
    # where to look for this target
    try_with = current_trial['try_with']
    
    # precomputed embedding
    embeddings = precomputed(current_trial)
    
    # find index of first and last embedding fully included in 'try_with'
    indices = embeddings.sliding_window.crop(try_with, mode='strict')
    first, last = indices[0], indices[-1]
    
    speech_timeline = REFERENCE[current_trial['uri']].crop(current_trial['try_with']).get_timeline().support()
    indices_speech = embeddings.sliding_window.crop(speech_timeline, mode='strict')

    # compare all embeddings to target model
    scores = 2. - cdist(embeddings.data, model, metric='cosine')

    data = np.zeros((len(embeddings.data), 1))
    for i, (window, _) in enumerate(embeddings):
        # make sure the current segment is in 'try_with'
        if i < first or (i not in indices_speech):
            continue
        if i > last:
            break
        data[i] = scores[i]

    data = data[first:last+1] 
    sliding_window = SlidingWindow(start=embeddings.sliding_window[first].start,
                                   duration=embeddings.sliding_window.duration,
                                   step=embeddings.sliding_window.step)
    
    return SlidingWindowFeature(data, sliding_window)
Beispiel #3
0
def speaker_spotting_try(current_trial):

    # target model
    model = models[current_trial['model_id']]
    # where to look for this target
    try_with = current_trial['try_with']

    # precomputed embedding
    embeddings = precomputed(current_trial)

    # find index of first and last embedding fully included in 'try_with'
    indices = embeddings.sliding_window.crop(try_with, mode='strict')
    first, last = indices[0], indices[-1]

    speech_timeline = SAD[current_trial['uri']]
    indices_speech = embeddings.sliding_window.crop(speech_timeline,
                                                    mode='center')

    # compare all embeddings to target model
    data = 2. - np.mean(
        cdist(embeddings.data, model, metric='cosine'), axis=1, keepdims=True)
    score = np.zeros((len(embeddings.data) + 2, 1))
    indices_speech = [
        indice for indice in indices_speech if indice < len(data)
    ]
    score[indices_speech] = data[indices_speech]
    score = score[first:last + 1]
    sliding_window = SlidingWindow(
        start=embeddings.sliding_window[first].start,
        duration=embeddings.sliding_window.duration,
        step=embeddings.sliding_window.step)

    return SlidingWindowFeature(score, sliding_window)
def speaker_spotting_try_segment(current_trial):

    # target model
    model = models[current_trial['model_id']]
    # where to look for this target
    try_with = current_trial['try_with']

    # precomputed embedding
    embeddings = precomputed(current_trial)

    # find index of first and last embedding fully included in 'try_with'
    indices = embeddings.sliding_window.crop(try_with, mode='strict')
    first, last = indices[0], indices[-1]

    speech_timeline = SPEECH[current_trial['uri']].crop(current_trial['try_with']).get_timeline().support()
    indices_speech = embeddings.sliding_window.crop(speech_timeline, mode='strict')

    # compare all embeddings to target model
    scores = 2. - cdist(embeddings.data, model, metric='cosine')

    data = np.zeros((len(embeddings.data), 1))
    for i, (window, _) in enumerate(embeddings):
        # make sure the current segment is in 'try_with'
        if i < first or (i not in indices_speech):
            continue
        if i > last:
            break
        data[i] = scores[i]

    data = data[first:last+1] 
    sliding_window = SlidingWindow(start=embeddings.sliding_window[first].start,
                                   duration=embeddings.sliding_window.duration,
                                   step=embeddings.sliding_window.step)

    return SlidingWindowFeature(data, sliding_window)
Beispiel #5
0
    def apply(self, fX):
        from hdbscan import HDBSCAN
        clusterer = HDBSCAN(min_cluster_size=self.min_cluster_size,
                            min_samples=self.min_samples,
                            metric='precomputed')
        distance_matrix = squareform(pdist(fX, metric=self.metric))

        # apply clustering
        cluster_labels = clusterer.fit_predict(distance_matrix)

        # cluster embedding
        n_clusters = np.max(cluster_labels) + 1

        if n_clusters < 2:
            return np.zeros(fX.shape[0], dtype=np.int)

        fC = l2_normalize(
            np.vstack([np.sum(fX[cluster_labels == k, :], axis=0)
                       for k in range(n_clusters)]))

        # tag each undefined embedding to closest cluster
        undefined = cluster_labels == -1
        closest_cluster = np.argmin(
            cdist(fC, fX[undefined, :], metric=self.metric), axis=0)
        cluster_labels[undefined] = closest_cluster

        return cluster_labels
Beispiel #6
0
    def compute_similarity(self, cluster1, cluster2, parent=None):

        x1, _ = self[cluster1]
        x2, _ = self[cluster2]

        nx1 = l2_normalize(x1)
        nx2 = l2_normalize(x2)

        similarities = -cdist([nx1], [nx2], metric=self.distance)
        return similarities[0, 0]
Beispiel #7
0
    def compute_similarities(self, cluster, clusters, parent=None):

        x = self[cluster][0].reshape((1, -1))
        X = np.vstack([self[c][0] for c in clusters])

        # L2 normalization
        nx = l2_normalize(x)
        nX = l2_normalize(X)

        similarities = -cdist(nx, nX, metric=self.distance)

        matrix = ValueSortedDict()
        for i, cluster_ in enumerate(clusters):
            matrix[cluster, cluster_] = similarities[0, i]
            matrix[cluster_, cluster] = similarities[0, i]

        return matrix
Beispiel #8
0
        def generator():

            centers = np.arange(n_classes)
            class_generators = [class_generator(jC) for jC in centers]

            previous_label = None

            while True:

                # loop over each centers in random order
                np.random.shuffle(centers)
                for iC in centers:

                    try:
                        # get "per_fold" closest centers to current centers
                        distances = cdist(self.fC_[iC, np.newaxis],
                                          self.fC_,
                                          metric=self.metric)[0]
                    except AttributeError as e:
                        # when on_train_begin hasn't been called yet,
                        # attribute fC_ doesn't exist --> fake it
                        distances = np.random.rand(len(centers))
                        distances[iC] = 0.

                    closest_centers = np.argpartition(
                        distances, self.per_fold)[:self.per_fold]

                    # corner case where last center of previous loop
                    # is the same as first center of current loop
                    if closest_centers[0] == previous_label:
                        closest_centers[:-1] = closest_centers[1:]
                        closest_centers[-1] = previous_label

                    for jC in closest_centers:
                        for _ in range(self.per_label):
                            i = next(class_generators[jC])
                            yield {'X': h5_X[i], 'y': y[i]}
                        previous_label = jC
def speaker_recognition_xp(aggregation, protocol, subset='development',
                           distance='angular', threads=None):

    method = '{subset}_enroll'.format(subset=subset)
    enroll = getattr(protocol, method)(yield_name=True)

    method = '{subset}_test'.format(subset=subset)
    test = getattr(protocol, method)(yield_name=True)

    # TODO parallelize using multiprocessing
    fX = {}
    for name, item in itertools.chain(enroll, test):
        if name in fX:
            continue
        embeddings = aggregation.apply(item)
        fX[name] = np.sum(embeddings.data, axis=0)

    method = '{subset}_keys'.format(subset=subset)
    keys = getattr(protocol, method)()

    enroll_fX = l2_normalize(np.vstack([fX[name] for name in keys.index]))
    test_fX = l2_normalize(np.vstack([fX[name] for name in keys]))

    # compare all possible (enroll, test) pairs at once
    D = cdist(enroll_fX, test_fX, metric=distance)

    positive = D[np.where(keys == 1)]
    negative = D[np.where(keys == -1)]
    # untested = D[np.where(keys == 0)]
    y_pred = np.hstack([positive, negative])

    n_positive = positive.shape[0]
    n_negative = negative.shape[0]
    # n_untested = untested.shape[0]
    y_true = np.hstack([np.ones(n_positive,), np.zeros(n_negative)])

    return det_curve(y_true, y_pred, distances=True)
Beispiel #10
0
def speaker_spotting_try_diarization(current_trial):
    """ speaker spotting system based on the oracle 
    clustering system
    """
    # target model
    # record the model embedding vector
    # and model id
    model = {}
    model_id = current_trial['model_id']
    model_embedding = models[current_trial['model_id']]
    model['mid'] = model_id
    model['embedding'] = model_embedding

    # where to look for this target
    try_with = current_trial['try_with']

    # precomputed embedding
    embeddings = precomputed(current_trial)

    # annotation of current file
    oracle_diarization = REFERENCE[current_trial['uri']].crop(
        current_trial['try_with'])

    # find index of first and last embedding fully included in 'try_with'
    indices = embeddings.sliding_window.crop(try_with, mode='strict')
    first, last = indices[0], indices[-1]

    onlineOracleClustering = clustering.OnlineOracleClustering(
        current_trial['uri'])
    start = embeddings.sliding_window[0].start
    data = np.zeros((len(embeddings.data), 1))
    for i, (window, _) in enumerate(embeddings):
        # make sure the current segment is in 'try_with'
        if i < first:
            start = window.end
            continue
        if i > last:
            break

        so_far = Segment(start, window.end)
        current_annotation = oracle_diarization.crop(so_far)
        score = 0.
        for segment, _, label in current_annotation.itertracks(label=True):
            example = {}
            example['label'] = label
            example['segment'] = segment
            example['embedding'] = embeddings.crop(segment, mode='center')
            example['indice'] = [i]
            # compute the distance with model
            example['distances'] = {}
            example['distances'][model['mid']] = list(
                cdist(example['embedding'],
                      model['embedding'],
                      metric='cosine').flatten())
            # update the online oracle clustering
            onlineOracleClustering.upadateCluster(example)
        if not onlineOracleClustering.empty():
            # compute the current score
            min_dist = min(onlineOracleClustering.modelClusterDistance(model))
            score = max(score, 2 - min_dist)
        data[i] = score
        start = window.end

    # transform scores to sliding window features
    data = data[first:last + 1]
    sliding_window = SlidingWindow(
        start=embeddings.sliding_window[first].start,
        duration=embeddings.sliding_window.duration,
        step=embeddings.sliding_window.step)

    return SlidingWindowFeature(data, sliding_window)
Beispiel #11
0
    def on_epoch_end(self, epoch, logs={}):

        # keep track of current time
        now = datetime.datetime.now().isoformat()
        prefix = self.log_dir + '/{subset}.plot.{epoch:04d}'.format(
            epoch=epoch, subset=self.subset)

        from pyannote.audio.embedding.base import SequenceEmbedding
        sequence_embedding = SequenceEmbedding()
        sequence_embedding.embedding_ = self.glue.extract_embedding(self.model)

        from pyannote.audio.embedding.aggregation import \
            SequenceEmbeddingAggregation
        aggregation = SequenceEmbeddingAggregation(
            sequence_embedding,
            self.glue.feature_extractor,
            duration=self.glue.duration,
            min_duration=self.glue.min_duration,
            step=self.glue.step,
            internal=-2)

        # TODO / pass internal as parameter
        aggregation.cache_preprocessed_ = False

        # embed enroll and test recordings

        method = '{subset}_enroll'.format(subset=self.subset)
        enroll = getattr(self.protocol, method)(yield_name=True)

        method = '{subset}_test'.format(subset=self.subset)
        test = getattr(self.protocol, method)(yield_name=True)

        fX = {}
        for name, item in itertools.chain(enroll, test):
            if name in fX:
                continue
            embeddings = aggregation.apply(item)
            fX[name] = np.sum(embeddings.data, axis=0)

        # perform trials

        method = '{subset}_keys'.format(subset=self.subset)
        keys = getattr(self.protocol, method)()

        enroll_fX = l2_normalize(np.vstack([fX[name] for name in keys.index]))
        test_fX = l2_normalize(np.vstack([fX[name] for name in keys]))

        D = cdist(enroll_fX, test_fX, metric=self.glue.distance)

        y_true = []
        y_pred = []
        key_mapping = {0: None, -1: 0, 1: 1}
        for i, _ in enumerate(keys.index):
            for j, _ in enumerate(keys):
                y = key_mapping[keys.iloc[i, j]]
                if y is None:
                    continue

                y_true.append(y)
                y_pred.append(D[i, j])

        y_true = np.array(y_true)
        y_pred = np.array(y_pred)

        # plot DET curve once every 20 epochs (and 10 first epochs)
        if (epoch < 10) or (epoch % 20 == 0):
            eer = plot_det_curve(y_true,
                                 y_pred,
                                 prefix,
                                 distances=True,
                                 dpi=75)
        else:
            _, _, _, eer = det_curve(y_true, y_pred, distances=True)

        # store equal error rate in file
        mode = 'a' if epoch else 'w'
        path = self.log_dir + '/{subset}.eer.txt'.format(subset=self.subset)
        with open(path, mode=mode) as fp:
            fp.write(self.EER_TEMPLATE_.format(epoch=epoch, eer=eer, now=now))
            fp.flush()

        # plot eer = f(epoch)
        self.eer_.append(eer)
        best_epoch = np.argmin(self.eer_)
        best_value = np.min(self.eer_)
        fig = plt.figure()
        plt.plot(self.eer_, 'b')
        plt.plot([best_epoch], [best_value], 'bo')
        plt.plot([0, epoch], [best_value, best_value], 'k--')
        plt.grid(True)
        plt.xlabel('epoch')
        plt.ylabel('EER on {subset}'.format(subset=self.subset))
        TITLE = 'EER = {best_value:.5g} on {subset} @ epoch #{best_epoch:d}'
        title = TITLE.format(best_value=best_value,
                             best_epoch=best_epoch,
                             subset=self.subset)
        plt.title(title)
        plt.tight_layout()
        path = self.log_dir + '/{subset}.eer.png'.format(subset=self.subset)
        plt.savefig(path, dpi=75)
        plt.close(fig)
def speaker_spotting_try_system2(current_trial):
    """ speaker spotting system based on the oracle 
    clustering system
    """
    # target model
    # record the model embedding vector 
    # and model id
    model = {}
    model_id = current_trial['model_id'] 
    model_embedding = models[current_trial['model_id']]
    model['mid'] = model_id
    model['embedding'] = model_embedding
    
    # where to look for this target
    try_with = current_trial['try_with']
    
    # precomputed embedding
    embeddings = precomputed(current_trial)
    
    # annotation of current file
    oracle_diarization = REFERENCE[current_trial['uri']].crop(current_trial['try_with'])
    
    # find index of first and last embedding fully included in 'try_with'
    indices = embeddings.sliding_window.crop(try_with, mode='strict')
    first, last = indices[0], indices[-1]
    onlineOracleClustering = clustering.OnlineOracleClustering(current_trial['uri'])
    start = embeddings.sliding_window[0].start
    data = np.zeros((len(embeddings.data), 1))
    for i, (window, _) in enumerate(embeddings):
        # make sure the current segment is in 'try_with'
        if i < first:
            start = window.end
            continue
        if i > last:
            break
            
        so_far = Segment(start, window.end)
        current_annotation = oracle_diarization.crop(so_far)
        score = 0.
        for segment, _, label in current_annotation.itertracks(label=True):
            example = {}
            example['label'] = label
            example['segment'] = segment
            example['embedding'] = embeddings.crop(segment, mode='center')
            example['indice'] = [i]
            # compute the distance with model
            example['distances'] = {}
            example['distances'][model['mid']] = list(cdist(example['embedding'], 
                                                            model['embedding'], 
                                                            metric='cosine').flatten())
            # update the online oracle clustering
            onlineOracleClustering.upadateCluster(example)
        if not onlineOracleClustering.empty():
            # compute the current score
            min_dist = min(onlineOracleClustering.modelDistance(model))
            score = max(score, 2-min_dist)
        data[i] = score
        start = window.end
    
    # transform scores to sliding window features
    data = data[first:last+1]
    sliding_window = SlidingWindow(start=embeddings.sliding_window[first].start,
                                   duration=embeddings.sliding_window.duration,
                                   step=embeddings.sliding_window.step)
    
    return SlidingWindowFeature(data, sliding_window)
    def _validate_epoch_verification(self,
                                     epoch,
                                     protocol_name,
                                     subset='development',
                                     validation_data=None):
        """Perform a speaker verification experiment using model at `epoch`

        Parameters
        ----------
        epoch : int
            Epoch to validate.
        protocol_name : str
            Name of speaker verification protocol
        subset : {'train', 'development', 'test'}, optional
            Name of subset.
        validation_data : provided by `validate_init`

        Returns
        -------
        metrics : dict
        """

        # load current model
        model = self.load_model(epoch).to(self.device)
        model.eval()

        # use user-provided --duration when available
        # otherwise use 'duration' used for training
        if self.duration is None:
            duration = self.task_.duration
        else:
            duration = self.duration
        min_duration = None

        # if 'duration' is still None, it means that
        # network was trained with variable lengths
        if duration is None:
            duration = self.task_.max_duration
            min_duration = self.task_.min_duration

        step = .5 * duration

        if isinstance(self.feature_extraction_, Precomputed):
            self.feature_extraction_.use_memmap = False

        # initialize embedding extraction
        sequence_embedding = SequenceEmbedding(model,
                                               self.feature_extraction_,
                                               duration=duration,
                                               step=step,
                                               min_duration=min_duration,
                                               batch_size=self.batch_size,
                                               device=self.device)

        metrics = {}
        protocol = get_protocol(protocol_name,
                                progress=False,
                                preprocessors=self.preprocessors_)

        enrolment_models, enrolment_khashes = {}, {}
        enrolments = getattr(protocol, '{0}_enrolment'.format(subset))()
        for i, enrolment in enumerate(enrolments):
            data = sequence_embedding.apply(enrolment,
                                            crop=enrolment['enrol_with'])
            model_id = enrolment['model_id']
            model = np.mean(np.stack(data), axis=0, keepdims=True)
            enrolment_models[model_id] = model

            # in some specific speaker verification protocols,
            # enrolment data may be  used later as trial data.
            # therefore, we cache information about enrolment data
            # to speed things up by reusing the enrolment as trial
            h = hash((get_unique_identifier(enrolment),
                      tuple(enrolment['enrol_with'])))
            enrolment_khashes[h] = model_id

        trial_models = {}
        trials = getattr(protocol, '{0}_trial'.format(subset))()
        y_true, y_pred = [], []
        for i, trial in enumerate(trials):
            model_id = trial['model_id']

            h = hash((get_unique_identifier(trial), tuple(trial['try_with'])))

            # re-use enrolment model whenever possible
            if h in enrolment_khashes:
                model = enrolment_models[enrolment_khashes[h]]

            # re-use trial model whenever possible
            elif h in trial_models:
                model = trial_models[h]

            else:
                data = sequence_embedding.apply(trial, crop=trial['try_with'])
                model = np.mean(data, axis=0, keepdims=True)
                # cache trial model for later re-use
                trial_models[h] = model

            distance = cdist(enrolment_models[model_id],
                             model,
                             metric=self.metric)[0, 0]
            y_pred.append(distance)
            y_true.append(trial['reference'])

        _, _, _, eer = det_curve(np.array(y_true),
                                 np.array(y_pred),
                                 distances=True)
        metrics['EER'] = {'minimize': True, 'value': eer}

        return metrics
Beispiel #14
0
    def cluster_(self, fX):
        """Compute complete dendrogram

        Parameters
        ----------
        fX : (n_items, dimension) np.array
            Embeddings.

        Returns
        -------
        dendrogram : list of (i, j, distance) tuples
            Dendrogram.
        """

        N = len(fX)

        # clusters contain the identifier of each cluster
        clusters = SortedSet(np.arange(N))

        # labels[i] = c means ith item belongs to cluster c
        labels = np.array(np.arange(N))

        squared = squareform(pdist(fX, metric=self.metric))
        distances = ValueSortedDict()
        for i, j in itertools.combinations(range(N), 2):
            distances[i, j] = squared[i, j]

        dendrogram = []

        for _ in range(N-1):

            # find most similar clusters
            (c_i, c_j), d = distances.peekitem(index=0)

            # keep track of this iteration
            dendrogram.append((c_i, c_j, d))

            # index of clusters in 'clusters' and 'fX'
            i = clusters.index(c_i)
            j = clusters.index(c_j)

            # merge items of cluster c_j into cluster c_i
            labels[labels == c_j] = c_i

            # update c_i representative
            fX[i] += fX[j]

            # remove c_j cluster
            fX[j:-1, :] = fX[j+1:, :]
            fX = fX[:-1]

            # remove distances to c_j cluster
            for c in clusters[:j]:
                distances.pop((c, c_j))
            for c in clusters[j+1:]:
                distances.pop((c_j, c))

            clusters.remove(c_j)

            if len(clusters) < 2:
                continue

            # compute distance to new c_i cluster
            new_d = cdist(fX[i, :].reshape((1, -1)), fX, metric=self.metric).squeeze()
            for c_k, d in zip(clusters, new_d):

                if c_k < c_i:
                    distances[c_k, c_i] = d
                elif c_k > c_i:
                    distances[c_i, c_k] = d

        return dendrogram
    def _validate_epoch_verification(self, epoch, protocol_name,
                                     subset='development',
                                     validation_data=None):
        """Perform a speaker verification experiment using model at `epoch`

        Parameters
        ----------
        epoch : int
            Epoch to validate.
        protocol_name : str
            Name of speaker verification protocol
        subset : {'train', 'development', 'test'}, optional
            Name of subset.
        validation_data : provided by `validate_init`

        Returns
        -------
        metrics : dict
        """


        # load current model
        model = self.load_model(epoch).to(self.device)
        model.eval()

        # use user-provided --duration when available
        # otherwise use 'duration' used for training
        if self.duration is None:
            duration = self.task_.duration
        else:
            duration = self.duration
        min_duration = None

        # if 'duration' is still None, it means that
        # network was trained with variable lengths
        if duration is None:
            duration = self.task_.max_duration
            min_duration = self.task_.min_duration

        step = .5 * duration

        if isinstance(self.feature_extraction_, Precomputed):
            self.feature_extraction_.use_memmap = False

        # initialize embedding extraction
        sequence_embedding = SequenceEmbedding(
            model, self.feature_extraction_, duration=duration,
            step=step, min_duration=min_duration,
            batch_size=self.batch_size, device=self.device)

        metrics = {}
        protocol = get_protocol(protocol_name, progress=False,
                                preprocessors=self.preprocessors_)

        enrolment_models, enrolment_khashes = {}, {}
        enrolments = getattr(protocol, '{0}_enrolment'.format(subset))()
        for i, enrolment in enumerate(enrolments):
            data = sequence_embedding.apply(enrolment,
                                            crop=enrolment['enrol_with'])
            model_id = enrolment['model_id']
            model = np.mean(np.stack(data), axis=0, keepdims=True)
            enrolment_models[model_id] = model

            # in some specific speaker verification protocols,
            # enrolment data may be  used later as trial data.
            # therefore, we cache information about enrolment data
            # to speed things up by reusing the enrolment as trial
            h = hash((get_unique_identifier(enrolment),
                      tuple(enrolment['enrol_with'])))
            enrolment_khashes[h] = model_id

        trial_models = {}
        trials = getattr(protocol, '{0}_trial'.format(subset))()
        y_true, y_pred = [], []
        for i, trial in enumerate(trials):
            model_id = trial['model_id']

            h = hash((get_unique_identifier(trial),
                      tuple(trial['try_with'])))

            # re-use enrolment model whenever possible
            if h in enrolment_khashes:
                model = enrolment_models[enrolment_khashes[h]]

            # re-use trial model whenever possible
            elif h in trial_models:
                model = trial_models[h]

            else:
                data = sequence_embedding.apply(trial, crop=trial['try_with'])
                model = np.mean(data, axis=0, keepdims=True)
                # cache trial model for later re-use
                trial_models[h] = model

            distance = cdist(enrolment_models[model_id], model,
                             metric=self.metric)[0, 0]
            y_pred.append(distance)
            y_true.append(trial['reference'])

        _, _, _, eer = det_curve(np.array(y_true), np.array(y_pred),
                                 distances=True)
        metrics['EER'] = {'minimize': True, 'value': eer}

        return metrics