예제 #1
0
    def phone_level_cluster(self, out_prefix='predictions'):
        self.load_checkpoint()
        X_a = np.zeros((self.n_class, self.K))
        norm = np.zeros((self.n_class, 1))
        audio_files = []
        encodings = []
        # Find the centroid of each phone-level cluster
        B = self.data_loader['test'].batch_size
        testset = self.data_loader['test'].dataset
        for b_idx, (audios, _, _, audio_masks,
                    _) in enumerate(self.data_loader['test']):
            if b_idx > 2 and self.debug:
                break
            audios = cuda(audios, self.cuda)
            audio_masks = cuda(audio_masks, self.cuda)
            _, _, encoding, embedding = self.audio_net(audios,
                                                       mask=audio_masks,
                                                       return_feat=True)
            encoding = encoding.permute(0, 2, 1).cpu().detach().numpy()
            embedding = embedding.cpu().detach().numpy()
            X_a += encoding @ embedding
            norm += encoding.sum(axis=-1, keepdims=True)
            audio_files.extend([
                testset.dataset[b_idx * B + i][0]
                for i in range(audios.size(0))
            ])
            encodings.append(encoding.T)
        encodings = np.concatenate(encodings)

        X_a /= norm
        X_s = self.audio_net.bottleneck.weight +\
                     self.audio_net.bottleneck.bias
        X = np.concatenate([X_a, X_s], axis=1)

        kmeans = KMeans(n_clusters=50).fit(X)
        phoneme_labels = kmeans.labels_

        out_file = os.path.join(self.ckpt_dir,
                                f'{out_prefix}_phone_level_clustering.txt')
        out_f = open(out_file, 'w')
        pred_phones = encodings.max(-1)[0]
        for idx, (audio_file,
                  encoding) in enumerate(zip(audio_files, encodings)):
            audio_id = os.path.splitext(os.path.split(audio_file)[1])[0]
            pred_phonemes = ','.join(
                [str(phoneme_labels[phn]) for phn in pred_phones[idx]])
            out_f.write('{audio_id} {pred_phonemes}\n')
        out_f.close()

        gold_path = os.path.join(os.path.join(testset.data_path, 'test/'))
        compute_token_f1(
            out_file, gold_path,
            os.path.join(self.ckpt_dir, 'confusion_phone_level_cluster.png'))
예제 #2
0
  def cluster(self,
              n_clusters=50,
              out_prefix='quantized_outputs'):
    self.set_mode('eval')
    testset = self.data_loader['test'].dataset
    temp = self.history['temp']

    us_ratio = int(self.hop_len_ms / 10) * self.audio_net.ds_ratio 
    with torch.no_grad():
      B = 0
      utt_ids = []
      X = []      
      for b_idx, (audios, phoneme_labels, word_labels,\
                  audio_masks, phone_masks, word_masks)\
                  in enumerate(self.data_loader['test']):
        if b_idx > 2 and self.debug:
          break
        if b_idx == 0:
          B = audios.size(0)

        
        audios = cuda(audios, self.cuda)
        audio_masks = cuda(audio_masks, self.cuda)
        if self.audio_feature == 'wav2vec2':
          x = self.audio_feature_net.feature_extractor(audios)
        else:
          x = audios
        
        audio_lens = audio_masks.sum(-1).long()
        outputs = self.audio_net(x, return_feat=True)
        embedding = outputs[-1]

        for idx in range(audios.size(0)): 
          global_idx = b_idx * B + idx
          utt_id = os.path.splitext(os.path.basename(testset.dataset[global_idx][0]))[0] 
          embed = embedding[idx, :audio_lens[idx]].cpu().detach().numpy()
          X.extend(embed.tolist())
          utt_ids.extend([utt_id]*embed.shape[0])

      X = np.asarray(X)
      begin_time = time.time()
      clusterer = KMeans(n_clusters=n_clusters).fit(X) 
      print(f'KMeans take {time.time()-begin_time} s to finish')
      np.save(self.ckpt_dir.joinpath('cluster_means.npy'), clusterer.cluster_centers_)
      
      ys = clusterer.predict(X)
      filename = self.ckpt_dir.joinpath(out_prefix+'.txt')
      out_f = open(filename, 'w')
      for utt_id, group in groupby(list(zip(utt_ids, ys)), lambda x:x[0]):
        y = ','.join([str(g[1]) for g in group for _ in range(us_ratio)])
        out_f.write(f'{utt_id} {y}\n') 
      out_f.close()
      gold_path = os.path.join(os.path.join(testset.data_path, f'{testset.splits[0]}'))
      token_f1, token_prec, token_recall = compute_token_f1(
                                             filename,
                                             gold_path,
                                             self.ckpt_dir.joinpath(f'confusion.png'),
                                           )
예제 #3
0
  def test(self, save_embedding=False, out_prefix='predictions'):
    self.set_mode('eval')
    testset = self.data_loader['test'].dataset
    preprocessor = testset.preprocessor

    total_loss = 0.
    total_step = 0.

    pred_word_labels = []
    pred_word_labels_quantized = []
    gold_word_labels = []
    if not self.ckpt_dir.joinpath('outputs/phonetic/dev-clean').is_dir():
      os.makedirs(self.ckpt_dir.joinpath('outputs/phonetic/dev-clean'))

    gold_phone_file = os.path.join(testset.data_path, f'{testset.splits[0]}/{testset.splits[0]}_nonoverlap.item')
    word_readable_f = open(self.ckpt_dir.joinpath(f'{out_prefix}_visual_word.{self.global_epoch}.readable'), 'w') 
    phone_file = self.ckpt_dir.joinpath(f'{out_prefix}_phoneme.{self.global_epoch}.txt')
    phone_f = open(phone_file, 'w')

    with torch.no_grad():
      B = 0
      for b_idx, batch in enumerate(self.data_loader['test']):        
        audios = batch[0]
        word_labels = batch[2].squeeze(-1)
        audio_masks = batch[3]
        word_masks = batch[5]
        if b_idx > 2 and self.debug:
          break
        if b_idx == 0: 
          B = audios.size(0)
 
        # (batch size, max segment num, feat dim) or (batch size, max segment num, max segment len, feat dim)
        x = cuda(audios, self.cuda)

        # (batch size,)
        word_labels = cuda(word_labels, self.cuda)

        # (batch size, max segment num) or (batch size, max segment num, max segment len)
        audio_masks = cuda(audio_masks, self.cuda)

        if self.audio_feature == "wav2vec2":
          B = x.size(0)
          T = x.size(1) 
          x = self.audio_feature_net.feature_extractor(x.view(B*T, -1)).permute(0, 2, 1)
          x = x.view(B, T, x.size(-2), x.size(-1))
          x = (x * audio_masks.unsqueeze(-1)).sum(-2) / (audio_masks.sum(-1, keepdim=True) + torch.tensor(1e-10, device=x.device))
          audio_masks = torch.where(audio_masks.sum(-1) > 0,
                                    torch.tensor(1, device=x.device),
                                    torch.tensor(0, device=x.device))
          
        # (batch size, max segment num)
        if audio_masks.dim() == 3: 
          segment_masks = torch.where(audio_masks.sum(-1) > 0,
                                      torch.tensor(1., device=audio_masks.device),
                                      torch.tensor(0., device=audio_masks.device))
        else:
          segment_masks = audio_masks.clone()
             
        if self.audio_net.ds_ratio > 1:
          audio_masks = audio_masks[:, ::self.audio_net.ds_ratio]
          segment_masks = segment_masks[:, ::self.audio_net.ds_ratio]

        # (batch size, max segment num, n visual class)
        word_logits, quantized, phone_loss = self.audio_net(x, masks=audio_masks)

        # segment_word_labels = word_labels.unsqueeze(-1)\
        #                                  .expand(-1, self.max_segment_num)
        # segment_word_labels = (segment_word_labels * segment_masks).flatten().long()
        if self.use_logsoftmax:
          segment_word_logits = (F.log_softmax(word_logits, dim=-1)\
                                * segment_masks.unsqueeze(-1)).sum(-2)
        else:
          segment_word_logits = (word_logits\
                                * segment_masks.unsqueeze(-1)).sum(-2)

        word_loss = F.cross_entropy(segment_word_logits,
                               word_labels,
                               ignore_index=self.ignore_index)
        loss = phone_loss + word_loss
        total_loss += loss.cpu().detach().numpy()
        total_step += 1.
        
        _, _, phone_indices = self.audio_net.encode(x, masks=audio_masks)
        for idx in range(audios.size(0)):
          global_idx = b_idx * B + idx
          audio_id = os.path.splitext(os.path.split(testset.dataset[global_idx][0])[1])[0]
          segments = testset.dataset[global_idx][3]
          pred_phone_label = testset.unsegment(phone_indices[idx], segments).long()

          if int(self.hop_len_ms / 10) * self.audio_net.ds_ratio > 1:
            us_ratio = int(self.hop_len_ms / 10) * self.audio_net.ds_ratio
            pred_phone_label = pred_phone_label.unsqueeze(-1)\
                               .expand(-1, us_ratio).flatten()

          pred_phone_label_list = pred_phone_label.cpu().detach().numpy().tolist()
          pred_phone_names = ','.join([str(p) for p in pred_phone_label_list])
          phone_f.write(f'{audio_id} {pred_phone_names}\n')
          
          gold_word_label = word_labels[idx].cpu().detach().numpy().tolist()
          pred_word_label = segment_word_logits[idx].max(-1)[1].cpu().detach().numpy().tolist()
          pred_word_label_quantized = quantized[idx].prod(-2).max(-1)[1].cpu().detach().numpy().tolist()
           
          gold_word_labels.append(gold_word_label)
          pred_word_labels.append(pred_word_label)
          pred_word_labels_quantized.append(pred_word_label_quantized)
          pred_word_name = preprocessor.to_word_text([pred_word_label])[0]
          pred_word_name_quantized = preprocessor.to_word_text([pred_word_label_quantized])[0]
          gold_word_name = preprocessor.to_word_text([gold_word_label])[0]
          word_readable_f.write(f'Utterance id: {audio_id}\n'
                                f'Gold word label: {gold_word_name}\n'
                                f'Pred word label: {pred_word_name}\n'
                                f'Pred word label by quantizer: {pred_word_name_quantized}\n\n') 
      phone_f.close()
      word_readable_f.close()
      avg_loss = total_loss / total_step
      # Compute word accuracy and word token F1
      print('[TEST RESULT]')
      word_acc = compute_accuracy(gold_word_labels, pred_word_labels)
      word_prec,\
      word_rec,\
      word_f1, _ = precision_recall_fscore_support(np.asarray(gold_word_labels),
                                                   np.asarray(pred_word_labels),
                                                   average='macro')

      word_prec_quantized,\
      word_rec_quantized,\
      word_f1_quantized, _ = precision_recall_fscore_support(np.asarray(gold_word_labels),
                                                             np.asarray(pred_word_labels_quantized),
                                                             average='macro') 

      token_f1,\
      token_prec,\
      token_recall = compute_token_f1(phone_file,
                                      gold_phone_file,
                                      self.ckpt_dir.joinpath(f'confusion.{self.global_epoch}.png'))
      info = f'Epoch {self.global_epoch}\tLoss: {avg_loss:.4f}\n'\
             f'WER: {1-word_acc:.3f}\tWord Acc.: {word_acc:.3f}\n'\
             f'Word Precision: {word_prec:.3f}\tWord Recall: {word_rec:.3f}\tWord F1: {word_f1:.3f}\n'\
             f'(By Quantizer) Word Precision: {word_prec_quantized:.3f}\tWord Recall: {word_rec_quantized:.3f}\tWord F1: {word_f1_quantized:.3f}\n'\
             f'Token Precision: {token_prec:.3f}\tToken Recall: {token_recall:.3f}\tToken F1: {token_f1:.3f}\n'
      print(info) 

      save_path = self.ckpt_dir.joinpath('results_file.txt')
      with open(save_path, 'a') as file:
        file.write(info)

      if self.history['word_acc'] < word_acc:
        self.history['token_result'] = [token_prec, token_recall, token_f1]
        self.history['word_acc'] = word_acc
        self.history['loss'] = avg_loss
        self.history['iter'] = self.global_iter
        self.history['epoch'] = self.global_epoch
        self.save_checkpoint(f'best_acc_{self.config.seed}.tar')
      self.set_mode('train') 
예제 #4
0
    def cluster(self,
                test_loader=None,
                out_prefix='predictions',
                save_embedding=False):
        self.load_checkpoint()
        X = []
        audio_files = []

        if test_loader is not None:
            B = test_loader.batch_size
            testset = test_loader.dataset
            split = testset.splits[0]
            gold_path = os.path.join(testset.data_path, split)
        else:
            test_loader = self.data_loader['test']
            B = test_loader.batch_size
            testset = test_loader.dataset
            split = testset.splits[0]
            gold_path = os.path.join(testset.data_path, split)

        embed_path = os.path.join(testset.data_path, f'{split}_embeddings')
        if save_embedding and not os.path.exists(embed_path):
            os.makedirs(embed_path)

        for b_idx, (audios, _, _, audio_masks, _) in enumerate(test_loader):
            if b_idx > 2 and self.debug:
                break
            audios = cuda(audios, self.cuda)
            audio_masks = cuda(audio_masks, self.cuda)
            _, _, _, embedding = self.audio_net(audios,
                                                masks=audio_masks,
                                                return_feat=True)
            # Concatenate the hidden vector with the input feature
            concat_embedding = torch.cat([audios.permute(0, 2, 1), embedding],
                                         axis=-1)
            if save_embedding:
                for idx in range(audios.size(0)):
                    audio_id = os.path.splitext(
                        os.path.split(testset.dataset[b_idx * B +
                                                      idx][0])[1])[0]
                    np.savetxt(os.path.join(embed_path, f'{audio_id}.txt'),
                               concat_embedding[idx].cpu().detach().numpy())

            X.append(concat_embedding.cpu().detach().numpy())
            audio_files.extend([
                testset.dataset[b_idx * B + i][0]
                for i in range(audios.size(0))
            ])
        X = np.concatenate(X, axis=0)

        shape = X.shape
        kmeans = KMeans(n_clusters=50).fit(X.reshape(shape[0] * shape[1], -1))
        np.save(os.path.join(self.ckpt_dir, f'kmeans_centroids.npy'),
                kmeans.cluster_centers_)
        encodings = kmeans.labels_.reshape(shape[0], shape[1])

        out_file = os.path.join(self.ckpt_dir, f'{out_prefix}_clustering.txt')
        out_f = open(out_file, 'w')
        for idx, (audio_file,
                  encoding) in enumerate(zip(audio_files, encodings)):
            audio_id = os.path.splitext(os.path.split(audio_file)[1])[0]
            pred_phonemes = ','.join([str(phn) for phn in encodings[idx]])
            out_f.write(f'{audio_id} {pred_phonemes}\n')
        out_f.close()

        compute_token_f1(out_file, gold_path,
                         os.path.join(self.ckpt_dir, 'confusion_cluster.png'))
예제 #5
0
    def test(self, save_embedding=False, out_prefix='predictions'):
        self.set_mode('eval')
        testset = self.data_loader['test'].dataset

        total_loss = 0
        izy_bound = 0
        izx_bound = 0
        total_num = 0
        seq_list = []
        pred_labels = []
        gold_labels = []
        if not self.ckpt_dir.joinpath('outputs/phonetic/dev-clean').is_dir():
            os.makedirs(self.ckpt_dir.joinpath('outputs/phonetic/dev-clean'))

        gold_path = os.path.join(os.path.join(testset.data_path, 'test/'))
        out_word_file = os.path.join(
            self.ckpt_dir, f'{out_prefix}_word.{self.global_epoch}.readable')
        out_phone_file = os.path.join(
            self.ckpt_dir, f'{out_prefix}_phoneme.{self.global_epoch}.txt')

        word_f = open(out_word_file, 'w')
        word_f.write('Image ID\tGold label\tPredicted label\n')
        phone_f = open(out_phone_file, 'w')
        with torch.no_grad():
            B = 0
            for b_idx, (audios, images, labels, audio_masks,
                        image_masks) in enumerate(self.data_loader['test']):
                if b_idx > 2 and self.debug:
                    break
                if b_idx == 0:
                    B = audios.size(0)
                audios = cuda(audios, self.cuda)
                labels = cuda(labels, self.cuda)
                images = cuda(images, self.cuda)
                audio_masks = cuda(audio_masks, self.cuda)
                image_masks = cuda(image_masks.unsqueeze(-1), self.cuda)
                image_size = images.size()
                '''
          if images.ndim in [2, 4]:
            image_logit, _ = self.image_net(images, return_score=self.image_feature)
          else:
            image_logit_flat, _ = self.image_net(
                                      images.view(-1, *image_size[2:]),
                                      return_score=True
                                      )
            image_logit = image_logit_flat.view(*image_size[:2], -1)
          
          if self.image_feature == 'label':
            if image_logit.ndim == 2:
              y = image_logit.max(-1)[1]
            else:
              y = F.one_hot(image_logit.max(-1)[1], self.n_class)
              y = (y * image_masks).max(1)[0]
          elif self.image_feature == 'multi_label':
            if image_logit.ndim == 2:
              y = (image_logit > 0).float()          
            else:
              y = ((image_logit > 0).float() * image_masks).max(1)[0]
          '''
                y = labels
                in_logits, logits, encoding, embedding = self.audio_net(
                    audios, masks=audio_masks, return_feat=True)

                # Word prediction
                in_logit = in_logits.sum(0)
                logit = (logits * audio_masks.unsqueeze(-1)).sum(dim=1)
                for idx in range(audios.size(0)):
                    global_idx = b_idx * B + idx
                    audio_id = os.path.splitext(
                        os.path.split(testset.dataset[global_idx][0])[1])[0]
                    golds = [y[idx].cpu().detach().numpy()]
                    preds = [logit[idx].max(-1)[1].cpu().detach().numpy()]
                    pred_phones = encoding[idx].max(-1)[1]
                    if self.use_segment:
                        pred_phones = testset.unsegment(
                            pred_phones,
                            testset.dataset[global_idx][3]).long()
                    pred_phones = pred_phones.cpu().detach().numpy().tolist()
                    gold_names = ','.join([self.class_names[c] for c in golds])
                    pred_names = ','.join([self.class_names[c] for c in preds])
                    pred_phones_str = ','.join(
                        [str(phn) for phn in pred_phones])
                    word_f.write(f'{audio_id}\t{gold_names}\t{pred_names}\n')
                    phone_f.write(f'{audio_id} {pred_phones_str}\n')

                pred_labels.append(F.one_hot(y, self.n_class).cpu())
                gold_labels.append(
                    F.one_hot(logit.max(-1)[1], self.n_class).cpu())
                cur_class_loss = F.cross_entropy(logit, y).div(math.log(2))
                cur_info_loss = (F.softmax(in_logit, dim=-1)\
                                  * F.log_softmax(in_logit, dim=-1)
                                ).sum(1).mean().div(math.log(2))
                cur_ib_loss = cur_class_loss + self.beta * cur_info_loss
                izy_bound = izy_bound + y.size(0) * math.log(
                    self.n_class, 2) - cur_class_loss
                izx_bound = izx_bound + cur_info_loss

                total_loss += cur_ib_loss.item()
                total_num += audios.size(0)
                for idx in range(audios.size(0)):
                    global_idx = b_idx * B + idx
                    example_id = testset.dataset[global_idx][0].split('/')[-1]
                    text = testset.dataset[global_idx][1]
                    feat_id = example_id
                    units = encoding[idx].max(-1)[1]

                    if save_embedding:
                        feat_fn = self.ckpt_dir.joinpath(
                            f'outputs/phonetic/dev-clean/{feat_id}')
                        np.savetxt(feat_fn, audios[idx].cpu().detach().numpy())
                        seq_list.append((feat_id, feat_fn))
        word_f.close()
        phone_f.close()

        pred_labels = torch.cat(pred_labels).detach().numpy()
        gold_labels = torch.cat(gold_labels).detach().numpy()
        ps, rs, f1s, _ = precision_recall_fscore_support(
            gold_labels.flatten(), pred_labels.flatten())
        p, r, f1 = ps[1], rs[1], f1s[1]

        class_f1s = np.zeros(self.n_class)
        for c in range(self.n_class):
            _, _, class_f1, _ = precision_recall_fscore_support(
                gold_labels[:, c], pred_labels[:, c])
            class_f1s[c] = class_f1[-1]
        izy_bound /= total_num
        izx_bound /= total_num

        avg_loss = total_loss / total_num
        print('[TEST RESULT]')
        print('Epoch {}\tLoss: {:.2f}\tPrecision: {:.2f}\tRecall: {:.2f}\tF1: {:.2f}'\
              .format(self.global_epoch, avg_loss, p, r, f1))
        token_f1, token_prec, token_recall = compute_token_f1(
            out_phone_file,
            gold_path,
            os.path.join(self.ckpt_dir, f'confusion.{self.global_epoch}.png'),
        )
        if self.history['f1'] < f1:
            self.history['f1'] = f1
            self.history['loss'] = avg_loss
            self.history['epoch'] = self.global_epoch
            self.history['iter'] = self.global_iter
            self.history['token_f1'] = token_f1
            self.save_checkpoint('best_acc.tar')
        self.set_mode('train')
    def test(self, save_embedding=False, out_prefix='predictions'):
        self.set_mode('eval')
        testset = self.data_loader['test'].dataset
        preprocessor = testset.preprocessor

        total_loss = 0.
        total_step = 0.

        gold_word_labels = []
        if not self.ckpt_dir.joinpath('outputs/phonetic/dev-clean').is_dir():
            os.makedirs(self.ckpt_dir.joinpath('outputs/phonetic/dev-clean'))

        gold_phone_file = os.path.join(
            testset.data_path,
            f'{testset.splits[0]}/{testset.splits[0]}_nonoverlap.item')
        gold_visual_phone_file = os.path.join(
            testset.data_path,
            f'{testset.splits[0]}/{testset.splits[0]}_visual.item')

        phone_file = self.ckpt_dir.joinpath(
            f'{out_prefix}_phoneme.{self.global_epoch}.txt')
        visual_phone_file = self.ckpt_dir.joinpath(
            f'{out_prefix}_visual_phoneme.{self.global_epoch}.txt')
        phone_f = open(phone_file, 'w')
        visual_phone_f = open(visual_phone_file, 'w')
        phone_readable_f = open(
            self.ckpt_dir.joinpath(
                f'{out_prefix}_phoneme.{self.global_epoch}.readable'), 'w')
        visual_phone_readable_f = open(
            self.ckpt_dir.joinpath(
                f'{out_prefix}_visual_phoneme.{self.global_epoch}.readable'),
            'w')

        with torch.no_grad():
            B = 0
            for b_idx, (audios, phoneme_labels, word_labels,\
                        audio_masks, phone_masks, word_masks)\
                        in enumerate(self.data_loader['test']):
                if b_idx > 2 and self.debug:
                    break
                if b_idx == 0:
                    B = audios.size(0)

                audios = cuda(audios, self.cuda)
                if self.audio_feature == 'wav2vec2':
                    x = self.audio_feature_net.feature_extractor(audios)
                else:
                    x = audios

                word_labels = cuda(word_labels, self.cuda)
                phoneme_labels = cuda(phoneme_labels, self.cuda)
                audio_masks = cuda(audio_masks, self.cuda)
                phone_masks = cuda(phone_masks, self.cuda)
                word_masks = cuda(word_masks, self.cuda)

                audio_lens = audio_masks.sum(-1).long()
                sent_lens = phone_masks.sum(-1).long()
                word_lens = word_masks.sum(dim=(-2, -1)).long()
                word_num = (word_labels >= 0).long().sum(-1)
                cluster_logits, embedding = self.audio_net(x, return_feat=True)
                phoneme_labels_aligned = self.align_net(
                    F.one_hot(phoneme_labels * phone_masks.long(),
                              self.n_phone_class), phone_masks, audio_masks)

                cluster_probs = F.softmax(cluster_logits, dim=-1)\
                                .view(-1, self.max_feat_len, self.n_phone_class)
                if self.max_normalize:
                    cluster_probs = cluster_probs / cluster_probs.max(
                        -1, keepdim=True)[0]

                phone_label_mask = F.one_hot(
                    phoneme_labels * phone_masks.long(),
                    self.n_phone_class).sum(1, keepdim=True)
                phone_label_mask = (phone_label_mask > 0).long()

                if self.loss_type == 'binary_cross_entropy':
                    loss = self.weight_phone_loss * self.criterion(
                        cluster_probs,
                        phoneme_labels_aligned * audio_masks.unsqueeze(-1))
                else:
                    loss = self.weight_phone_loss * self.criterion(
                        cluster_probs, phoneme_labels_aligned, audio_masks)

                word_cluster_logits = torch.matmul(word_masks,
                                                   cluster_logits.unsqueeze(1))
                word_cluster_probs = F.softmax(word_cluster_logits, dim=-1)\
                                     .view(-1, self.max_word_len, self.n_phone_class)
                if self.max_normalize:
                    word_cluster_probs = word_cluster_probs / word_cluster_probs.max(
                        -1, keepdim=True)[0]

                word_phone_probs = self.phone_net(word_labels.flatten())
                word_phone_probs = word_phone_probs\
                                   .unsqueeze(1).expand(-1, self.max_word_len, -1)
                if self.max_normalize:
                    word_phone_probs = word_phone_probs / (
                        word_phone_probs.max(-1, keepdim=True)[0] + EPS)

                if self.loss_type == 'binary_cross_entropy':
                    loss = loss + self.weight_word_loss * self.criterion(
                        word_cluster_probs,
                        word_phone_probs *
                        word_masks.sum(-1).view(-1, self.max_word_len, 1))
                else:
                    loss = loss + self.weight_word_loss * self.criterion(
                        word_cluster_probs, word_phone_probs,
                        word_masks.sum(-1).view(-1, self.max_word_len))
                total_loss += loss.cpu().detach().numpy()
                total_step += 1.

                for idx in range(audios.size(0)):
                    global_idx = b_idx * B + idx
                    audio_id = os.path.splitext(
                        os.path.split(testset.dataset[global_idx][0])[1])[0]
                    gold_phone_label = phoneme_labels[idx, :sent_lens[idx]]
                    pred_phone_label = cluster_logits[
                        idx, :audio_lens[idx]].max(-1)[1]
                    gold_phone_names = ','.join(
                        preprocessor.to_text(gold_phone_label))
                    pred_phone_names = ','.join(
                        preprocessor.tokens_to_text(pred_phone_label))
                    phone_readable_f.write(
                        f'Utterance id: {audio_id}\n'
                        f'Gold transcript: {gold_phone_names}\n'
                        f'Pred transcript: {pred_phone_names}\n\n')
                    us_ratio = int(
                        self.hop_len_ms / 10) * self.audio_net.ds_ratio > 1
                    if us_ratio > 1:
                        pred_phone_label = pred_phone_label.unsqueeze(-1)\
                                           .expand(-1, us_ratio).flatten()
                    pred_phone_label_list = pred_phone_label.cpu().detach(
                    ).numpy().tolist()
                    pred_phone_names = ','.join(
                        [str(p) for p in pred_phone_label_list])
                    phone_f.write(f'{audio_id} {pred_phone_names}\n')

                    if word_num[idx] > 0:
                        pred_phone_word_masked_label = pred_phone_label\
                                                       * word_masks[idx, :, :, :audio_lens[idx]].long().sum(dim=(0, 1))
                        pred_phone_word_masked_label_list = pred_phone_word_masked_label.detach(
                        ).cpu().numpy().tolist()
                        pred_phone_word_masked_names = ','.join([
                            str(p) for p in pred_phone_word_masked_label_list
                        ])
                        visual_phone_f.write(
                            f'{audio_id} {pred_phone_word_masked_names}\n')

                        gold_word_label = word_labels[idx, :word_num[idx]].cpu(
                        ).detach().numpy().tolist()
                        gold_word_names = preprocessor.to_word_text(
                            gold_word_label)
                        pred_visual_phone_label = word_cluster_logits[
                            idx, :word_num[idx]].max(-1)[1]
                        for word_idx in range(word_num[idx]):
                            gold_word_name = gold_word_names[word_idx]
                            pred_label = pred_visual_phone_label[
                                word_idx, :word_lens[idx, word_idx]]
                            pred_label = pred_label.cpu().detach().numpy(
                            ).tolist()
                            pred_visual_phone_names = ','.join(
                                preprocessor.tokens_to_text(pred_label))
                            visual_phone_readable_f.write(
                                f'Utterance id: {audio_id}\n'
                                f'Gold word label: {gold_word_name}\n'
                                f'Pred transcript: {pred_visual_phone_names}\n\n'
                            )
            phone_f.close()
            visual_phone_f.close()
            phone_readable_f.close()
            visual_phone_readable_f.close()

            avg_loss = total_loss / total_step

            # Token F1
            token_f1,\
            token_prec,\
            token_recall = compute_token_f1(phone_file,
                                            gold_phone_file,
                                            self.ckpt_dir.joinpath(f'confusion.{self.global_epoch}.png'))

            visual_token_f1,\
            visual_token_prec,\
            visual_token_recall = compute_token_f1(visual_phone_file,
                                                   gold_visual_phone_file,
                                                   self.ckpt_dir.joinpath(f'visual_confusion.{self.global_epoch}.png'))
            print('[TEST RESULT]')
            print(
                f'Epoch {self.global_epoch}\tLoss: {avg_loss:.4f}\tToken F1: {token_f1:.3f}\tVisual Token F1: {visual_token_f1:.3f}'
            )

            if self.history['visual_token_f1'] < visual_token_f1:
                self.history['token_f1'] = token_f1
                self.history['visual_token_f1'] = visual_token_f1
                self.history['loss'] = avg_loss
                self.history['iter'] = self.global_iter
                self.history['epoch'] = self.global_epoch
                self.save_checkpoint('best_acc.tar')
            self.set_mode('train')
    def cluster(self, n_clusters=44, out_prefix='quantized_outputs'):
        self.set_mode('eval')
        testset = self.data_loader['test'].dataset

        us_ratio = int(self.hop_len_ms / 10) * self.audio_net.ds_ratio
        with torch.no_grad():
            B = 0
            utt_ids_dict = dict()
            X_dict = dict()
            segment_dict = dict()
            for split in self.data_loader:
                X_dict[split] = []
                utt_ids_dict[split] = []
                segment_dict[split] = dict()
                for b_idx, batch in enumerate(self.data_loader[split]):
                    audios = batch[0]
                    audio_masks = batch[3]
                    audios = cuda(audios, self.cuda)
                    audio_masks = cuda(audio_masks, self.cuda)
                    audio_lens = audio_masks.sum(-1)
                    if b_idx > 2 and self.debug:
                        break
                    if b_idx == 0:
                        B = audios.size(0)

                    if self.audio_feature == 'wav2vec2':
                        x = self.audio_feature_net.feature_extractor(audios)
                    else:
                        x = audios

                    audio_lens = audio_masks.sum(-1).long()
                    # XXX _, embedding = self.audio_net(x,
                    #                               masks=audio_masks,
                    #                               return_feat=True)
                    embedding = x

                    for idx in range(audios.size(0)):
                        global_idx = b_idx * B + idx
                        utt_id = os.path.splitext(
                            os.path.split(self.data_loader[split].dataset.
                                          dataset[global_idx][0])[1])[0]

                        embed = embedding[
                            idx, :audio_lens[idx]].cpu().detach().numpy()
                        X_dict[split].extend(embed.tolist())
                        utt_ids_dict[split].extend([utt_id] * embed.shape[0])
                        segment_dict[split][utt_id] = self.data_loader[
                            split].dataset.dataset[global_idx][3]
                X_dict[split] = np.asarray(X_dict[split])

            X = np.concatenate([X_dict[split] for split in self.data_loader])
            begin_time = time.time()
            clusterer = KMeans(n_clusters=n_clusters).fit(X)
            print(f'KMeans take {time.time()-begin_time} s to finish')
            np.save(self.ckpt_dir.joinpath('cluster_means.npy'),
                    clusterer.cluster_centers_)

            ys = clusterer.predict(X_dict['test'])
            utt_ids = utt_ids_dict['test']
            segments = segment_dict['test']
            filename = self.ckpt_dir.joinpath(out_prefix + '.txt')
            out_f = open(filename, 'w')
            for utt_id, group in groupby(list(zip(utt_ids, ys)),
                                         lambda x: x[0]):
                y = torch.LongTensor([g[1] for g in group])
                y_unseg = testset.unsegment(
                    y,
                    segments[utt_id]).long().cpu().detach().numpy().tolist()
                y_str = ','.join([str(l) for l in y_unseg])
                out_f.write(f'{utt_id} {y_str}\n')
            out_f.close()
            gold_path = os.path.join(
                os.path.join(testset.data_path, f'{testset.splits[0]}'))
            token_f1, token_prec, token_recall = compute_token_f1(
                filename,
                gold_path,
                self.ckpt_dir.joinpath(f'confusion.png'),
            )
            info = f'Token Precision: {token_prec:.3f}\tToken Recall: {token_recall:.3f}\tToken F1: {token_f1:.3f}\n'
            save_path = self.ckpt_dir.joinpath(
                f'results_file_{self.config.seed}.txt')
            with open(save_path, 'a') as file:
                file.write(info)

            if self.history['token_f1'] < token_f1:
                self.history['token_f1'] = token_f1
예제 #8
0
  def test(self, save_embedding=False, out_prefix='predictions'):
    self.set_mode('eval')
    testset = self.data_loader['test'].dataset 
    preprocessor = testset.preprocessor
    
    total_loss = 0.
    total_num = 0.
    gold_labels = []
    pred_labels = []
    if not self.ckpt_dir.joinpath('outputs/phonetic/dev-clean').is_dir():
      os.makedirs(self.ckpt_dir.joinpath('outputs/phonetic/dev-clean'))

    gold_path = os.path.join(os.path.join(testset.data_path, f'{testset.splits[0]}'))
    out_word_file = os.path.join(
                      self.ckpt_dir,
                      f'{out_prefix}_word.{self.global_epoch}.readable'
                    )
    out_phone_readable_file = os.path.join(
                      self.ckpt_dir,
                      f'{out_prefix}_phoneme.{self.global_epoch}.readable'
                    )
    out_phone_file = os.path.join(
                       self.ckpt_dir,
                       f'{out_prefix}_phoneme.{self.global_epoch}.txt'
                     )

    word_f = open(out_word_file, 'w')
    word_f.write('Image ID\tGold label\tPredicted label\n')
    phone_readable_f = open(out_phone_readable_file, 'w')
    phone_f = open(out_phone_file, 'w')
    
    gold_word_labels = []
    gold_phone_labels = []
    pred_word_labels = []
    pred_phone_labels = []
     
    with torch.no_grad():
      B = 0
      for b_idx, (audios, phoneme_labels, word_labels,\
                  audio_masks, phone_masks, word_masks)\
                  in enumerate(self.data_loader['test']):
        if b_idx > 2 and self.debug:
          break
        if b_idx == 0:
          B = audios.size(0)
          
        audios = cuda(audios, self.cuda)
        if self.audio_feature == 'wav2vec2':
          audios = self.audio_feature_net.feature_extractor(audios)
        phoneme_labels = cuda(phoneme_labels, self.cuda)
        word_labels = cuda(word_labels, self.cuda)
        audio_masks = cuda(audio_masks, self.cuda)
        phone_masks = cuda(phone_masks, self.cuda)
        word_masks = cuda(word_masks, self.cuda).sum(-2) # Use averaged frame embedding by default
        if self.audio_net.ds_ratio > 1:
          audio_masks = audio_masks[:, ::self.audio_net.ds_ratio]
          word_masks = word_masks[:, :, ::self.audio_net.ds_ratio]
        
        audio_lens = audio_masks.sum(-1).long()
        sent_lens = phone_masks.sum(-1).long()
        word_lens = (word_labels >= 0).long().sum(-1)

        if self.model_type == 'blstm':
          out_logits, embedding = self.audio_net(audios, return_feat=True)
          word_logits = out_logits[:, :, :self.n_visual_class]
          phone_logits = out_logits[:, :, self.n_visual_class:]
        else:
          gumbel_logits, out_logits, encoding, embedding = self.audio_net(
            audios, masks=audio_masks,
            return_feat=True)
          phone_logits = gumbel_logits
          word_logits = out_logits

        if self.model_type == 'vq-mlp':
          word_logits = out_logits[:, :, :self.n_visual_class]

        word_logits = torch.matmul(word_masks, word_logits)
        word_loss = F.cross_entropy(word_logits.permute(0, 2, 1), 
                                    word_labels,
                                    ignore_index=-100)\
                                    .div(math.log(2)) 
        phone_loss = F.ctc_loss(F.log_softmax(phone_logits, dim=-1)\
                                  .permute(1, 0, 2),
                                phoneme_labels,
                                audio_lens,
                                sent_lens)
        info_loss = (F.softmax(phone_logits, dim=-1)\
                      * F.log_softmax(phone_logits, dim=-1)
                    ).sum().div(sent_lens.sum() * math.log(2))
        total_loss += word_loss + phone_loss + self.beta * info_loss
        total_num += 1. 

        for idx in range(audios.size(0)):
          global_idx = b_idx * B + idx
          audio_id = os.path.splitext(os.path.split(testset.dataset[global_idx][0])[1])[0]
          if word_lens[idx] > 0:
            gold_words = word_labels[idx, :word_lens[idx]]
            pred_words = word_logits[idx, :word_lens[idx]].max(-1)[1]
            gold_words = gold_words.cpu().detach().numpy().tolist()
            pred_words = pred_words.cpu().detach().numpy().tolist()
            gold_word_labels.append(gold_words)
            pred_word_labels.append(pred_words)

            gold_word_names = ','.join(preprocessor.to_word_text(
                                    gold_words)
                                  )
            pred_word_names = ','.join(preprocessor.to_word_text(
                                    pred_words)
                                  )
            word_f.write(f'{audio_id}\t{gold_word_names}\t{pred_word_names}\n')

          feat_fn = self.ckpt_dir.joinpath(f'outputs/phonetic/dev-clean/{audio_id}.txt')
          if save_embedding:
            np.savetxt(feat_fn, embedding[idx, :audio_lens[idx]][::2].cpu().detach().numpy()) # XXX
           
          gold_phone_label = phoneme_labels[idx, :sent_lens[idx]]
          pred_phone_label = phone_logits[idx, :audio_lens[idx]].max(-1)[1]
          gold_phone_names = ','.join(preprocessor.to_text(gold_phone_label))
          pred_phone_names = ','.join(preprocessor.tokens_to_text(pred_phone_label))
          phone_readable_f.write(f'Utterance id: {audio_id}\n'
                                 f'Gold transcript: {gold_phone_names}\n'
                                 f'Pred transcript: {pred_phone_names}\n\n')

          gold_phone_label = gold_phone_label.cpu().detach().numpy().tolist()
          if int(self.hop_len_ms / 10) * self.audio_net.ds_ratio > 1:
            us_ratio = int(self.hop_len_ms / 10) * self.audio_net.ds_ratio
            pred_phone_label = pred_phone_label.unsqueeze(-1)\
                                 .expand(-1, us_ratio).flatten()
          pred_phone_label = pred_phone_label.cpu().detach().numpy().tolist()
          gold_phone_labels.append(gold_phone_label)
          pred_phone_labels.append(pred_phone_label) 
          
          pred_phone_label = preprocessor.to_index(preprocessor.to_text(pred_phone_label))
          pred_phone_label = pred_phone_label.cpu().detach().numpy().tolist()
          pred_phone_names = ','.join([str(p) for p in pred_phone_label])
          phone_f.write(f'{audio_id} {pred_phone_names}\n')  
   
    word_f.close()
    phone_f.close()              
   
    avg_loss = total_loss / total_num 
    acc = compute_accuracy(gold_word_labels, pred_word_labels)
    dist, n_tokens = compute_edit_distance(pred_phone_labels, gold_phone_labels, preprocessor)
    pter = float(dist) / float(n_tokens)
    print('[TEST RESULT]')
    print('Epoch {}\tLoss: {:.4f}\tWord Acc.: {:.3f}\tPTER: {:.3f}'\
          .format(self.global_epoch, avg_loss, acc, pter))
    token_f1, token_prec, token_recall = compute_token_f1(
                                           out_phone_file,
                                           gold_path,
                                           os.path.join(
                                             self.ckpt_dir,
                                             f'confusion.{self.global_epoch}.png'
                                           )
                                         )
    if self.history['acc'] < acc:
      self.history['acc'] = acc
      self.history['loss'] = avg_loss
      self.history['epoch'] = self.global_epoch
      self.history['iter'] = self.global_iter
      self.history['token_f1'] = token_f1
      self.save_checkpoint('best_acc.tar')
    self.set_mode('train')
  def cluster(self,
              n_clusters=44,
              out_prefix='quantized_outputs'):
      self.set_mode('eval')
      testset = self.data_loader['test'].dataset
      preprocessor = testset.preprocessor
      save_path = self.ckpt_dir.joinpath(f'results_file_{self.config.seed}.txt')

      us_ratio = int(self.hop_len_ms / 10) * self.audio_net.ds_ratio
      with torch.no_grad():
        B = 0
        utt_ids_dict = dict()
        X_dict = dict()
        segment_dict = dict()
        for split in self.data_loader:
          X_dict[split] = []
          utt_ids_dict[split] = []
          segment_dict[split] = dict()
          for b_idx, batch in enumerate(self.data_loader[split]):
            audios = batch[0]
            audio_masks = batch[3]
            if b_idx > 2 and self.debug:
              break
            if b_idx == 0:
              B = audios.size(0)

            x = cuda(audios, self.cuda)
            audio_masks = cuda(audio_masks, self.cuda)
            if self.audio_feature == 'wav2vec2':
              B = x.size(0)
              T = x.size(1)
              x = self.audio_feature_net.feature_extractor(x.view(B*T, -1)).permute(0, 2, 1)
              x = x.view(B, T, x.size(-2), x.size(-1))
              x = (x * audio_masks.unsqueeze(-1)).sum(-2) / (audio_masks.sum(-1, keepdim=True) + torch.tensor(1e-10, device=x.device))
              audio_masks = torch.where(audio_masks.sum(-1) > 0,
                                        torch.tensor(1, device=x.device),
                                        torch.tensor(0, device=x.device))

            # (batch size, max segment num)
            if audio_masks.dim() == 3:
              segment_masks = torch.where(audio_masks.sum(-1) > 0,
                                          torch.tensor(1., device=audio_masks.device),
                                          torch.tensor(0., device=audio_masks.device))
            else:
              segment_masks = audio_masks.clone()
            audio_lens = segment_masks.sum(-1).long()

            if self.audio_net.ds_ratio > 1:
              audio_masks = audio_masks[:, ::self.audio_net.ds_ratio]
            # (batch size, max segment num, n visual class)
            word_logits, quantized, phone_loss = self.audio_net(x, masks=audio_masks)
            word_probs = F.softmax(word_logits, dim=-1)

            for idx in range(audios.size(0)):
              global_idx = b_idx * B + idx
              utt_id = os.path.splitext(os.path.split(self.data_loader[split].dataset.dataset[global_idx][0])[1])[0]
              
              embed = word_probs[idx, :audio_lens[idx]].cpu().detach().numpy()
              X_dict[split].extend(embed.tolist())
              utt_ids_dict[split].extend([utt_id]*embed.shape[0])
              segment_dict[split][utt_id] = self.data_loader[split].dataset.dataset[global_idx][3]
          X_dict[split] = np.asarray(X_dict[split])
        
        X = np.concatenate([X_dict[split] for split in self.data_loader])
        begin_time = time.time()
        self.clusterer.fit(X)
        info =  f'Epoch {self.global_epoch}\tClustering Time: {time.time()-begin_time:.2f} s'
        print(info)
        with open(save_path, 'a') as file:
          file.write(info+'\n')

        ys = self.clusterer.predict(X_dict['test'])
        utt_ids = utt_ids_dict['test']
        segments = segment_dict['test']
        filename = self.ckpt_dir.joinpath(out_prefix+'.txt')
        out_f = open(filename, 'w')
        for utt_id, group in groupby(list(zip(utt_ids, ys)), lambda x:x[0]):
          y = torch.LongTensor([g[1] for g in group])
          y_unseg = testset.unsegment(y, segments[utt_id]).long().cpu().detach().numpy().tolist()
          y_str = ','.join([str(l) for l in y_unseg]) 
          out_f.write(f'{utt_id} {y_str}\n') 
        out_f.close()
        np.save(self.ckpt_dir.joinpath('cluster_means.npy'), self.clusterer.cluster_centers_) 
        gold_path = os.path.join(os.path.join(testset.data_path, f'{testset.splits[0]}'))
        print('[CLUSTERING RESULT]')
        token_f1, token_prec, token_recall = compute_token_f1(
                                               filename,
                                               gold_path,
                                               self.ckpt_dir.joinpath(f'confusion.png'),
                                             ) 
        info = f'Token Precision: {token_prec:.3f}\tToken Recall: {token_recall:.3f}\tToken F1: {token_f1:.3f}'
        with open(save_path, 'a') as file:
          file.write(info+'\n')
  
        if self.history['token_f1'] < token_f1:
          self.history['token_f1'] = token_f1
          self.history['best_cluster_epoch'] = self.global_epoch
          self.history['best_cluster_iter'] = self.global_iter
          self.save_checkpoint(f'best_token_f1_{self.config.seed}.tar')
        self.set_mode('train')
예제 #10
0
    def test(self, save_embedding=False, out_prefix='predictions'):
        self.set_mode('eval')
        testset = self.data_loader['test'].dataset
        preprocessor = testset.preprocessor

        total_loss = 0.
        total_word_loss = 0.
        total_phone_loss = 0.
        total_step = 0.

        pred_word_labels = []
        gold_word_labels = []
        if not self.ckpt_dir.joinpath('outputs/phonetic/dev-clean').is_dir():
            os.makedirs(self.ckpt_dir.joinpath('outputs/phonetic/dev-clean'))

        gold_phone_file = os.path.join(
            testset.data_path,
            f'{testset.splits[0]}/{testset.splits[0]}_nonoverlap.item')
        word_readable_f = open(
            self.ckpt_dir.joinpath(
                f'{out_prefix}_visual_word.{self.global_epoch}.readable'), 'w')
        phone_file = self.ckpt_dir.joinpath(
            f'{out_prefix}_phoneme.{self.global_epoch}.txt')
        phone_f = open(phone_file, 'w')

        with torch.no_grad():
            B = 0
            for b_idx, batch in enumerate(self.data_loader['test']):
                audios = batch[0]
                word_labels = batch[2]
                audio_masks = batch[3]
                word_masks = batch[5]
                if b_idx > 2 and self.debug:
                    break
                if b_idx == 0:
                    B = audios.size(0)

                x = cuda(audios, self.cuda)
                if self.audio_feature == "wav2vec2":
                    x = self.audio_feature_net.feature_extractor(x)
                word_labels = cuda(word_labels, self.cuda)
                audio_masks = cuda(audio_masks, self.cuda)
                # (batch size, max word num, max word len, max audio len)
                word_masks = cuda(word_masks, self.cuda)
                # (batch size x max word num, max word len)
                word_masks_2d = torch.where(
                    word_masks.sum(-1).view(-1, self.max_word_len) > 0,
                    torch.tensor(1, device=x.device),
                    torch.tensor(0, device=x.device))

                # (batch size x max word num)
                word_masks_1d = torch.where(
                    word_masks_2d.sum(-1) > 0, torch.tensor(1,
                                                            device=x.device),
                    torch.tensor(0, device=x.device))
                # images = cuda(images, self.cuda)

                if self.audio_net.ds_ratio > 1:
                    audio_masks = audio_masks[:, ::self.audio_net.ds_ratio]
                    word_masks = word_masks[:, :, ::self.audio_net.ds_ratio]
                audio_lens = audio_masks.sum(-1).long()
                # (batch size, max word num)
                word_lens = word_masks.sum(dim=(-1, -2)).long()
                # (batch size,)
                word_nums = torch.where(word_lens > 0,
                                        torch.tensor(1, device=x.device),
                                        torch.tensor(0,
                                                     device=x.device)).sum(-1)

                sent_phone_logits,\
                word_logits,\
                 _,\
                 embedding = self.audio_net(x,
                                            masks=audio_masks,
                                            return_feat=True)

                # (batch size, max word num, max word len, n phone class)
                phone_logits = torch.matmul(word_masks,
                                            sent_phone_logits.unsqueeze(1))
                # (batch size x max word num, max word len, n phone class)
                phone_logits = phone_logits.view(-1, self.max_word_len,
                                                 self.n_phone_class)
                phone_probs = F.softmax(phone_logits, dim=-1)

                # (batch size x max word num, n phone class)
                pm_phone_probs = self.audio_net.reverse_forward(
                    word_labels.flatten(), ignore_index=self.ignore_index)
                # (batch size x max word num, max word len, n phone class)
                pm_phone_probs = pm_phone_probs\
                                 .unsqueeze(1).expand(-1, self.max_word_len, -1)

                # (batch size, max word num, max word len, n word class)
                word_logits = torch.matmul(word_masks,
                                           word_logits.unsqueeze(1))
                # (batch size, max word num, n word class)
                word_logits = (
                    word_logits * word_masks.sum(-1, keepdim=True)
                ).sum(
                    -2
                )  # / (word_lens.unsqueeze(-1) + EPS) # Average over frames
                # (batch size x max word num, n word class)
                word_logits_2d = word_logits.view(-1, self.n_visual_class)
                word_probs = F.softmax(word_logits_2d, dim=-1)

                if self.loss_type == 'cross_entropy':
                    # phone_loss = -((pm_phone_probs * word_masks_2d.unsqueeze(-1)) *\
                    #              torch.log_softmax(phone_logits, dim=-1)).mean()
                    phone_loss = self.criterion(phone_probs, pm_phone_probs,
                                                word_masks_2d)
                    word_loss = F.cross_entropy(word_logits_2d,
                                                word_labels.flatten(),
                                                ignore_index=self.ignore_index)
                else:
                    phone_loss = self.criterion(phone_probs, pm_phone_probs,
                                                word_masks_2d)
                    if self.max_normalize:
                        word_probs = word_probs / word_probs.max(
                            -1, keepdim=True)[0]
                    # (batch size x max word num, n word class)
                    word_labels_onehot = F.one_hot(word_labels,
                                                   self.n_visual_class)
                    word_labels_onehot = word_masks_1d.unsqueeze(
                        -1) * word_labels_onehot
                    word_loss = self.criterion(
                        word_probs, word_labels_onehot,
                        torch.ones(audio_lens.size(0), device=x.device))
                loss = word_loss + phone_loss
                total_loss += loss.cpu().detach().numpy()
                total_word_loss += word_loss.cpu().detach().numpy()
                total_phone_loss += phone_loss.cpu().detach().numpy()
                total_step += 1.

                for idx in range(audios.size(0)):
                    global_idx = b_idx * B + idx
                    audio_id = os.path.splitext(
                        os.path.split(testset.dataset[global_idx][0])[1])[0]
                    pred_phone_label = sent_phone_logits[
                        idx, :audio_lens[idx]].max(-1)[1]
                    pred_phone_label_list = pred_phone_label.cpu().detach(
                    ).numpy().tolist()
                    pred_phone_names = ','.join(
                        [str(p) for p in pred_phone_label_list])
                    phone_f.write(f'{audio_id} {pred_phone_names}\n')

                    if word_nums[idx] > 0:
                        gold_word_label = word_labels[
                            idx, :word_nums[idx]].cpu().detach().numpy(
                            ).tolist()
                        pred_word_label = word_logits[
                            idx, :word_nums[idx]].max(
                                -1)[1].cpu().detach().numpy().tolist()

                        gold_word_labels.extend(gold_word_label)
                        pred_word_labels.extend(pred_word_label)
                        pred_word_names = preprocessor.to_word_text(
                            pred_word_label)
                        gold_word_names = preprocessor.to_word_text(
                            gold_word_label)

                        for word_idx in range(word_nums[idx]):
                            pred_word_name = pred_word_names[word_idx]
                            gold_word_name = gold_word_names[word_idx]
                            word_readable_f.write(
                                f'Utterance id: {audio_id}\n'
                                f'Gold word label: {gold_word_name}\n'
                                f'Pred word label: {pred_word_name}\n\n')
            phone_f.close()
            word_readable_f.close()
            avg_loss = total_loss / total_step
            avg_word_loss = total_word_loss / total_step
            avg_phone_loss = total_phone_loss / total_step
            # Compute word accuracy and word token F1
            print('[TEST RESULT]')
            word_acc = compute_accuracy(gold_word_labels, pred_word_labels)
            word_prec,\
            word_rec,\
            word_f1, _ = precision_recall_fscore_support(np.asarray(gold_word_labels),
                                                         np.asarray(pred_word_labels),
                                                         average='macro')
            print(
                f'Epoch {self.global_epoch}\tLoss: {avg_loss:.4f}\tWord Loss: {avg_word_loss}\tPhone Loss: {avg_phone_loss}\n'
                f'WER: {1-word_acc:.3f}\tWord Acc.: {word_acc:.3f}\n'
                f'Word Precision: {word_prec:.3f}\tWord Recall: {word_rec:.3f}\tWord F1: {word_f1:.3f}'
            )
            token_f1,\
            token_prec,\
            token_recall = compute_token_f1(phone_file,
                                            gold_phone_file,
                                            self.ckpt_dir.joinpath(f'confusion.{self.global_epoch}.png'))

            if self.history['word_acc'] < word_acc:
                self.history['token_f1'] = token_f1
                self.history['word_acc'] = word_acc
                self.history['loss'] = avg_loss
                self.history['iter'] = self.global_iter
                self.history['epoch'] = self.global_epoch
                self.save_checkpoint(f'best_acc_{self.config.seed}.tar')
            self.set_mode('train')