def phone_level_cluster(self, out_prefix='predictions'): self.load_checkpoint() X_a = np.zeros((self.n_class, self.K)) norm = np.zeros((self.n_class, 1)) audio_files = [] encodings = [] # Find the centroid of each phone-level cluster B = self.data_loader['test'].batch_size testset = self.data_loader['test'].dataset for b_idx, (audios, _, _, audio_masks, _) in enumerate(self.data_loader['test']): if b_idx > 2 and self.debug: break audios = cuda(audios, self.cuda) audio_masks = cuda(audio_masks, self.cuda) _, _, encoding, embedding = self.audio_net(audios, mask=audio_masks, return_feat=True) encoding = encoding.permute(0, 2, 1).cpu().detach().numpy() embedding = embedding.cpu().detach().numpy() X_a += encoding @ embedding norm += encoding.sum(axis=-1, keepdims=True) audio_files.extend([ testset.dataset[b_idx * B + i][0] for i in range(audios.size(0)) ]) encodings.append(encoding.T) encodings = np.concatenate(encodings) X_a /= norm X_s = self.audio_net.bottleneck.weight +\ self.audio_net.bottleneck.bias X = np.concatenate([X_a, X_s], axis=1) kmeans = KMeans(n_clusters=50).fit(X) phoneme_labels = kmeans.labels_ out_file = os.path.join(self.ckpt_dir, f'{out_prefix}_phone_level_clustering.txt') out_f = open(out_file, 'w') pred_phones = encodings.max(-1)[0] for idx, (audio_file, encoding) in enumerate(zip(audio_files, encodings)): audio_id = os.path.splitext(os.path.split(audio_file)[1])[0] pred_phonemes = ','.join( [str(phoneme_labels[phn]) for phn in pred_phones[idx]]) out_f.write('{audio_id} {pred_phonemes}\n') out_f.close() gold_path = os.path.join(os.path.join(testset.data_path, 'test/')) compute_token_f1( out_file, gold_path, os.path.join(self.ckpt_dir, 'confusion_phone_level_cluster.png'))
def cluster(self, n_clusters=50, out_prefix='quantized_outputs'): self.set_mode('eval') testset = self.data_loader['test'].dataset temp = self.history['temp'] us_ratio = int(self.hop_len_ms / 10) * self.audio_net.ds_ratio with torch.no_grad(): B = 0 utt_ids = [] X = [] for b_idx, (audios, phoneme_labels, word_labels,\ audio_masks, phone_masks, word_masks)\ in enumerate(self.data_loader['test']): if b_idx > 2 and self.debug: break if b_idx == 0: B = audios.size(0) audios = cuda(audios, self.cuda) audio_masks = cuda(audio_masks, self.cuda) if self.audio_feature == 'wav2vec2': x = self.audio_feature_net.feature_extractor(audios) else: x = audios audio_lens = audio_masks.sum(-1).long() outputs = self.audio_net(x, return_feat=True) embedding = outputs[-1] for idx in range(audios.size(0)): global_idx = b_idx * B + idx utt_id = os.path.splitext(os.path.basename(testset.dataset[global_idx][0]))[0] embed = embedding[idx, :audio_lens[idx]].cpu().detach().numpy() X.extend(embed.tolist()) utt_ids.extend([utt_id]*embed.shape[0]) X = np.asarray(X) begin_time = time.time() clusterer = KMeans(n_clusters=n_clusters).fit(X) print(f'KMeans take {time.time()-begin_time} s to finish') np.save(self.ckpt_dir.joinpath('cluster_means.npy'), clusterer.cluster_centers_) ys = clusterer.predict(X) filename = self.ckpt_dir.joinpath(out_prefix+'.txt') out_f = open(filename, 'w') for utt_id, group in groupby(list(zip(utt_ids, ys)), lambda x:x[0]): y = ','.join([str(g[1]) for g in group for _ in range(us_ratio)]) out_f.write(f'{utt_id} {y}\n') out_f.close() gold_path = os.path.join(os.path.join(testset.data_path, f'{testset.splits[0]}')) token_f1, token_prec, token_recall = compute_token_f1( filename, gold_path, self.ckpt_dir.joinpath(f'confusion.png'), )
def test(self, save_embedding=False, out_prefix='predictions'): self.set_mode('eval') testset = self.data_loader['test'].dataset preprocessor = testset.preprocessor total_loss = 0. total_step = 0. pred_word_labels = [] pred_word_labels_quantized = [] gold_word_labels = [] if not self.ckpt_dir.joinpath('outputs/phonetic/dev-clean').is_dir(): os.makedirs(self.ckpt_dir.joinpath('outputs/phonetic/dev-clean')) gold_phone_file = os.path.join(testset.data_path, f'{testset.splits[0]}/{testset.splits[0]}_nonoverlap.item') word_readable_f = open(self.ckpt_dir.joinpath(f'{out_prefix}_visual_word.{self.global_epoch}.readable'), 'w') phone_file = self.ckpt_dir.joinpath(f'{out_prefix}_phoneme.{self.global_epoch}.txt') phone_f = open(phone_file, 'w') with torch.no_grad(): B = 0 for b_idx, batch in enumerate(self.data_loader['test']): audios = batch[0] word_labels = batch[2].squeeze(-1) audio_masks = batch[3] word_masks = batch[5] if b_idx > 2 and self.debug: break if b_idx == 0: B = audios.size(0) # (batch size, max segment num, feat dim) or (batch size, max segment num, max segment len, feat dim) x = cuda(audios, self.cuda) # (batch size,) word_labels = cuda(word_labels, self.cuda) # (batch size, max segment num) or (batch size, max segment num, max segment len) audio_masks = cuda(audio_masks, self.cuda) if self.audio_feature == "wav2vec2": B = x.size(0) T = x.size(1) x = self.audio_feature_net.feature_extractor(x.view(B*T, -1)).permute(0, 2, 1) x = x.view(B, T, x.size(-2), x.size(-1)) x = (x * audio_masks.unsqueeze(-1)).sum(-2) / (audio_masks.sum(-1, keepdim=True) + torch.tensor(1e-10, device=x.device)) audio_masks = torch.where(audio_masks.sum(-1) > 0, torch.tensor(1, device=x.device), torch.tensor(0, device=x.device)) # (batch size, max segment num) if audio_masks.dim() == 3: segment_masks = torch.where(audio_masks.sum(-1) > 0, torch.tensor(1., device=audio_masks.device), torch.tensor(0., device=audio_masks.device)) else: segment_masks = audio_masks.clone() if self.audio_net.ds_ratio > 1: audio_masks = audio_masks[:, ::self.audio_net.ds_ratio] segment_masks = segment_masks[:, ::self.audio_net.ds_ratio] # (batch size, max segment num, n visual class) word_logits, quantized, phone_loss = self.audio_net(x, masks=audio_masks) # segment_word_labels = word_labels.unsqueeze(-1)\ # .expand(-1, self.max_segment_num) # segment_word_labels = (segment_word_labels * segment_masks).flatten().long() if self.use_logsoftmax: segment_word_logits = (F.log_softmax(word_logits, dim=-1)\ * segment_masks.unsqueeze(-1)).sum(-2) else: segment_word_logits = (word_logits\ * segment_masks.unsqueeze(-1)).sum(-2) word_loss = F.cross_entropy(segment_word_logits, word_labels, ignore_index=self.ignore_index) loss = phone_loss + word_loss total_loss += loss.cpu().detach().numpy() total_step += 1. _, _, phone_indices = self.audio_net.encode(x, masks=audio_masks) for idx in range(audios.size(0)): global_idx = b_idx * B + idx audio_id = os.path.splitext(os.path.split(testset.dataset[global_idx][0])[1])[0] segments = testset.dataset[global_idx][3] pred_phone_label = testset.unsegment(phone_indices[idx], segments).long() if int(self.hop_len_ms / 10) * self.audio_net.ds_ratio > 1: us_ratio = int(self.hop_len_ms / 10) * self.audio_net.ds_ratio pred_phone_label = pred_phone_label.unsqueeze(-1)\ .expand(-1, us_ratio).flatten() pred_phone_label_list = pred_phone_label.cpu().detach().numpy().tolist() pred_phone_names = ','.join([str(p) for p in pred_phone_label_list]) phone_f.write(f'{audio_id} {pred_phone_names}\n') gold_word_label = word_labels[idx].cpu().detach().numpy().tolist() pred_word_label = segment_word_logits[idx].max(-1)[1].cpu().detach().numpy().tolist() pred_word_label_quantized = quantized[idx].prod(-2).max(-1)[1].cpu().detach().numpy().tolist() gold_word_labels.append(gold_word_label) pred_word_labels.append(pred_word_label) pred_word_labels_quantized.append(pred_word_label_quantized) pred_word_name = preprocessor.to_word_text([pred_word_label])[0] pred_word_name_quantized = preprocessor.to_word_text([pred_word_label_quantized])[0] gold_word_name = preprocessor.to_word_text([gold_word_label])[0] word_readable_f.write(f'Utterance id: {audio_id}\n' f'Gold word label: {gold_word_name}\n' f'Pred word label: {pred_word_name}\n' f'Pred word label by quantizer: {pred_word_name_quantized}\n\n') phone_f.close() word_readable_f.close() avg_loss = total_loss / total_step # Compute word accuracy and word token F1 print('[TEST RESULT]') word_acc = compute_accuracy(gold_word_labels, pred_word_labels) word_prec,\ word_rec,\ word_f1, _ = precision_recall_fscore_support(np.asarray(gold_word_labels), np.asarray(pred_word_labels), average='macro') word_prec_quantized,\ word_rec_quantized,\ word_f1_quantized, _ = precision_recall_fscore_support(np.asarray(gold_word_labels), np.asarray(pred_word_labels_quantized), average='macro') token_f1,\ token_prec,\ token_recall = compute_token_f1(phone_file, gold_phone_file, self.ckpt_dir.joinpath(f'confusion.{self.global_epoch}.png')) info = f'Epoch {self.global_epoch}\tLoss: {avg_loss:.4f}\n'\ f'WER: {1-word_acc:.3f}\tWord Acc.: {word_acc:.3f}\n'\ f'Word Precision: {word_prec:.3f}\tWord Recall: {word_rec:.3f}\tWord F1: {word_f1:.3f}\n'\ f'(By Quantizer) Word Precision: {word_prec_quantized:.3f}\tWord Recall: {word_rec_quantized:.3f}\tWord F1: {word_f1_quantized:.3f}\n'\ f'Token Precision: {token_prec:.3f}\tToken Recall: {token_recall:.3f}\tToken F1: {token_f1:.3f}\n' print(info) save_path = self.ckpt_dir.joinpath('results_file.txt') with open(save_path, 'a') as file: file.write(info) if self.history['word_acc'] < word_acc: self.history['token_result'] = [token_prec, token_recall, token_f1] self.history['word_acc'] = word_acc self.history['loss'] = avg_loss self.history['iter'] = self.global_iter self.history['epoch'] = self.global_epoch self.save_checkpoint(f'best_acc_{self.config.seed}.tar') self.set_mode('train')
def cluster(self, test_loader=None, out_prefix='predictions', save_embedding=False): self.load_checkpoint() X = [] audio_files = [] if test_loader is not None: B = test_loader.batch_size testset = test_loader.dataset split = testset.splits[0] gold_path = os.path.join(testset.data_path, split) else: test_loader = self.data_loader['test'] B = test_loader.batch_size testset = test_loader.dataset split = testset.splits[0] gold_path = os.path.join(testset.data_path, split) embed_path = os.path.join(testset.data_path, f'{split}_embeddings') if save_embedding and not os.path.exists(embed_path): os.makedirs(embed_path) for b_idx, (audios, _, _, audio_masks, _) in enumerate(test_loader): if b_idx > 2 and self.debug: break audios = cuda(audios, self.cuda) audio_masks = cuda(audio_masks, self.cuda) _, _, _, embedding = self.audio_net(audios, masks=audio_masks, return_feat=True) # Concatenate the hidden vector with the input feature concat_embedding = torch.cat([audios.permute(0, 2, 1), embedding], axis=-1) if save_embedding: for idx in range(audios.size(0)): audio_id = os.path.splitext( os.path.split(testset.dataset[b_idx * B + idx][0])[1])[0] np.savetxt(os.path.join(embed_path, f'{audio_id}.txt'), concat_embedding[idx].cpu().detach().numpy()) X.append(concat_embedding.cpu().detach().numpy()) audio_files.extend([ testset.dataset[b_idx * B + i][0] for i in range(audios.size(0)) ]) X = np.concatenate(X, axis=0) shape = X.shape kmeans = KMeans(n_clusters=50).fit(X.reshape(shape[0] * shape[1], -1)) np.save(os.path.join(self.ckpt_dir, f'kmeans_centroids.npy'), kmeans.cluster_centers_) encodings = kmeans.labels_.reshape(shape[0], shape[1]) out_file = os.path.join(self.ckpt_dir, f'{out_prefix}_clustering.txt') out_f = open(out_file, 'w') for idx, (audio_file, encoding) in enumerate(zip(audio_files, encodings)): audio_id = os.path.splitext(os.path.split(audio_file)[1])[0] pred_phonemes = ','.join([str(phn) for phn in encodings[idx]]) out_f.write(f'{audio_id} {pred_phonemes}\n') out_f.close() compute_token_f1(out_file, gold_path, os.path.join(self.ckpt_dir, 'confusion_cluster.png'))
def test(self, save_embedding=False, out_prefix='predictions'): self.set_mode('eval') testset = self.data_loader['test'].dataset total_loss = 0 izy_bound = 0 izx_bound = 0 total_num = 0 seq_list = [] pred_labels = [] gold_labels = [] if not self.ckpt_dir.joinpath('outputs/phonetic/dev-clean').is_dir(): os.makedirs(self.ckpt_dir.joinpath('outputs/phonetic/dev-clean')) gold_path = os.path.join(os.path.join(testset.data_path, 'test/')) out_word_file = os.path.join( self.ckpt_dir, f'{out_prefix}_word.{self.global_epoch}.readable') out_phone_file = os.path.join( self.ckpt_dir, f'{out_prefix}_phoneme.{self.global_epoch}.txt') word_f = open(out_word_file, 'w') word_f.write('Image ID\tGold label\tPredicted label\n') phone_f = open(out_phone_file, 'w') with torch.no_grad(): B = 0 for b_idx, (audios, images, labels, audio_masks, image_masks) in enumerate(self.data_loader['test']): if b_idx > 2 and self.debug: break if b_idx == 0: B = audios.size(0) audios = cuda(audios, self.cuda) labels = cuda(labels, self.cuda) images = cuda(images, self.cuda) audio_masks = cuda(audio_masks, self.cuda) image_masks = cuda(image_masks.unsqueeze(-1), self.cuda) image_size = images.size() ''' if images.ndim in [2, 4]: image_logit, _ = self.image_net(images, return_score=self.image_feature) else: image_logit_flat, _ = self.image_net( images.view(-1, *image_size[2:]), return_score=True ) image_logit = image_logit_flat.view(*image_size[:2], -1) if self.image_feature == 'label': if image_logit.ndim == 2: y = image_logit.max(-1)[1] else: y = F.one_hot(image_logit.max(-1)[1], self.n_class) y = (y * image_masks).max(1)[0] elif self.image_feature == 'multi_label': if image_logit.ndim == 2: y = (image_logit > 0).float() else: y = ((image_logit > 0).float() * image_masks).max(1)[0] ''' y = labels in_logits, logits, encoding, embedding = self.audio_net( audios, masks=audio_masks, return_feat=True) # Word prediction in_logit = in_logits.sum(0) logit = (logits * audio_masks.unsqueeze(-1)).sum(dim=1) for idx in range(audios.size(0)): global_idx = b_idx * B + idx audio_id = os.path.splitext( os.path.split(testset.dataset[global_idx][0])[1])[0] golds = [y[idx].cpu().detach().numpy()] preds = [logit[idx].max(-1)[1].cpu().detach().numpy()] pred_phones = encoding[idx].max(-1)[1] if self.use_segment: pred_phones = testset.unsegment( pred_phones, testset.dataset[global_idx][3]).long() pred_phones = pred_phones.cpu().detach().numpy().tolist() gold_names = ','.join([self.class_names[c] for c in golds]) pred_names = ','.join([self.class_names[c] for c in preds]) pred_phones_str = ','.join( [str(phn) for phn in pred_phones]) word_f.write(f'{audio_id}\t{gold_names}\t{pred_names}\n') phone_f.write(f'{audio_id} {pred_phones_str}\n') pred_labels.append(F.one_hot(y, self.n_class).cpu()) gold_labels.append( F.one_hot(logit.max(-1)[1], self.n_class).cpu()) cur_class_loss = F.cross_entropy(logit, y).div(math.log(2)) cur_info_loss = (F.softmax(in_logit, dim=-1)\ * F.log_softmax(in_logit, dim=-1) ).sum(1).mean().div(math.log(2)) cur_ib_loss = cur_class_loss + self.beta * cur_info_loss izy_bound = izy_bound + y.size(0) * math.log( self.n_class, 2) - cur_class_loss izx_bound = izx_bound + cur_info_loss total_loss += cur_ib_loss.item() total_num += audios.size(0) for idx in range(audios.size(0)): global_idx = b_idx * B + idx example_id = testset.dataset[global_idx][0].split('/')[-1] text = testset.dataset[global_idx][1] feat_id = example_id units = encoding[idx].max(-1)[1] if save_embedding: feat_fn = self.ckpt_dir.joinpath( f'outputs/phonetic/dev-clean/{feat_id}') np.savetxt(feat_fn, audios[idx].cpu().detach().numpy()) seq_list.append((feat_id, feat_fn)) word_f.close() phone_f.close() pred_labels = torch.cat(pred_labels).detach().numpy() gold_labels = torch.cat(gold_labels).detach().numpy() ps, rs, f1s, _ = precision_recall_fscore_support( gold_labels.flatten(), pred_labels.flatten()) p, r, f1 = ps[1], rs[1], f1s[1] class_f1s = np.zeros(self.n_class) for c in range(self.n_class): _, _, class_f1, _ = precision_recall_fscore_support( gold_labels[:, c], pred_labels[:, c]) class_f1s[c] = class_f1[-1] izy_bound /= total_num izx_bound /= total_num avg_loss = total_loss / total_num print('[TEST RESULT]') print('Epoch {}\tLoss: {:.2f}\tPrecision: {:.2f}\tRecall: {:.2f}\tF1: {:.2f}'\ .format(self.global_epoch, avg_loss, p, r, f1)) token_f1, token_prec, token_recall = compute_token_f1( out_phone_file, gold_path, os.path.join(self.ckpt_dir, f'confusion.{self.global_epoch}.png'), ) if self.history['f1'] < f1: self.history['f1'] = f1 self.history['loss'] = avg_loss self.history['epoch'] = self.global_epoch self.history['iter'] = self.global_iter self.history['token_f1'] = token_f1 self.save_checkpoint('best_acc.tar') self.set_mode('train')
def test(self, save_embedding=False, out_prefix='predictions'): self.set_mode('eval') testset = self.data_loader['test'].dataset preprocessor = testset.preprocessor total_loss = 0. total_step = 0. gold_word_labels = [] if not self.ckpt_dir.joinpath('outputs/phonetic/dev-clean').is_dir(): os.makedirs(self.ckpt_dir.joinpath('outputs/phonetic/dev-clean')) gold_phone_file = os.path.join( testset.data_path, f'{testset.splits[0]}/{testset.splits[0]}_nonoverlap.item') gold_visual_phone_file = os.path.join( testset.data_path, f'{testset.splits[0]}/{testset.splits[0]}_visual.item') phone_file = self.ckpt_dir.joinpath( f'{out_prefix}_phoneme.{self.global_epoch}.txt') visual_phone_file = self.ckpt_dir.joinpath( f'{out_prefix}_visual_phoneme.{self.global_epoch}.txt') phone_f = open(phone_file, 'w') visual_phone_f = open(visual_phone_file, 'w') phone_readable_f = open( self.ckpt_dir.joinpath( f'{out_prefix}_phoneme.{self.global_epoch}.readable'), 'w') visual_phone_readable_f = open( self.ckpt_dir.joinpath( f'{out_prefix}_visual_phoneme.{self.global_epoch}.readable'), 'w') with torch.no_grad(): B = 0 for b_idx, (audios, phoneme_labels, word_labels,\ audio_masks, phone_masks, word_masks)\ in enumerate(self.data_loader['test']): if b_idx > 2 and self.debug: break if b_idx == 0: B = audios.size(0) audios = cuda(audios, self.cuda) if self.audio_feature == 'wav2vec2': x = self.audio_feature_net.feature_extractor(audios) else: x = audios word_labels = cuda(word_labels, self.cuda) phoneme_labels = cuda(phoneme_labels, self.cuda) audio_masks = cuda(audio_masks, self.cuda) phone_masks = cuda(phone_masks, self.cuda) word_masks = cuda(word_masks, self.cuda) audio_lens = audio_masks.sum(-1).long() sent_lens = phone_masks.sum(-1).long() word_lens = word_masks.sum(dim=(-2, -1)).long() word_num = (word_labels >= 0).long().sum(-1) cluster_logits, embedding = self.audio_net(x, return_feat=True) phoneme_labels_aligned = self.align_net( F.one_hot(phoneme_labels * phone_masks.long(), self.n_phone_class), phone_masks, audio_masks) cluster_probs = F.softmax(cluster_logits, dim=-1)\ .view(-1, self.max_feat_len, self.n_phone_class) if self.max_normalize: cluster_probs = cluster_probs / cluster_probs.max( -1, keepdim=True)[0] phone_label_mask = F.one_hot( phoneme_labels * phone_masks.long(), self.n_phone_class).sum(1, keepdim=True) phone_label_mask = (phone_label_mask > 0).long() if self.loss_type == 'binary_cross_entropy': loss = self.weight_phone_loss * self.criterion( cluster_probs, phoneme_labels_aligned * audio_masks.unsqueeze(-1)) else: loss = self.weight_phone_loss * self.criterion( cluster_probs, phoneme_labels_aligned, audio_masks) word_cluster_logits = torch.matmul(word_masks, cluster_logits.unsqueeze(1)) word_cluster_probs = F.softmax(word_cluster_logits, dim=-1)\ .view(-1, self.max_word_len, self.n_phone_class) if self.max_normalize: word_cluster_probs = word_cluster_probs / word_cluster_probs.max( -1, keepdim=True)[0] word_phone_probs = self.phone_net(word_labels.flatten()) word_phone_probs = word_phone_probs\ .unsqueeze(1).expand(-1, self.max_word_len, -1) if self.max_normalize: word_phone_probs = word_phone_probs / ( word_phone_probs.max(-1, keepdim=True)[0] + EPS) if self.loss_type == 'binary_cross_entropy': loss = loss + self.weight_word_loss * self.criterion( word_cluster_probs, word_phone_probs * word_masks.sum(-1).view(-1, self.max_word_len, 1)) else: loss = loss + self.weight_word_loss * self.criterion( word_cluster_probs, word_phone_probs, word_masks.sum(-1).view(-1, self.max_word_len)) total_loss += loss.cpu().detach().numpy() total_step += 1. for idx in range(audios.size(0)): global_idx = b_idx * B + idx audio_id = os.path.splitext( os.path.split(testset.dataset[global_idx][0])[1])[0] gold_phone_label = phoneme_labels[idx, :sent_lens[idx]] pred_phone_label = cluster_logits[ idx, :audio_lens[idx]].max(-1)[1] gold_phone_names = ','.join( preprocessor.to_text(gold_phone_label)) pred_phone_names = ','.join( preprocessor.tokens_to_text(pred_phone_label)) phone_readable_f.write( f'Utterance id: {audio_id}\n' f'Gold transcript: {gold_phone_names}\n' f'Pred transcript: {pred_phone_names}\n\n') us_ratio = int( self.hop_len_ms / 10) * self.audio_net.ds_ratio > 1 if us_ratio > 1: pred_phone_label = pred_phone_label.unsqueeze(-1)\ .expand(-1, us_ratio).flatten() pred_phone_label_list = pred_phone_label.cpu().detach( ).numpy().tolist() pred_phone_names = ','.join( [str(p) for p in pred_phone_label_list]) phone_f.write(f'{audio_id} {pred_phone_names}\n') if word_num[idx] > 0: pred_phone_word_masked_label = pred_phone_label\ * word_masks[idx, :, :, :audio_lens[idx]].long().sum(dim=(0, 1)) pred_phone_word_masked_label_list = pred_phone_word_masked_label.detach( ).cpu().numpy().tolist() pred_phone_word_masked_names = ','.join([ str(p) for p in pred_phone_word_masked_label_list ]) visual_phone_f.write( f'{audio_id} {pred_phone_word_masked_names}\n') gold_word_label = word_labels[idx, :word_num[idx]].cpu( ).detach().numpy().tolist() gold_word_names = preprocessor.to_word_text( gold_word_label) pred_visual_phone_label = word_cluster_logits[ idx, :word_num[idx]].max(-1)[1] for word_idx in range(word_num[idx]): gold_word_name = gold_word_names[word_idx] pred_label = pred_visual_phone_label[ word_idx, :word_lens[idx, word_idx]] pred_label = pred_label.cpu().detach().numpy( ).tolist() pred_visual_phone_names = ','.join( preprocessor.tokens_to_text(pred_label)) visual_phone_readable_f.write( f'Utterance id: {audio_id}\n' f'Gold word label: {gold_word_name}\n' f'Pred transcript: {pred_visual_phone_names}\n\n' ) phone_f.close() visual_phone_f.close() phone_readable_f.close() visual_phone_readable_f.close() avg_loss = total_loss / total_step # Token F1 token_f1,\ token_prec,\ token_recall = compute_token_f1(phone_file, gold_phone_file, self.ckpt_dir.joinpath(f'confusion.{self.global_epoch}.png')) visual_token_f1,\ visual_token_prec,\ visual_token_recall = compute_token_f1(visual_phone_file, gold_visual_phone_file, self.ckpt_dir.joinpath(f'visual_confusion.{self.global_epoch}.png')) print('[TEST RESULT]') print( f'Epoch {self.global_epoch}\tLoss: {avg_loss:.4f}\tToken F1: {token_f1:.3f}\tVisual Token F1: {visual_token_f1:.3f}' ) if self.history['visual_token_f1'] < visual_token_f1: self.history['token_f1'] = token_f1 self.history['visual_token_f1'] = visual_token_f1 self.history['loss'] = avg_loss self.history['iter'] = self.global_iter self.history['epoch'] = self.global_epoch self.save_checkpoint('best_acc.tar') self.set_mode('train')
def cluster(self, n_clusters=44, out_prefix='quantized_outputs'): self.set_mode('eval') testset = self.data_loader['test'].dataset us_ratio = int(self.hop_len_ms / 10) * self.audio_net.ds_ratio with torch.no_grad(): B = 0 utt_ids_dict = dict() X_dict = dict() segment_dict = dict() for split in self.data_loader: X_dict[split] = [] utt_ids_dict[split] = [] segment_dict[split] = dict() for b_idx, batch in enumerate(self.data_loader[split]): audios = batch[0] audio_masks = batch[3] audios = cuda(audios, self.cuda) audio_masks = cuda(audio_masks, self.cuda) audio_lens = audio_masks.sum(-1) if b_idx > 2 and self.debug: break if b_idx == 0: B = audios.size(0) if self.audio_feature == 'wav2vec2': x = self.audio_feature_net.feature_extractor(audios) else: x = audios audio_lens = audio_masks.sum(-1).long() # XXX _, embedding = self.audio_net(x, # masks=audio_masks, # return_feat=True) embedding = x for idx in range(audios.size(0)): global_idx = b_idx * B + idx utt_id = os.path.splitext( os.path.split(self.data_loader[split].dataset. dataset[global_idx][0])[1])[0] embed = embedding[ idx, :audio_lens[idx]].cpu().detach().numpy() X_dict[split].extend(embed.tolist()) utt_ids_dict[split].extend([utt_id] * embed.shape[0]) segment_dict[split][utt_id] = self.data_loader[ split].dataset.dataset[global_idx][3] X_dict[split] = np.asarray(X_dict[split]) X = np.concatenate([X_dict[split] for split in self.data_loader]) begin_time = time.time() clusterer = KMeans(n_clusters=n_clusters).fit(X) print(f'KMeans take {time.time()-begin_time} s to finish') np.save(self.ckpt_dir.joinpath('cluster_means.npy'), clusterer.cluster_centers_) ys = clusterer.predict(X_dict['test']) utt_ids = utt_ids_dict['test'] segments = segment_dict['test'] filename = self.ckpt_dir.joinpath(out_prefix + '.txt') out_f = open(filename, 'w') for utt_id, group in groupby(list(zip(utt_ids, ys)), lambda x: x[0]): y = torch.LongTensor([g[1] for g in group]) y_unseg = testset.unsegment( y, segments[utt_id]).long().cpu().detach().numpy().tolist() y_str = ','.join([str(l) for l in y_unseg]) out_f.write(f'{utt_id} {y_str}\n') out_f.close() gold_path = os.path.join( os.path.join(testset.data_path, f'{testset.splits[0]}')) token_f1, token_prec, token_recall = compute_token_f1( filename, gold_path, self.ckpt_dir.joinpath(f'confusion.png'), ) info = f'Token Precision: {token_prec:.3f}\tToken Recall: {token_recall:.3f}\tToken F1: {token_f1:.3f}\n' save_path = self.ckpt_dir.joinpath( f'results_file_{self.config.seed}.txt') with open(save_path, 'a') as file: file.write(info) if self.history['token_f1'] < token_f1: self.history['token_f1'] = token_f1
def test(self, save_embedding=False, out_prefix='predictions'): self.set_mode('eval') testset = self.data_loader['test'].dataset preprocessor = testset.preprocessor total_loss = 0. total_num = 0. gold_labels = [] pred_labels = [] if not self.ckpt_dir.joinpath('outputs/phonetic/dev-clean').is_dir(): os.makedirs(self.ckpt_dir.joinpath('outputs/phonetic/dev-clean')) gold_path = os.path.join(os.path.join(testset.data_path, f'{testset.splits[0]}')) out_word_file = os.path.join( self.ckpt_dir, f'{out_prefix}_word.{self.global_epoch}.readable' ) out_phone_readable_file = os.path.join( self.ckpt_dir, f'{out_prefix}_phoneme.{self.global_epoch}.readable' ) out_phone_file = os.path.join( self.ckpt_dir, f'{out_prefix}_phoneme.{self.global_epoch}.txt' ) word_f = open(out_word_file, 'w') word_f.write('Image ID\tGold label\tPredicted label\n') phone_readable_f = open(out_phone_readable_file, 'w') phone_f = open(out_phone_file, 'w') gold_word_labels = [] gold_phone_labels = [] pred_word_labels = [] pred_phone_labels = [] with torch.no_grad(): B = 0 for b_idx, (audios, phoneme_labels, word_labels,\ audio_masks, phone_masks, word_masks)\ in enumerate(self.data_loader['test']): if b_idx > 2 and self.debug: break if b_idx == 0: B = audios.size(0) audios = cuda(audios, self.cuda) if self.audio_feature == 'wav2vec2': audios = self.audio_feature_net.feature_extractor(audios) phoneme_labels = cuda(phoneme_labels, self.cuda) word_labels = cuda(word_labels, self.cuda) audio_masks = cuda(audio_masks, self.cuda) phone_masks = cuda(phone_masks, self.cuda) word_masks = cuda(word_masks, self.cuda).sum(-2) # Use averaged frame embedding by default if self.audio_net.ds_ratio > 1: audio_masks = audio_masks[:, ::self.audio_net.ds_ratio] word_masks = word_masks[:, :, ::self.audio_net.ds_ratio] audio_lens = audio_masks.sum(-1).long() sent_lens = phone_masks.sum(-1).long() word_lens = (word_labels >= 0).long().sum(-1) if self.model_type == 'blstm': out_logits, embedding = self.audio_net(audios, return_feat=True) word_logits = out_logits[:, :, :self.n_visual_class] phone_logits = out_logits[:, :, self.n_visual_class:] else: gumbel_logits, out_logits, encoding, embedding = self.audio_net( audios, masks=audio_masks, return_feat=True) phone_logits = gumbel_logits word_logits = out_logits if self.model_type == 'vq-mlp': word_logits = out_logits[:, :, :self.n_visual_class] word_logits = torch.matmul(word_masks, word_logits) word_loss = F.cross_entropy(word_logits.permute(0, 2, 1), word_labels, ignore_index=-100)\ .div(math.log(2)) phone_loss = F.ctc_loss(F.log_softmax(phone_logits, dim=-1)\ .permute(1, 0, 2), phoneme_labels, audio_lens, sent_lens) info_loss = (F.softmax(phone_logits, dim=-1)\ * F.log_softmax(phone_logits, dim=-1) ).sum().div(sent_lens.sum() * math.log(2)) total_loss += word_loss + phone_loss + self.beta * info_loss total_num += 1. for idx in range(audios.size(0)): global_idx = b_idx * B + idx audio_id = os.path.splitext(os.path.split(testset.dataset[global_idx][0])[1])[0] if word_lens[idx] > 0: gold_words = word_labels[idx, :word_lens[idx]] pred_words = word_logits[idx, :word_lens[idx]].max(-1)[1] gold_words = gold_words.cpu().detach().numpy().tolist() pred_words = pred_words.cpu().detach().numpy().tolist() gold_word_labels.append(gold_words) pred_word_labels.append(pred_words) gold_word_names = ','.join(preprocessor.to_word_text( gold_words) ) pred_word_names = ','.join(preprocessor.to_word_text( pred_words) ) word_f.write(f'{audio_id}\t{gold_word_names}\t{pred_word_names}\n') feat_fn = self.ckpt_dir.joinpath(f'outputs/phonetic/dev-clean/{audio_id}.txt') if save_embedding: np.savetxt(feat_fn, embedding[idx, :audio_lens[idx]][::2].cpu().detach().numpy()) # XXX gold_phone_label = phoneme_labels[idx, :sent_lens[idx]] pred_phone_label = phone_logits[idx, :audio_lens[idx]].max(-1)[1] gold_phone_names = ','.join(preprocessor.to_text(gold_phone_label)) pred_phone_names = ','.join(preprocessor.tokens_to_text(pred_phone_label)) phone_readable_f.write(f'Utterance id: {audio_id}\n' f'Gold transcript: {gold_phone_names}\n' f'Pred transcript: {pred_phone_names}\n\n') gold_phone_label = gold_phone_label.cpu().detach().numpy().tolist() if int(self.hop_len_ms / 10) * self.audio_net.ds_ratio > 1: us_ratio = int(self.hop_len_ms / 10) * self.audio_net.ds_ratio pred_phone_label = pred_phone_label.unsqueeze(-1)\ .expand(-1, us_ratio).flatten() pred_phone_label = pred_phone_label.cpu().detach().numpy().tolist() gold_phone_labels.append(gold_phone_label) pred_phone_labels.append(pred_phone_label) pred_phone_label = preprocessor.to_index(preprocessor.to_text(pred_phone_label)) pred_phone_label = pred_phone_label.cpu().detach().numpy().tolist() pred_phone_names = ','.join([str(p) for p in pred_phone_label]) phone_f.write(f'{audio_id} {pred_phone_names}\n') word_f.close() phone_f.close() avg_loss = total_loss / total_num acc = compute_accuracy(gold_word_labels, pred_word_labels) dist, n_tokens = compute_edit_distance(pred_phone_labels, gold_phone_labels, preprocessor) pter = float(dist) / float(n_tokens) print('[TEST RESULT]') print('Epoch {}\tLoss: {:.4f}\tWord Acc.: {:.3f}\tPTER: {:.3f}'\ .format(self.global_epoch, avg_loss, acc, pter)) token_f1, token_prec, token_recall = compute_token_f1( out_phone_file, gold_path, os.path.join( self.ckpt_dir, f'confusion.{self.global_epoch}.png' ) ) if self.history['acc'] < acc: self.history['acc'] = acc self.history['loss'] = avg_loss self.history['epoch'] = self.global_epoch self.history['iter'] = self.global_iter self.history['token_f1'] = token_f1 self.save_checkpoint('best_acc.tar') self.set_mode('train')
def cluster(self, n_clusters=44, out_prefix='quantized_outputs'): self.set_mode('eval') testset = self.data_loader['test'].dataset preprocessor = testset.preprocessor save_path = self.ckpt_dir.joinpath(f'results_file_{self.config.seed}.txt') us_ratio = int(self.hop_len_ms / 10) * self.audio_net.ds_ratio with torch.no_grad(): B = 0 utt_ids_dict = dict() X_dict = dict() segment_dict = dict() for split in self.data_loader: X_dict[split] = [] utt_ids_dict[split] = [] segment_dict[split] = dict() for b_idx, batch in enumerate(self.data_loader[split]): audios = batch[0] audio_masks = batch[3] if b_idx > 2 and self.debug: break if b_idx == 0: B = audios.size(0) x = cuda(audios, self.cuda) audio_masks = cuda(audio_masks, self.cuda) if self.audio_feature == 'wav2vec2': B = x.size(0) T = x.size(1) x = self.audio_feature_net.feature_extractor(x.view(B*T, -1)).permute(0, 2, 1) x = x.view(B, T, x.size(-2), x.size(-1)) x = (x * audio_masks.unsqueeze(-1)).sum(-2) / (audio_masks.sum(-1, keepdim=True) + torch.tensor(1e-10, device=x.device)) audio_masks = torch.where(audio_masks.sum(-1) > 0, torch.tensor(1, device=x.device), torch.tensor(0, device=x.device)) # (batch size, max segment num) if audio_masks.dim() == 3: segment_masks = torch.where(audio_masks.sum(-1) > 0, torch.tensor(1., device=audio_masks.device), torch.tensor(0., device=audio_masks.device)) else: segment_masks = audio_masks.clone() audio_lens = segment_masks.sum(-1).long() if self.audio_net.ds_ratio > 1: audio_masks = audio_masks[:, ::self.audio_net.ds_ratio] # (batch size, max segment num, n visual class) word_logits, quantized, phone_loss = self.audio_net(x, masks=audio_masks) word_probs = F.softmax(word_logits, dim=-1) for idx in range(audios.size(0)): global_idx = b_idx * B + idx utt_id = os.path.splitext(os.path.split(self.data_loader[split].dataset.dataset[global_idx][0])[1])[0] embed = word_probs[idx, :audio_lens[idx]].cpu().detach().numpy() X_dict[split].extend(embed.tolist()) utt_ids_dict[split].extend([utt_id]*embed.shape[0]) segment_dict[split][utt_id] = self.data_loader[split].dataset.dataset[global_idx][3] X_dict[split] = np.asarray(X_dict[split]) X = np.concatenate([X_dict[split] for split in self.data_loader]) begin_time = time.time() self.clusterer.fit(X) info = f'Epoch {self.global_epoch}\tClustering Time: {time.time()-begin_time:.2f} s' print(info) with open(save_path, 'a') as file: file.write(info+'\n') ys = self.clusterer.predict(X_dict['test']) utt_ids = utt_ids_dict['test'] segments = segment_dict['test'] filename = self.ckpt_dir.joinpath(out_prefix+'.txt') out_f = open(filename, 'w') for utt_id, group in groupby(list(zip(utt_ids, ys)), lambda x:x[0]): y = torch.LongTensor([g[1] for g in group]) y_unseg = testset.unsegment(y, segments[utt_id]).long().cpu().detach().numpy().tolist() y_str = ','.join([str(l) for l in y_unseg]) out_f.write(f'{utt_id} {y_str}\n') out_f.close() np.save(self.ckpt_dir.joinpath('cluster_means.npy'), self.clusterer.cluster_centers_) gold_path = os.path.join(os.path.join(testset.data_path, f'{testset.splits[0]}')) print('[CLUSTERING RESULT]') token_f1, token_prec, token_recall = compute_token_f1( filename, gold_path, self.ckpt_dir.joinpath(f'confusion.png'), ) info = f'Token Precision: {token_prec:.3f}\tToken Recall: {token_recall:.3f}\tToken F1: {token_f1:.3f}' with open(save_path, 'a') as file: file.write(info+'\n') if self.history['token_f1'] < token_f1: self.history['token_f1'] = token_f1 self.history['best_cluster_epoch'] = self.global_epoch self.history['best_cluster_iter'] = self.global_iter self.save_checkpoint(f'best_token_f1_{self.config.seed}.tar') self.set_mode('train')
def test(self, save_embedding=False, out_prefix='predictions'): self.set_mode('eval') testset = self.data_loader['test'].dataset preprocessor = testset.preprocessor total_loss = 0. total_word_loss = 0. total_phone_loss = 0. total_step = 0. pred_word_labels = [] gold_word_labels = [] if not self.ckpt_dir.joinpath('outputs/phonetic/dev-clean').is_dir(): os.makedirs(self.ckpt_dir.joinpath('outputs/phonetic/dev-clean')) gold_phone_file = os.path.join( testset.data_path, f'{testset.splits[0]}/{testset.splits[0]}_nonoverlap.item') word_readable_f = open( self.ckpt_dir.joinpath( f'{out_prefix}_visual_word.{self.global_epoch}.readable'), 'w') phone_file = self.ckpt_dir.joinpath( f'{out_prefix}_phoneme.{self.global_epoch}.txt') phone_f = open(phone_file, 'w') with torch.no_grad(): B = 0 for b_idx, batch in enumerate(self.data_loader['test']): audios = batch[0] word_labels = batch[2] audio_masks = batch[3] word_masks = batch[5] if b_idx > 2 and self.debug: break if b_idx == 0: B = audios.size(0) x = cuda(audios, self.cuda) if self.audio_feature == "wav2vec2": x = self.audio_feature_net.feature_extractor(x) word_labels = cuda(word_labels, self.cuda) audio_masks = cuda(audio_masks, self.cuda) # (batch size, max word num, max word len, max audio len) word_masks = cuda(word_masks, self.cuda) # (batch size x max word num, max word len) word_masks_2d = torch.where( word_masks.sum(-1).view(-1, self.max_word_len) > 0, torch.tensor(1, device=x.device), torch.tensor(0, device=x.device)) # (batch size x max word num) word_masks_1d = torch.where( word_masks_2d.sum(-1) > 0, torch.tensor(1, device=x.device), torch.tensor(0, device=x.device)) # images = cuda(images, self.cuda) if self.audio_net.ds_ratio > 1: audio_masks = audio_masks[:, ::self.audio_net.ds_ratio] word_masks = word_masks[:, :, ::self.audio_net.ds_ratio] audio_lens = audio_masks.sum(-1).long() # (batch size, max word num) word_lens = word_masks.sum(dim=(-1, -2)).long() # (batch size,) word_nums = torch.where(word_lens > 0, torch.tensor(1, device=x.device), torch.tensor(0, device=x.device)).sum(-1) sent_phone_logits,\ word_logits,\ _,\ embedding = self.audio_net(x, masks=audio_masks, return_feat=True) # (batch size, max word num, max word len, n phone class) phone_logits = torch.matmul(word_masks, sent_phone_logits.unsqueeze(1)) # (batch size x max word num, max word len, n phone class) phone_logits = phone_logits.view(-1, self.max_word_len, self.n_phone_class) phone_probs = F.softmax(phone_logits, dim=-1) # (batch size x max word num, n phone class) pm_phone_probs = self.audio_net.reverse_forward( word_labels.flatten(), ignore_index=self.ignore_index) # (batch size x max word num, max word len, n phone class) pm_phone_probs = pm_phone_probs\ .unsqueeze(1).expand(-1, self.max_word_len, -1) # (batch size, max word num, max word len, n word class) word_logits = torch.matmul(word_masks, word_logits.unsqueeze(1)) # (batch size, max word num, n word class) word_logits = ( word_logits * word_masks.sum(-1, keepdim=True) ).sum( -2 ) # / (word_lens.unsqueeze(-1) + EPS) # Average over frames # (batch size x max word num, n word class) word_logits_2d = word_logits.view(-1, self.n_visual_class) word_probs = F.softmax(word_logits_2d, dim=-1) if self.loss_type == 'cross_entropy': # phone_loss = -((pm_phone_probs * word_masks_2d.unsqueeze(-1)) *\ # torch.log_softmax(phone_logits, dim=-1)).mean() phone_loss = self.criterion(phone_probs, pm_phone_probs, word_masks_2d) word_loss = F.cross_entropy(word_logits_2d, word_labels.flatten(), ignore_index=self.ignore_index) else: phone_loss = self.criterion(phone_probs, pm_phone_probs, word_masks_2d) if self.max_normalize: word_probs = word_probs / word_probs.max( -1, keepdim=True)[0] # (batch size x max word num, n word class) word_labels_onehot = F.one_hot(word_labels, self.n_visual_class) word_labels_onehot = word_masks_1d.unsqueeze( -1) * word_labels_onehot word_loss = self.criterion( word_probs, word_labels_onehot, torch.ones(audio_lens.size(0), device=x.device)) loss = word_loss + phone_loss total_loss += loss.cpu().detach().numpy() total_word_loss += word_loss.cpu().detach().numpy() total_phone_loss += phone_loss.cpu().detach().numpy() total_step += 1. for idx in range(audios.size(0)): global_idx = b_idx * B + idx audio_id = os.path.splitext( os.path.split(testset.dataset[global_idx][0])[1])[0] pred_phone_label = sent_phone_logits[ idx, :audio_lens[idx]].max(-1)[1] pred_phone_label_list = pred_phone_label.cpu().detach( ).numpy().tolist() pred_phone_names = ','.join( [str(p) for p in pred_phone_label_list]) phone_f.write(f'{audio_id} {pred_phone_names}\n') if word_nums[idx] > 0: gold_word_label = word_labels[ idx, :word_nums[idx]].cpu().detach().numpy( ).tolist() pred_word_label = word_logits[ idx, :word_nums[idx]].max( -1)[1].cpu().detach().numpy().tolist() gold_word_labels.extend(gold_word_label) pred_word_labels.extend(pred_word_label) pred_word_names = preprocessor.to_word_text( pred_word_label) gold_word_names = preprocessor.to_word_text( gold_word_label) for word_idx in range(word_nums[idx]): pred_word_name = pred_word_names[word_idx] gold_word_name = gold_word_names[word_idx] word_readable_f.write( f'Utterance id: {audio_id}\n' f'Gold word label: {gold_word_name}\n' f'Pred word label: {pred_word_name}\n\n') phone_f.close() word_readable_f.close() avg_loss = total_loss / total_step avg_word_loss = total_word_loss / total_step avg_phone_loss = total_phone_loss / total_step # Compute word accuracy and word token F1 print('[TEST RESULT]') word_acc = compute_accuracy(gold_word_labels, pred_word_labels) word_prec,\ word_rec,\ word_f1, _ = precision_recall_fscore_support(np.asarray(gold_word_labels), np.asarray(pred_word_labels), average='macro') print( f'Epoch {self.global_epoch}\tLoss: {avg_loss:.4f}\tWord Loss: {avg_word_loss}\tPhone Loss: {avg_phone_loss}\n' f'WER: {1-word_acc:.3f}\tWord Acc.: {word_acc:.3f}\n' f'Word Precision: {word_prec:.3f}\tWord Recall: {word_rec:.3f}\tWord F1: {word_f1:.3f}' ) token_f1,\ token_prec,\ token_recall = compute_token_f1(phone_file, gold_phone_file, self.ckpt_dir.joinpath(f'confusion.{self.global_epoch}.png')) if self.history['word_acc'] < word_acc: self.history['token_f1'] = token_f1 self.history['word_acc'] = word_acc self.history['loss'] = avg_loss self.history['iter'] = self.global_iter self.history['epoch'] = self.global_epoch self.save_checkpoint(f'best_acc_{self.config.seed}.tar') self.set_mode('train')