def objective_function(parameters, beta=1.0): epoch = parameters[0] weights_h5 = WEIGHTS_H5.format(epoch=epoch) sequence_embedding = SequenceEmbedding.from_disk( architecture_yml, weights_h5) fX = sequence_embedding.transform(X, batch_size=batch_size) # compute distance between every pair of sequences y_distance = pdist(fX, metric=distance) # compute same/different groundtruth y_true = pdist(y, metric='chebyshev') < 1 # false positive / true positive fpr, tpr, thresholds = sklearn.metrics.roc_curve( y_true, -y_distance, pos_label=True, drop_intermediate=True) fnr = 1. - tpr far = fpr thresholds = -thresholds fscore = 1. - f_measure(1. - fnr, 1. - far, beta=beta) i = np.nanargmin(fscore) alphas[epoch] = float(thresholds[i]) return fscore[i]
def speaker_diarization_xp(sequence_embedding, X, y, distance='angular'): fX = sequence_embedding.transform(X) # compute distance between every pair of sequences y_pred = pdist(fX, metric=distance) # compute same/different groundtruth y_true = pdist(y, metric='chebyshev') < 1 # return DET curve return det_curve(y_true, y_pred, distances=True)
def _validate_init_turn(self, protocol_name, subset='development'): np.random.seed(1337) protocol = get_protocol(protocol_name, progress=False, preprocessors=self.preprocessors_) batch_generator = SpeechTurnSubSegmentGenerator( self.feature_extraction_, self.duration, per_label=10, per_turn=5) batch = next(batch_generator(protocol, subset=subset)) X = np.stack(batch['X']) y = np.stack(batch['y']) z = np.stack(batch['z']) # get list of labels from list of repeated labels: # z 0 0 0 1 1 1 2 2 2 2 3 3 3 3 # y A A A A A A B B B B B B B B # becomes # z 0 0 0 1 1 1 2 2 2 2 3 3 3 3 # y A B yz = np.vstack([y, z]).T y = [] for _, yz_ in itertools.groupby(yz, lambda t: t[1]): yz_ = np.stack(yz_) y.append(yz_[0, 0]) y = np.array(y).reshape((-1, 1)) # precompute same/different groundtruth y = pdist(y, metric='equal') return {'X': X, 'y': y, 'z': z}
def apply(self, fX): from hdbscan import HDBSCAN clusterer = HDBSCAN(min_cluster_size=self.min_cluster_size, min_samples=self.min_samples, metric='precomputed') distance_matrix = squareform(pdist(fX, metric=self.metric)) # apply clustering cluster_labels = clusterer.fit_predict(distance_matrix) # cluster embedding n_clusters = np.max(cluster_labels) + 1 if n_clusters < 2: return np.zeros(fX.shape[0], dtype=np.int) fC = l2_normalize( np.vstack([np.sum(fX[cluster_labels == k, :], axis=0) for k in range(n_clusters)])) # tag each undefined embedding to closest cluster undefined = cluster_labels == -1 closest_cluster = np.argmin( cdist(fC, fX[undefined, :], metric=self.metric), axis=0) cluster_labels[undefined] = closest_cluster return cluster_labels
def on_epoch_end(self, epoch, logs={}): # keep track of current time now = datetime.datetime.now().isoformat() embedding = self.extract_embedding(self.model) fX = embedding.predict(self.X_) distance = pdist(fX, metric=self.distance) prefix = self.log_dir + '/{subset}.plot.{epoch:04d}'.format( subset=self.subset, epoch=epoch) # plot distance distribution every 20 epochs (and 10 first epochs) xlim = get_range(metric=self.distance) if (epoch < 10) or (epoch % 20 == 0): plot_distributions(self.y_, distance, prefix, xlim=xlim, ymax=3, nbins=100, dpi=75) # plot DET curve once every 20 epochs (and 10 first epochs) if (epoch < 10) or (epoch % 20 == 0): eer = plot_det_curve(self.y_, distance, prefix, distances=True, dpi=75) else: _, _, _, eer = det_curve(self.y_, distance, distances=True) # store equal error rate in file mode = 'a' if epoch else 'w' path = self.log_dir + '/{subset}.eer.txt'.format(subset=self.subset) with open(path, mode=mode) as fp: fp.write(self.EER_TEMPLATE_.format(epoch=epoch, eer=eer, now=now)) fp.flush() # plot eer = f(epoch) self.eer_.append(eer) best_epoch = np.argmin(self.eer_) best_value = np.min(self.eer_) fig = plt.figure() plt.plot(self.eer_, 'b') plt.plot([best_epoch], [best_value], 'bo') plt.plot([0, epoch], [best_value, best_value], 'k--') plt.grid(True) plt.xlabel('epoch') plt.ylabel('EER on {subset}'.format(subset=self.subset)) TITLE = 'EER = {best_value:.5g} on {subset} @ epoch #{best_epoch:d}' title = TITLE.format(best_value=best_value, best_epoch=best_epoch, subset=self.subset) plt.title(title) plt.tight_layout() path = self.log_dir + '/{subset}.eer.png'.format(subset=self.subset) plt.savefig(path, dpi=75) plt.close(fig)
def __init__(self, glue, protocol, subset, log_dir): super(SpeakerDiarizationValidation, self).__init__() self.subset = subset self.distance = glue.distance self.extract_embedding = glue.extract_embedding self.log_dir = log_dir np.random.seed(1337) # initialize fixed duration sequence generator if glue.min_duration is None: # initialize fixed duration sequence generator generator = FixedDurationSequences(glue.feature_extractor, duration=glue.duration, step=glue.step, batch_size=-1) else: # initialize variable duration sequence generator generator = VariableDurationSequences( glue.feature_extractor, max_duration=glue.duration, min_duration=glue.min_duration, batch_size=-1) # randomly select (at most) 100 sequences from each label to ensure # all labels have (more or less) the same weight in the evaluation file_generator = getattr(protocol, subset)() X, y = zip(*generator(file_generator)) X = np.vstack(X) y = np.hstack(y) unique, y, counts = np.unique(y, return_inverse=True, return_counts=True) n_labels = len(unique) indices = [] for label in range(n_labels): i = np.random.choice(np.where(y == label)[0], size=min(100, counts[label]), replace=False) indices.append(i) indices = np.hstack(indices) X, y = X[indices], y[indices, np.newaxis] # precompute same/different groundtruth self.y_ = pdist(y, metric='chebyshev') < 1 self.X_ = X self.EER_TEMPLATE_ = '{epoch:04d} {now} {eer:5f}\n' self.eer_ = []
def apply(self, current_file): # initial segmentation speech_turns = super().apply(current_file) # initialize the hypothesized annotation hypothesis = Annotation(uri=current_file['uri']) if len(speech_turns) < 1: return hypothesis # this only happens during pipeline training if 'annotation' in current_file: # number of speech turns in reference reference = current_file['annotation'] n_turns_true = len(list(reference.itertracks())) # number of speech turns in hypothesis uem = get_annotated(current_file) n_turns_pred = len(speech_turns.crop(uem)) # don't even bother trying to cluster those speech turns # as there are too many of those... if n_turns_pred > 20 * n_turns_true: return None # get raw (sliding window) embeddings emb = self.emb_(current_file) # get one embedding per speech turn # FIXME don't l2_normalize for any metric fX = l2_normalize( np.vstack([ np.sum(emb.crop(t, mode='loose'), axis=0) for t in speech_turns ])) # apply clustering try: affinity = -squareform(pdist(fX, metric=self.metric)) clusters = self.cls_.fit_predict(affinity) except MemoryError as e: # cannot compute affinity propagation return None for speech_turn, cluster in zip(speech_turns, clusters): # HACK find why fit_predict returns NaN sometimes and fix it. cluster = -1 if np.isnan(cluster) else cluster hypothesis[speech_turn] = cluster return hypothesis
def compute_similarity_matrix(self, parent=None): clusters = list(self._models) n_clusters = len(clusters) X = np.vstack([self[cluster][0] for cluster in clusters]) nX = l2_normalize(X) similarities = -squareform(pdist(nX, metric=self.distance)) matrix = ValueSortedDict() for i, j in itertools.combinations(range(n_clusters), 2): matrix[clusters[i], clusters[j]] = similarities[i, j] matrix[clusters[j], clusters[i]] = similarities[j, i] return matrix
def apply(self, current_file): # initial segmentation speech_turns = super().apply(current_file) # initialize the hypothesized annotation hypothesis = Annotation(uri=current_file['uri']) if len(speech_turns) < 1: return hypothesis # this only happens during pipeline training if 'annotation' in current_file: # number of speech turns in reference reference = current_file['annotation'] n_turns_true = len(list(reference.itertracks())) # number of speech turns in hypothesis uem = get_annotated(current_file) n_turns_pred = len(speech_turns.crop(uem)) # don't even bother trying to cluster those speech turns # as there are too many of those... if n_turns_pred > 20 * n_turns_true: return None # get raw (sliding window) embeddings emb = self.emb_(current_file) # get one embedding per speech turn # FIXME don't l2_normalize for any metric fX = l2_normalize(np.vstack( [np.sum(emb.crop(t, mode='loose'), axis=0) for t in speech_turns])) # apply clustering try: affinity = -squareform(pdist(fX, metric=self.metric)) clusters = self.cls_.fit_predict(affinity) except MemoryError as e: # cannot compute affinity propagation return None for speech_turn, cluster in zip(speech_turns, clusters): # HACK find why fit_predict returns NaN sometimes and fix it. cluster = -1 if np.isnan(cluster) else cluster hypothesis[speech_turn] = cluster return hypothesis
def _validate_init_segment(self, protocol_name, subset='development'): np.random.seed(1337) protocol = get_protocol(protocol_name, progress=False, preprocessors=self.preprocessors_) batch_generator = SpeechSegmentGenerator( self.feature_extraction_, per_label=10, duration=self.duration) batch = next(batch_generator(protocol, subset=subset)) X = np.stack(batch['X']) y = np.stack(batch['y']).reshape((-1, 1)) # precompute same/different groundtruth y = pdist(y, metric='equal') return {'X': X, 'y': y}
def _validate_epoch_segment(self, epoch, protocol_name, subset='development', validation_data=None): model = self.load_model(epoch).to(self.device) model.eval() sequence_embedding = SequenceEmbedding( model, self.feature_extraction_, batch_size=self.batch_size, device=self.device) fX = sequence_embedding.apply(validation_data['X']) y_pred = pdist(fX, metric=self.metric) _, _, _, eer = det_curve(validation_data['y'], y_pred, distances=True) return {'EER.{0:g}s'.format(self.duration): {'minimize': True, 'value': eer}}
def _validate_init_segment(self, protocol_name, subset='development'): np.random.seed(1337) protocol = get_protocol(protocol_name, progress=False, preprocessors=self.preprocessors_) batch_generator = SpeechSegmentGenerator(self.feature_extraction_, per_label=10, duration=self.duration) batch = next(batch_generator(protocol, subset=subset)) X = np.stack(batch['X']) y = np.stack(batch['y']).reshape((-1, 1)) # precompute same/different groundtruth y = pdist(y, metric='equal') return {'X': X, 'y': y}
def _validate_epoch_turn(self, epoch, protocol_name, subset='development', validation_data=None): model = self.load_model(epoch).to(self.device) model.eval() sequence_embedding = SequenceEmbedding(model, self.feature_extraction_, batch_size=self.batch_size, device=self.device) fX = sequence_embedding.apply(validation_data['X']) z = validation_data['z'] # iterate over segments, speech turn by speech turn fX_avg = [] nz = np.vstack([np.arange(len(z)), z]).T for _, nz_ in itertools.groupby(nz, lambda t: t[1]): # (n, 2) numpy array where # * n is the number of segments in current speech turn # * dim #0 is the index of segment in original batch # * dim #1 is the index of speech turn (used for grouping) nz_ = np.stack(nz_) # compute (and stack) average embedding over all segments # of current speech turn indices = nz_[:, 0] fX_avg.append(np.mean(fX[indices], axis=0)) fX = np.vstack(fX_avg) y_pred = pdist(fX, metric=self.metric) _, _, _, eer = det_curve(validation_data['y'], y_pred, distances=True) metrics = {} metrics['EER.turn'] = {'minimize': True, 'value': eer} return metrics
def _validate_epoch_turn(self, epoch, protocol_name, subset='development', validation_data=None): model = self.load_model(epoch).to(self.device) model.eval() sequence_embedding = SequenceEmbedding( model, self.feature_extraction_, batch_size=self.batch_size, device=self.device) fX = sequence_embedding.apply(validation_data['X']) z = validation_data['z'] # iterate over segments, speech turn by speech turn fX_avg = [] nz = np.vstack([np.arange(len(z)), z]).T for _, nz_ in itertools.groupby(nz, lambda t: t[1]): # (n, 2) numpy array where # * n is the number of segments in current speech turn # * dim #0 is the index of segment in original batch # * dim #1 is the index of speech turn (used for grouping) nz_ = np.stack(nz_) # compute (and stack) average embedding over all segments # of current speech turn indices = nz_[:, 0] fX_avg.append(np.mean(fX[indices], axis=0)) fX = np.vstack(fX_avg) y_pred = pdist(fX, metric=self.metric) _, _, _, eer = det_curve(validation_data['y'], y_pred, distances=True) metrics = {} metrics['EER.turn'] = {'minimize': True, 'value': eer} return metrics
def _validate_epoch_segment(self, epoch, protocol_name, subset='development', validation_data=None): model = self.load_model(epoch).to(self.device) model.eval() sequence_embedding = SequenceEmbedding(model, self.feature_extraction_, batch_size=self.batch_size, device=self.device) fX = sequence_embedding.apply(validation_data['X']) y_pred = pdist(fX, metric=self.metric) _, _, _, eer = det_curve(validation_data['y'], y_pred, distances=True) return { 'EER.{0:g}s'.format(self.duration): { 'minimize': True, 'value': eer } }
def batch_loss(self, batch, model, device, writer=None, **kwargs): lengths = torch.tensor([len(x) for x in batch['X']]) variable_lengths = len(set(lengths)) > 1 if variable_lengths: sorted_lengths, sort = torch.sort(lengths, descending=True) _, unsort = torch.sort(sort) sequences = [torch.tensor(batch['X'][i], dtype=torch.float32, device=device) for i in sort] padded = pad_sequence(sequences, batch_first=True, padding_value=0) packed = pack_padded_sequence(padded, sorted_lengths, batch_first=True) batch['X'] = packed else: batch['X'] = torch.tensor(np.stack(batch['X']), dtype=torch.float32, device=device) # forward pass fX = model(batch['X']) if variable_lengths: fX = fX[unsort] # log embedding norms if writer is not None: norm_npy = np.linalg.norm(self.to_numpy(fX), axis=1) self.log_norm_.append(norm_npy) batch['fX'] = fX batch = self.aggregate(batch) fX = batch['fX'] y = batch['y'] # pre-compute pairwise distances distances = self.pdist(fX) # sample triplets triplets = getattr(self, 'batch_{0}'.format(self.sampling)) anchors, positives, negatives = triplets(y, distances) # compute loss for each triplet losses, deltas, _, _ = self.triplet_loss( distances, anchors, positives, negatives, return_delta=True) if writer is not None: pdist_npy = self.to_numpy(distances) delta_npy = self.to_numpy(deltas) same_speaker = pdist(y.reshape((-1, 1)), metric='equal') self.log_positive_.append(pdist_npy[np.where(same_speaker)]) self.log_negative_.append(pdist_npy[np.where(~same_speaker)]) self.log_delta_.append(delta_npy) # average over all triplets return torch.mean(losses)
def computeDistMat(self, X, metric='angular'): dist=pdist(X, metric=metric) return dist
def computeLogDistMat(self, X, metric='angular'): dist=pdist(X, metric=metric) distMat = squareform((dist))*(-1.0) return distMat
def test(protocol, tune_dir, test_dir, subset, beta=1.0): batch_size = 32 try: os.makedirs(test_dir) except Exception as e: pass train_dir = os.path.dirname(os.path.dirname(tune_dir)) # -- DURATIONS -- duration, min_duration, step, heterogeneous = \ path_to_duration(os.path.basename(train_dir)) config_dir = os.path.dirname(os.path.dirname(os.path.dirname(train_dir))) config_yml = config_dir + '/config.yml' with open(config_yml, 'r') as fp: config = yaml.load(fp) # -- PREPROCESSORS -- for key, preprocessor in config.get('preprocessors', {}).items(): preprocessor_name = preprocessor['name'] preprocessor_params = preprocessor.get('params', {}) preprocessors = __import__('pyannote.audio.preprocessors', fromlist=[preprocessor_name]) Preprocessor = getattr(preprocessors, preprocessor_name) protocol.preprocessors[key] = Preprocessor(**preprocessor_params) # -- FEATURE EXTRACTION -- feature_extraction_name = config['feature_extraction']['name'] features = __import__('pyannote.audio.features', fromlist=[feature_extraction_name]) FeatureExtraction = getattr(features, feature_extraction_name) feature_extraction = FeatureExtraction( **config['feature_extraction'].get('params', {})) distance = config['glue'].get('params', {}).get('distance', 'sqeuclidean') # -- HYPER-PARAMETERS -- tune_yml = tune_dir + '/tune.yml' with open(tune_yml, 'r') as fp: tune = yaml.load(fp) architecture_yml = train_dir + '/architecture.yml' WEIGHTS_H5 = train_dir + '/weights/{epoch:04d}.h5' weights_h5 = WEIGHTS_H5.format(epoch=tune['epoch']) sequence_embedding = SequenceEmbedding.from_disk( architecture_yml, weights_h5) X, y = generate_test(protocol, subset, feature_extraction, duration, min_duration=min_duration, step=step) fX = sequence_embedding.transform(X, batch_size=batch_size) y_distance = pdist(fX, metric=distance) y_true = pdist(y, metric='chebyshev') < 1 fpr, tpr, thresholds = sklearn.metrics.roc_curve( y_true, -y_distance, pos_label=True, drop_intermediate=True) frr = 1. - tpr far = fpr thresholds = -thresholds eer_index = np.where(far > frr)[0][0] eer = .25 * (far[eer_index-1] + far[eer_index] + frr[eer_index-1] + frr[eer_index]) fscore = 1. - f_measure(1. - frr, 1. - far, beta=beta) opt_i = np.nanargmin(fscore) opt_alpha = float(thresholds[opt_i]) opt_far = far[opt_i] opt_frr = frr[opt_i] opt_fscore = fscore[opt_i] alpha = tune['alpha'] actual_i = np.searchsorted(thresholds, alpha) actual_far = far[actual_i] actual_frr = frr[actual_i] actual_fscore = fscore[actual_i] save_to = test_dir + '/' + subset plot_distributions(y_true, y_distance, save_to) eer = plot_det_curve(y_true, -y_distance, save_to) plot_precision_recall_curve(y_true, -y_distance, save_to) with open(save_to + '.txt', 'w') as fp: fp.write('# cond. thresh far frr fscore eer\n') TEMPLATE = '{condition} {alpha:.5f} {far:.5f} {frr:.5f} {fscore:.5f} {eer:.5f}\n' fp.write(TEMPLATE.format(condition='optimal', alpha=opt_alpha, far=opt_far, frr=opt_frr, fscore=opt_fscore, eer=eer)) fp.write(TEMPLATE.format(condition='actual ', alpha=alpha, far=actual_far, frr=actual_frr, fscore=actual_fscore, eer=eer))
def cluster_(self, fX): """Compute complete dendrogram Parameters ---------- fX : (n_items, dimension) np.array Embeddings. Returns ------- dendrogram : list of (i, j, distance) tuples Dendrogram. """ N = len(fX) # clusters contain the identifier of each cluster clusters = SortedSet(np.arange(N)) # labels[i] = c means ith item belongs to cluster c labels = np.array(np.arange(N)) squared = squareform(pdist(fX, metric=self.metric)) distances = ValueSortedDict() for i, j in itertools.combinations(range(N), 2): distances[i, j] = squared[i, j] dendrogram = [] for _ in range(N-1): # find most similar clusters (c_i, c_j), d = distances.peekitem(index=0) # keep track of this iteration dendrogram.append((c_i, c_j, d)) # index of clusters in 'clusters' and 'fX' i = clusters.index(c_i) j = clusters.index(c_j) # merge items of cluster c_j into cluster c_i labels[labels == c_j] = c_i # update c_i representative fX[i] += fX[j] # remove c_j cluster fX[j:-1, :] = fX[j+1:, :] fX = fX[:-1] # remove distances to c_j cluster for c in clusters[:j]: distances.pop((c, c_j)) for c in clusters[j+1:]: distances.pop((c_j, c)) clusters.remove(c_j) if len(clusters) < 2: continue # compute distance to new c_i cluster new_d = cdist(fX[i, :].reshape((1, -1)), fX, metric=self.metric).squeeze() for c_k, d in zip(clusters, new_d): if c_k < c_i: distances[c_k, c_i] = d elif c_k > c_i: distances[c_i, c_k] = d return dendrogram
def validate(self, protocol_name, subset='development', aggregate=False, every=1, start=0): # prepare paths validate_dir = self.VALIDATE_DIR.format(train_dir=self.train_dir_, protocol=protocol_name) validate_txt = self.VALIDATE_TXT.format( validate_dir=validate_dir, subset=subset, aggregate='aggregate.' if aggregate else '') validate_png = self.VALIDATE_PNG.format( validate_dir=validate_dir, subset=subset, aggregate='aggregate.' if aggregate else '') validate_eps = self.VALIDATE_EPS.format( validate_dir=validate_dir, subset=subset, aggregate='aggregate.' if aggregate else '') # create validation directory mkdir_p(validate_dir) # Build validation set if aggregate: X, n, y = self._validation_set_z(protocol_name, subset=subset) else: X, y = self._validation_set_y(protocol_name, subset=subset) # list of equal error rates, and epoch to process eers, epoch = SortedDict(), start desc_format = ('Best EER = {best_eer:.2f}% @ epoch #{best_epoch:d} ::' ' EER = {eer:.2f}% @ epoch #{epoch:d} :') progress_bar = tqdm(unit='epoch') with open(validate_txt, mode='w') as fp: # watch and evaluate forever while True: # last completed epochs completed_epochs = self.get_epochs(self.train_dir_) - 1 if completed_epochs < epoch: time.sleep(60) continue # if last completed epoch has already been processed # go back to first epoch that hasn't been processed yet process_epoch = epoch if completed_epochs in eers \ else completed_epochs # do not validate this epoch if it has been done before... if process_epoch == epoch and epoch in eers: epoch += every progress_bar.update(every) continue weights_h5 = LoggingCallback.WEIGHTS_H5.format( log_dir=self.train_dir_, epoch=process_epoch) # this is needed for corner case when training is started from # an epoch > 0 if not isfile(weights_h5): time.sleep(60) continue # sleep 5 seconds to let the checkpoint callback finish time.sleep(5) embedding = keras.models.load_model( weights_h5, custom_objects=CUSTOM_OBJECTS, compile=False) if aggregate: def embed(X): func = K.function([ embedding.get_layer(name='input').input, K.learning_phase() ], [embedding.get_layer(name='internal').output]) return func([X, 0])[0] else: embed = embedding.predict # embed all validation sequences fX = embed(X) if aggregate: indices = np.hstack([[0], np.cumsum(n)]) fX = np.stack([ np.sum(np.sum(fX[i:j], axis=0), axis=0) for i, j in pairwise(indices) ]) fX = l2_normalize(fX) # compute pairwise distances y_pred = pdist(fX, metric=self.approach_.metric) # compute pairwise groundtruth y_true = pdist(y, metric='chebyshev') < 1 # estimate equal error rate _, _, _, eer = det_curve(y_true, y_pred, distances=True) eers[process_epoch] = eer # save equal error rate to file fp.write( self.VALIDATE_TXT_TEMPLATE.format(epoch=process_epoch, eer=eer)) fp.flush() # keep track of best epoch so far best_epoch = eers.iloc[np.argmin(eers.values())] best_eer = eers[best_epoch] progress_bar.set_description( desc_format.format(epoch=process_epoch, eer=100 * eer, best_epoch=best_epoch, best_eer=100 * best_eer)) # plot fig = plt.figure() plt.plot(eers.keys(), eers.values(), 'b') plt.plot([best_epoch], [best_eer], 'bo') plt.plot([eers.iloc[0], eers.iloc[-1]], [best_eer, best_eer], 'k--') plt.grid(True) plt.xlabel('epoch') plt.ylabel('EER on {subset}'.format(subset=subset)) TITLE = '{best_eer:.5g} @ epoch #{best_epoch:d}' title = TITLE.format(best_eer=best_eer, best_epoch=best_epoch, subset=subset) plt.title(title) plt.tight_layout() plt.savefig(validate_png, dpi=75) plt.savefig(validate_eps) plt.close(fig) # go to next epoch if epoch == process_epoch: epoch += every progress_bar.update(every) else: progress_bar.update(0) progress_bar.close()
def batch_loss(self, batch, model, device, writer=None, **kwargs): lengths = torch.tensor([len(x) for x in batch['X']]) variable_lengths = len(set(lengths)) > 1 if variable_lengths: sorted_lengths, sort = torch.sort(lengths, descending=True) _, unsort = torch.sort(sort) sequences = [ torch.tensor(batch['X'][i], dtype=torch.float32, device=device) for i in sort ] padded = pad_sequence(sequences, batch_first=True, padding_value=0) packed = pack_padded_sequence(padded, sorted_lengths, batch_first=True) batch['X'] = packed else: batch['X'] = torch.tensor(np.stack(batch['X']), dtype=torch.float32, device=device) # forward pass fX = model(batch['X']) if variable_lengths: fX = fX[unsort] # log embedding norms if writer is not None: norm_npy = np.linalg.norm(self.to_numpy(fX), axis=1) self.log_norm_.append(norm_npy) batch['fX'] = fX batch = self.aggregate(batch) fX = batch['fX'] y = batch['y'] # pre-compute pairwise distances distances = self.pdist(fX) # sample triplets triplets = getattr(self, 'batch_{0}'.format(self.sampling)) anchors, positives, negatives = triplets(y, distances) # compute loss for each triplet losses, deltas, _, _ = self.triplet_loss(distances, anchors, positives, negatives, return_delta=True) if writer is not None: pdist_npy = self.to_numpy(distances) delta_npy = self.to_numpy(deltas) same_speaker = pdist(y.reshape((-1, 1)), metric='equal') self.log_positive_.append(pdist_npy[np.where(same_speaker)]) self.log_negative_.append(pdist_npy[np.where(~same_speaker)]) self.log_delta_.append(delta_npy) # average over all triplets return torch.mean(losses)