def test(self, save_embedding=False, out_prefix='predictions'):
        self.set_mode('eval')
        testset = self.data_loader['test'].dataset
        preprocessor = testset.preprocessor
        temp = self.history['temp']

        total_loss = 0.
        total_neg_evidence = 0.
        total_num = 0.
        gold_word_labels = []
        pred_word_labels = []
        gold_word_names = []
        if not self.ckpt_dir.joinpath('outputs/phonetic/dev-clean').is_dir():
            os.makedirs(self.ckpt_dir.joinpath('outputs/phonetic/dev-clean'))

        gold_path = os.path.join(
            os.path.join(testset.data_path, f'{testset.splits[0]}'))
        out_word_file = os.path.join(
            self.ckpt_dir, f'{out_prefix}_word.{self.global_epoch}.readable')
        word_f = open(out_word_file, 'w')
        word_f.write('Image ID\tGold label\tPredicted label\n')
        phone_readable_f = open(
            self.ckpt_dir.joinpath(
                f'{out_prefix}_phoneme.{self.global_epoch}.txt'), 'w')

        with torch.no_grad():
            B = 0
            for b_idx, (audios, phoneme_labels, word_labels,\
                        audio_masks, phone_masks, word_masks)\
                        in enumerate(self.data_loader['test']):
                if b_idx > 2 and self.debug:
                    break
                if b_idx == 0:
                    B = audios.size(0)

                audios = cuda(audios, self.cuda)
                if self.audio_feature == 'wav2vec2':
                    x = self.audio_feature_net.feature_extractor(audios)
                else:
                    x = audios

                word_labels = cuda(word_labels, self.cuda)
                phoneme_labels = cuda(phoneme_labels, self.cuda)
                audio_masks = cuda(audio_masks, self.cuda)
                phone_masks = cuda(phone_masks, self.cuda)
                word_masks = cuda(word_masks, self.cuda)

                audio_lens = audio_masks.sum(-1).long()
                sent_lens = phone_masks.sum(-1).long()
                word_lens = (word_labels >= 0).long().sum(-1)
                (mu, std),\
                outputs,\
                embedding = self.audio_net(x,
                                           masks=audio_masks,
                                           temp=temp,
                                           num_sample=self.num_sample,
                                           return_feat=True)

                word_logits = outputs[:, :, :self.n_visual_class]
                word_logits = torch.matmul(word_masks, word_logits)
                mu_x = outputs[:, :, self.n_visual_class:self.n_visual_class +
                               self.input_size]
                std_x = outputs[:, :, self.n_visual_class +
                                self.input_size:self.n_visual_class +
                                2 * self.input_size]
                std_x = F.softplus(std_x - 5, beta=1)

                word_loss = F.cross_entropy(word_logits.permute(0, 2, 1),
                                            word_labels,
                                            ignore_index=-100)\
                                            .div(math.log(2))

                if self.weight_phone_loss > 0:
                    phone_logits = outputs[:, :, self.n_visual_class +
                                           2 * self.input_size:]
                    phone_loss = F.ctc_loss(F.log_softmax(phone_logits, dim=-1)\
                                            .permute(1, 0, 2),
                                            phoneme_labels,
                                            audio_lens,
                                            sent_lens)
                else:
                    phone_logits = None
                    phone_loss = 0.
                info_loss = -0.5 * (1 + 2 * std.log() - mu.pow(2) -
                                    std.pow(2)).sum(-1).mean().div(math.log(2))
                evidence_loss = self.weight_evidence * F.mse_loss(
                    x.permute(0, 2, 1), mu_x).div(math.log(2))

                total_loss += (self.weight_phone_loss * phone_loss +\
                               self.weight_word_loss * word_loss +\
                               evidence_loss + self.beta * info_loss).cpu().detach().numpy()
                total_neg_evidence += evidence_loss.cpu().detach().numpy()
                total_num += 1.

                for idx in range(audios.size(0)):
                    global_idx = b_idx * B + idx
                    audio_id = os.path.splitext(
                        os.path.split(testset.dataset[global_idx][0])[1])[0]
                    if word_lens[idx] > 0:
                        gold_words = word_labels[idx, :word_lens[idx]]
                        pred_words = word_logits[idx, :word_lens[idx]].max(
                            -1)[1]

                        gold_words = gold_words.cpu().detach().numpy().tolist()
                        pred_words = pred_words.cpu().detach().numpy().tolist()
                        gold_word_labels.append(gold_words)
                        pred_word_labels.append(pred_words)
                        gold_word_names.append(
                            preprocessor.to_word_text(gold_words))
                        gold_word_str = ','.join(gold_word_names[-1])
                        pred_word_str = ','.join(
                            preprocessor.to_word_text(pred_words))

                        word_f.write(
                            f'{audio_id}\t{gold_word_str}\t{pred_word_str}\n')
                    self.ckpt_dir.joinpath(
                        f'outputs/phonetic/dev-clean/{audio_id}.txt')

                    if self.weight_phone_loss > 0:
                        gold_phone_label = phoneme_labels[idx, :sent_lens[idx]]
                        pred_phone_label = phone_logits[
                            idx, :audio_lens[idx]].max(-1)[1]
                        gold_phone_names = ','.join(
                            preprocessor.to_text(gold_phone_label))
                        pred_phone_names = ','.join(
                            preprocessor.tokens_to_text(pred_phone_label))
                        phone_readable_f.write(
                            f'Utterance id: {audio_id}\n'
                            f'Gold transcript: {gold_phone_names}\n'
                            f'Pred transcript: {pred_phone_names}\n\n')

                    if save_embedding:
                        np.savetxt(feat_fn, embedding[idx, :audio_lens[idx]]
                                   [::2].cpu().detach().numpy())  # XXX

        word_f.close()
        phone_readable_f.close()
        avg_loss = total_loss / total_num
        avg_neg_evidence = total_neg_evidence / total_num
        acc = compute_accuracy(gold_word_labels, pred_word_labels)

        print('[TEST RESULT]')
        print('Epoch {}\tLoss: {:.4f}\tEvidence Loss: {:.4f}\tWord Acc.: {:.3f}'\
              .format(self.global_epoch, avg_loss, avg_neg_evidence, acc))
        if self.history['acc'] < acc:
            self.history['acc'] = acc
            self.history['loss'] = avg_loss
            self.history['epoch'] = self.global_epoch
            self.history['iter'] = self.global_iter
            self.save_checkpoint()
        else:
            self.save_checkpoint(filename='latest.tar')
        self.set_mode('train')
Exemple #2
0
  def test(self, save_embedding=False, out_prefix='predictions'):
    self.set_mode('eval')
    testset = self.data_loader['test'].dataset
    preprocessor = testset.preprocessor

    total_loss = 0.
    total_step = 0.

    pred_word_labels = []
    pred_word_labels_quantized = []
    gold_word_labels = []
    if not self.ckpt_dir.joinpath('outputs/phonetic/dev-clean').is_dir():
      os.makedirs(self.ckpt_dir.joinpath('outputs/phonetic/dev-clean'))

    gold_phone_file = os.path.join(testset.data_path, f'{testset.splits[0]}/{testset.splits[0]}_nonoverlap.item')
    word_readable_f = open(self.ckpt_dir.joinpath(f'{out_prefix}_visual_word.{self.global_epoch}.readable'), 'w') 
    phone_file = self.ckpt_dir.joinpath(f'{out_prefix}_phoneme.{self.global_epoch}.txt')
    phone_f = open(phone_file, 'w')

    with torch.no_grad():
      B = 0
      for b_idx, batch in enumerate(self.data_loader['test']):        
        audios = batch[0]
        word_labels = batch[2].squeeze(-1)
        audio_masks = batch[3]
        word_masks = batch[5]
        if b_idx > 2 and self.debug:
          break
        if b_idx == 0: 
          B = audios.size(0)
 
        # (batch size, max segment num, feat dim) or (batch size, max segment num, max segment len, feat dim)
        x = cuda(audios, self.cuda)

        # (batch size,)
        word_labels = cuda(word_labels, self.cuda)

        # (batch size, max segment num) or (batch size, max segment num, max segment len)
        audio_masks = cuda(audio_masks, self.cuda)

        if self.audio_feature == "wav2vec2":
          B = x.size(0)
          T = x.size(1) 
          x = self.audio_feature_net.feature_extractor(x.view(B*T, -1)).permute(0, 2, 1)
          x = x.view(B, T, x.size(-2), x.size(-1))
          x = (x * audio_masks.unsqueeze(-1)).sum(-2) / (audio_masks.sum(-1, keepdim=True) + torch.tensor(1e-10, device=x.device))
          audio_masks = torch.where(audio_masks.sum(-1) > 0,
                                    torch.tensor(1, device=x.device),
                                    torch.tensor(0, device=x.device))
          
        # (batch size, max segment num)
        if audio_masks.dim() == 3: 
          segment_masks = torch.where(audio_masks.sum(-1) > 0,
                                      torch.tensor(1., device=audio_masks.device),
                                      torch.tensor(0., device=audio_masks.device))
        else:
          segment_masks = audio_masks.clone()
             
        if self.audio_net.ds_ratio > 1:
          audio_masks = audio_masks[:, ::self.audio_net.ds_ratio]
          segment_masks = segment_masks[:, ::self.audio_net.ds_ratio]

        # (batch size, max segment num, n visual class)
        word_logits, quantized, phone_loss = self.audio_net(x, masks=audio_masks)

        # segment_word_labels = word_labels.unsqueeze(-1)\
        #                                  .expand(-1, self.max_segment_num)
        # segment_word_labels = (segment_word_labels * segment_masks).flatten().long()
        if self.use_logsoftmax:
          segment_word_logits = (F.log_softmax(word_logits, dim=-1)\
                                * segment_masks.unsqueeze(-1)).sum(-2)
        else:
          segment_word_logits = (word_logits\
                                * segment_masks.unsqueeze(-1)).sum(-2)

        word_loss = F.cross_entropy(segment_word_logits,
                               word_labels,
                               ignore_index=self.ignore_index)
        loss = phone_loss + word_loss
        total_loss += loss.cpu().detach().numpy()
        total_step += 1.
        
        _, _, phone_indices = self.audio_net.encode(x, masks=audio_masks)
        for idx in range(audios.size(0)):
          global_idx = b_idx * B + idx
          audio_id = os.path.splitext(os.path.split(testset.dataset[global_idx][0])[1])[0]
          segments = testset.dataset[global_idx][3]
          pred_phone_label = testset.unsegment(phone_indices[idx], segments).long()

          if int(self.hop_len_ms / 10) * self.audio_net.ds_ratio > 1:
            us_ratio = int(self.hop_len_ms / 10) * self.audio_net.ds_ratio
            pred_phone_label = pred_phone_label.unsqueeze(-1)\
                               .expand(-1, us_ratio).flatten()

          pred_phone_label_list = pred_phone_label.cpu().detach().numpy().tolist()
          pred_phone_names = ','.join([str(p) for p in pred_phone_label_list])
          phone_f.write(f'{audio_id} {pred_phone_names}\n')
          
          gold_word_label = word_labels[idx].cpu().detach().numpy().tolist()
          pred_word_label = segment_word_logits[idx].max(-1)[1].cpu().detach().numpy().tolist()
          pred_word_label_quantized = quantized[idx].prod(-2).max(-1)[1].cpu().detach().numpy().tolist()
           
          gold_word_labels.append(gold_word_label)
          pred_word_labels.append(pred_word_label)
          pred_word_labels_quantized.append(pred_word_label_quantized)
          pred_word_name = preprocessor.to_word_text([pred_word_label])[0]
          pred_word_name_quantized = preprocessor.to_word_text([pred_word_label_quantized])[0]
          gold_word_name = preprocessor.to_word_text([gold_word_label])[0]
          word_readable_f.write(f'Utterance id: {audio_id}\n'
                                f'Gold word label: {gold_word_name}\n'
                                f'Pred word label: {pred_word_name}\n'
                                f'Pred word label by quantizer: {pred_word_name_quantized}\n\n') 
      phone_f.close()
      word_readable_f.close()
      avg_loss = total_loss / total_step
      # Compute word accuracy and word token F1
      print('[TEST RESULT]')
      word_acc = compute_accuracy(gold_word_labels, pred_word_labels)
      word_prec,\
      word_rec,\
      word_f1, _ = precision_recall_fscore_support(np.asarray(gold_word_labels),
                                                   np.asarray(pred_word_labels),
                                                   average='macro')

      word_prec_quantized,\
      word_rec_quantized,\
      word_f1_quantized, _ = precision_recall_fscore_support(np.asarray(gold_word_labels),
                                                             np.asarray(pred_word_labels_quantized),
                                                             average='macro') 

      token_f1,\
      token_prec,\
      token_recall = compute_token_f1(phone_file,
                                      gold_phone_file,
                                      self.ckpt_dir.joinpath(f'confusion.{self.global_epoch}.png'))
      info = f'Epoch {self.global_epoch}\tLoss: {avg_loss:.4f}\n'\
             f'WER: {1-word_acc:.3f}\tWord Acc.: {word_acc:.3f}\n'\
             f'Word Precision: {word_prec:.3f}\tWord Recall: {word_rec:.3f}\tWord F1: {word_f1:.3f}\n'\
             f'(By Quantizer) Word Precision: {word_prec_quantized:.3f}\tWord Recall: {word_rec_quantized:.3f}\tWord F1: {word_f1_quantized:.3f}\n'\
             f'Token Precision: {token_prec:.3f}\tToken Recall: {token_recall:.3f}\tToken F1: {token_f1:.3f}\n'
      print(info) 

      save_path = self.ckpt_dir.joinpath('results_file.txt')
      with open(save_path, 'a') as file:
        file.write(info)

      if self.history['word_acc'] < word_acc:
        self.history['token_result'] = [token_prec, token_recall, token_f1]
        self.history['word_acc'] = word_acc
        self.history['loss'] = avg_loss
        self.history['iter'] = self.global_iter
        self.history['epoch'] = self.global_epoch
        self.save_checkpoint(f'best_acc_{self.config.seed}.tar')
      self.set_mode('train') 
    def test(self, save_embedding=False, out_prefix='predictions'):
        self.set_mode('eval')
        testset = self.data_loader['test'].dataset
        preprocessor = testset.preprocessor

        total_loss = 0.
        total_step = 0.

        pred_word_labels = []
        gold_word_labels = []
        if not self.ckpt_dir.joinpath('outputs/phonetic/dev-clean').is_dir():
            os.makedirs(self.ckpt_dir.joinpath('outputs/phonetic/dev-clean'))

        gold_phone_file = os.path.join(
            testset.data_path,
            f'{testset.splits[0]}/{testset.splits[0]}_nonoverlap.item')
        word_readable_f = open(
            self.ckpt_dir.joinpath(
                f'{out_prefix}_visual_word.{self.global_epoch}.readable'), 'w')

        with torch.no_grad():
            B = 0
            for b_idx, batch in enumerate(self.data_loader['test']):
                audios = batch[0]
                word_labels = batch[2]
                audio_masks = batch[3]
                word_masks = batch[5]
                if b_idx > 2 and self.debug:
                    break
                if b_idx == 0:
                    B = audios.size(0)

                x = cuda(audios, self.cuda)
                if self.audio_feature == "wav2vec2":
                    x = self.audio_feature_net.feature_extractor(x)
                word_labels = cuda(word_labels, self.cuda)
                audio_masks = cuda(audio_masks, self.cuda)
                # (batch size, max word num, max word len, max audio len)
                word_masks = cuda(word_masks, self.cuda)
                # images = cuda(images, self.cuda)
                word_lens = word_masks.sum(dim=(-1, -2)).long()
                word_nums = torch.where(word_lens > 0,
                                        torch.tensor(1, device=x.device),
                                        torch.tensor(0,
                                                     device=x.device)).sum(-1)

                # (batch size, max word num, max word len, feat dim)
                x = torch.matmul(word_masks, x.unsqueeze(1))
                # (batch size x max word num, feat dim, max word len)
                if testset.use_segment:
                    x = x.view(-1, self.max_segment_num, x.size(-1))
                else:
                    x = x.view(-1, self.max_word_len, x.size(-1))

                if self.audio_net.ds_ratio > 1:
                    audio_masks = audio_masks[:, ::self.audio_net.ds_ratio]
                    word_masks = word_masks[:, :, ::self.audio_net.ds_ratio, ::
                                            self.audio_net.ds_ratio]

                word_logits,\
                embedding = self.audio_net(x,
                                           masks=audio_masks,
                                           return_feat=True)
                phone_logits = self.phone_net(embedding)

                word_loss = F.cross_entropy(word_logits,
                                            word_labels.flatten(),
                                            ignore_index=self.ignore_index)
                phone_loss = F.cross_entropy(phone_logits.view(
                    -1, phone_logits.size(-1)),
                                             phone_labels.flatten(),
                                             ignore_index=self.ignore_index)
                loss = word_loss + phone_loss
                total_loss += loss.cpu().detach().numpy()
                total_step += 1.

                word_logits = word_logits.view(word_labels.size(0), -1,
                                               self.n_visual_class)
                for idx in range(audios.size(0)):
                    global_idx = b_idx * B + idx
                    audio_id = os.path.splitext(
                        os.path.split(testset.dataset[global_idx][0])[1])[0]

                    if word_nums[idx] > 0:
                        gold_word_label = word_labels[
                            idx, :word_nums[idx]].cpu().detach().numpy(
                            ).tolist()
                        pred_word_label = word_logits[
                            idx, :word_nums[idx]].max(
                                -1)[1].cpu().detach().numpy().tolist()
                        gold_word_labels.extend(gold_word_label)
                        pred_word_labels.extend(pred_word_label)
                        pred_word_names = preprocessor.to_word_text(
                            pred_word_label)
                        gold_word_names = preprocessor.to_word_text(
                            gold_word_label)

                        for word_idx in range(word_nums[idx]):
                            pred_word_name = pred_word_names[word_idx]
                            gold_word_name = gold_word_names[word_idx]
                            word_readable_f.write(
                                f'Utterance id: {audio_id}\n'
                                f'Gold word label: {gold_word_name}\n'
                                f'Pred word label: {pred_word_name}\n\n')
            word_readable_f.close()
            avg_loss = total_loss / total_step
            # Compute word accuracy and word token F1
            print('[TEST RESULT]')
            word_acc = compute_accuracy(gold_word_labels, pred_word_labels)
            word_prec,\
            word_rec,\
            word_f1, _ = precision_recall_fscore_support(np.asarray(gold_word_labels),
                                                         np.asarray(pred_word_labels),
                                                         average='macro')
            info = f'Epoch {self.global_epoch}\tLoss: {avg_loss:.4f}\n'\
                   f'WER: {1-word_acc:.3f}\tWord Acc.: {word_acc:.3f}\n'\
                   f'Word Precision: {word_prec:.3f}\tWord Recall: {word_rec:.3f}\tWord F1: {word_f1:.3f}\n'
            print(info)

            save_path = self.ckpt_dir.joinpath(
                f'results_file_{self.config.seed}.txt')
            with open(save_path, 'a') as file:
                file.write(info)

            if self.history['word_acc'] < word_acc:
                self.history['word_acc'] = word_acc
                self.history['loss'] = avg_loss
                self.history['iter'] = self.global_iter
                self.history['epoch'] = self.global_epoch
                self.save_checkpoint(f'best_acc_{self.config.seed}.tar')
            self.set_mode('train')
    def train(self, save_embedding=False):
        self.set_mode('train')
        preprocessor = self.data_loader['train'].dataset.preprocessor
        temp_min = 0.1
        anneal_rate = self.anneal_rate
        temp = self.history['temp']

        total_loss = 0.
        total_step = 0.
        for e in range(self.epoch):
            self.global_epoch += 1
            pred_word_labels = []
            gold_word_labels = []
            for idx, (audios, phoneme_labels, word_labels,\
                      audio_masks, phone_masks, word_masks)\
                in enumerate(self.data_loader['train']):
                if idx > 2 and self.debug:
                    break
                self.global_iter += 1

                audios = cuda(audios, self.cuda)

                if self.audio_feature == 'wav2vec2':
                    x = self.audio_feature_net.feature_extractor(audios)
                word_labels = cuda(word_labels, self.cuda)
                phoneme_labels = cuda(phoneme_labels, self.cuda)
                audio_masks = cuda(audio_masks, self.cuda)
                phone_masks = cuda(phone_masks, self.cuda)
                word_masks = cuda(word_masks, self.cuda)
                audio_lens = audio_masks.sum(-1).long()
                sent_lens = phone_masks.sum(-1).long()
                word_lens = (word_labels > 0).long().sum(-1)

                (mu, std),\
                outputs,\
                embedding = self.audio_net(x,
                                           masks=audio_masks,
                                           temp=temp,
                                           num_sample=self.num_sample,
                                           return_feat=True
                                                        )
                word_logits = outputs[:, :, :self.n_visual_class]
                word_logits = torch.matmul(word_masks, word_logits)
                mu_x = outputs[:, :, self.n_visual_class:self.n_visual_class +
                               self.input_size]
                std_x = outputs[:, :, self.n_visual_class +
                                self.input_size:self.n_visual_class +
                                2 * self.input_size]
                std_x = F.softplus(std_x - 5, beta=1)

                word_loss = F.cross_entropy(word_logits.permute(0, 2, 1), word_labels,\
                                             ignore_index=-100,
                                             ).div(math.log(2))
                if self.weight_phone_loss > 0:
                    phone_logits = outputs[:, :, self.n_visual_class +
                                           2 * self.input_size:]
                    phone_loss = F.ctc_loss(F.log_softmax(phone_logits, dim=-1)\
                                            .permute(1, 0, 2,),
                                            phoneme_labels,
                                            audio_lens,
                                            sent_lens)
                else:
                    phone_loss = 0
                info_loss = -0.5 * (1 + 2 * std.log() - mu.pow(2) -
                                    std.pow(2)).sum(-1).mean().div(math.log(2))
                evidence_loss = self.weight_evidence * F.mse_loss(
                    x.permute(0, 2, 1), mu_x)

                loss = self.weight_phone_loss * phone_loss +\
                       self.weight_word_loss * word_loss +\
                       evidence_loss + self.beta * info_loss
                izy_bound = math.log(self.n_visual_class, 2) - word_loss
                izx_bound = info_loss
                total_loss += loss.cpu().detach().numpy()
                total_step += 1.

                self.optim.zero_grad()
                loss.backward()
                self.optim.step()

                for i in range(audios.size(0)):
                    word_len = word_lens[i]
                    if word_len > 0:
                        gold_word_labels.append(word_labels[
                            i, :word_len].cpu().detach().numpy().tolist())
                        pred_word_label = word_logits[i, :word_len].max(-1)[1]
                        pred_word_labels.append(
                            pred_word_label.cpu().detach().numpy().tolist())

                if self.global_iter % 1000 == 0:
                    temp = np.maximum(temp * np.exp(-anneal_rate * idx),
                                      temp_min)
                    avg_loss = total_loss / total_step
                    print(
                        f'i:{self.global_iter:d} temp:{temp} avg loss (total loss):{avg_loss:.2f} ({total_loss:.2f}) '
                        f'IZY:{izy_bound:.2f} IZX:{izx_bound:.2f} Evidence:{evidence_loss}'
                    )

            acc = compute_accuracy(gold_word_labels, pred_word_labels)
            print(
                f'Epoch {self.global_epoch}\ttraining visual word accuracy: {acc:.3f}'
            )
            if (self.global_epoch % 2) == 0:
                self.scheduler.step()
            if (self.global_epoch % 10) == 0:
                self.cluster()
            self.test(save_embedding=save_embedding)
Exemple #5
0
  def test(self, save_embedding=False, out_prefix='predictions'):
    self.set_mode('eval')
    testset = self.data_loader['test'].dataset 
    preprocessor = testset.preprocessor
    
    total_loss = 0.
    total_num = 0.
    gold_labels = []
    pred_labels = []
    if not self.ckpt_dir.joinpath('outputs/phonetic/dev-clean').is_dir():
      os.makedirs(self.ckpt_dir.joinpath('outputs/phonetic/dev-clean'))

    gold_path = os.path.join(os.path.join(testset.data_path, f'{testset.splits[0]}'))
    out_word_file = os.path.join(
                      self.ckpt_dir,
                      f'{out_prefix}_word.{self.global_epoch}.readable'
                    )
    out_phone_readable_file = os.path.join(
                      self.ckpt_dir,
                      f'{out_prefix}_phoneme.{self.global_epoch}.readable'
                    )
    out_phone_file = os.path.join(
                       self.ckpt_dir,
                       f'{out_prefix}_phoneme.{self.global_epoch}.txt'
                     )

    word_f = open(out_word_file, 'w')
    word_f.write('Image ID\tGold label\tPredicted label\n')
    phone_readable_f = open(out_phone_readable_file, 'w')
    phone_f = open(out_phone_file, 'w')
    
    gold_word_labels = []
    gold_phone_labels = []
    pred_word_labels = []
    pred_phone_labels = []
     
    with torch.no_grad():
      B = 0
      for b_idx, (audios, phoneme_labels, word_labels,\
                  audio_masks, phone_masks, word_masks)\
                  in enumerate(self.data_loader['test']):
        if b_idx > 2 and self.debug:
          break
        if b_idx == 0:
          B = audios.size(0)
          
        audios = cuda(audios, self.cuda)
        if self.audio_feature == 'wav2vec2':
          audios = self.audio_feature_net.feature_extractor(audios)
        phoneme_labels = cuda(phoneme_labels, self.cuda)
        word_labels = cuda(word_labels, self.cuda)
        audio_masks = cuda(audio_masks, self.cuda)
        phone_masks = cuda(phone_masks, self.cuda)
        word_masks = cuda(word_masks, self.cuda).sum(-2) # Use averaged frame embedding by default
        if self.audio_net.ds_ratio > 1:
          audio_masks = audio_masks[:, ::self.audio_net.ds_ratio]
          word_masks = word_masks[:, :, ::self.audio_net.ds_ratio]
        
        audio_lens = audio_masks.sum(-1).long()
        sent_lens = phone_masks.sum(-1).long()
        word_lens = (word_labels >= 0).long().sum(-1)

        if self.model_type == 'blstm':
          out_logits, embedding = self.audio_net(audios, return_feat=True)
          word_logits = out_logits[:, :, :self.n_visual_class]
          phone_logits = out_logits[:, :, self.n_visual_class:]
        else:
          gumbel_logits, out_logits, encoding, embedding = self.audio_net(
            audios, masks=audio_masks,
            return_feat=True)
          phone_logits = gumbel_logits
          word_logits = out_logits

        if self.model_type == 'vq-mlp':
          word_logits = out_logits[:, :, :self.n_visual_class]

        word_logits = torch.matmul(word_masks, word_logits)
        word_loss = F.cross_entropy(word_logits.permute(0, 2, 1), 
                                    word_labels,
                                    ignore_index=-100)\
                                    .div(math.log(2)) 
        phone_loss = F.ctc_loss(F.log_softmax(phone_logits, dim=-1)\
                                  .permute(1, 0, 2),
                                phoneme_labels,
                                audio_lens,
                                sent_lens)
        info_loss = (F.softmax(phone_logits, dim=-1)\
                      * F.log_softmax(phone_logits, dim=-1)
                    ).sum().div(sent_lens.sum() * math.log(2))
        total_loss += word_loss + phone_loss + self.beta * info_loss
        total_num += 1. 

        for idx in range(audios.size(0)):
          global_idx = b_idx * B + idx
          audio_id = os.path.splitext(os.path.split(testset.dataset[global_idx][0])[1])[0]
          if word_lens[idx] > 0:
            gold_words = word_labels[idx, :word_lens[idx]]
            pred_words = word_logits[idx, :word_lens[idx]].max(-1)[1]
            gold_words = gold_words.cpu().detach().numpy().tolist()
            pred_words = pred_words.cpu().detach().numpy().tolist()
            gold_word_labels.append(gold_words)
            pred_word_labels.append(pred_words)

            gold_word_names = ','.join(preprocessor.to_word_text(
                                    gold_words)
                                  )
            pred_word_names = ','.join(preprocessor.to_word_text(
                                    pred_words)
                                  )
            word_f.write(f'{audio_id}\t{gold_word_names}\t{pred_word_names}\n')

          feat_fn = self.ckpt_dir.joinpath(f'outputs/phonetic/dev-clean/{audio_id}.txt')
          if save_embedding:
            np.savetxt(feat_fn, embedding[idx, :audio_lens[idx]][::2].cpu().detach().numpy()) # XXX
           
          gold_phone_label = phoneme_labels[idx, :sent_lens[idx]]
          pred_phone_label = phone_logits[idx, :audio_lens[idx]].max(-1)[1]
          gold_phone_names = ','.join(preprocessor.to_text(gold_phone_label))
          pred_phone_names = ','.join(preprocessor.tokens_to_text(pred_phone_label))
          phone_readable_f.write(f'Utterance id: {audio_id}\n'
                                 f'Gold transcript: {gold_phone_names}\n'
                                 f'Pred transcript: {pred_phone_names}\n\n')

          gold_phone_label = gold_phone_label.cpu().detach().numpy().tolist()
          if int(self.hop_len_ms / 10) * self.audio_net.ds_ratio > 1:
            us_ratio = int(self.hop_len_ms / 10) * self.audio_net.ds_ratio
            pred_phone_label = pred_phone_label.unsqueeze(-1)\
                                 .expand(-1, us_ratio).flatten()
          pred_phone_label = pred_phone_label.cpu().detach().numpy().tolist()
          gold_phone_labels.append(gold_phone_label)
          pred_phone_labels.append(pred_phone_label) 
          
          pred_phone_label = preprocessor.to_index(preprocessor.to_text(pred_phone_label))
          pred_phone_label = pred_phone_label.cpu().detach().numpy().tolist()
          pred_phone_names = ','.join([str(p) for p in pred_phone_label])
          phone_f.write(f'{audio_id} {pred_phone_names}\n')  
   
    word_f.close()
    phone_f.close()              
   
    avg_loss = total_loss / total_num 
    acc = compute_accuracy(gold_word_labels, pred_word_labels)
    dist, n_tokens = compute_edit_distance(pred_phone_labels, gold_phone_labels, preprocessor)
    pter = float(dist) / float(n_tokens)
    print('[TEST RESULT]')
    print('Epoch {}\tLoss: {:.4f}\tWord Acc.: {:.3f}\tPTER: {:.3f}'\
          .format(self.global_epoch, avg_loss, acc, pter))
    token_f1, token_prec, token_recall = compute_token_f1(
                                           out_phone_file,
                                           gold_path,
                                           os.path.join(
                                             self.ckpt_dir,
                                             f'confusion.{self.global_epoch}.png'
                                           )
                                         )
    if self.history['acc'] < acc:
      self.history['acc'] = acc
      self.history['loss'] = avg_loss
      self.history['epoch'] = self.global_epoch
      self.history['iter'] = self.global_iter
      self.history['token_f1'] = token_f1
      self.save_checkpoint('best_acc.tar')
    self.set_mode('train')
Exemple #6
0
  def train(self, save_embedding=False):
    self.set_mode('train')
    preprocessor = self.data_loader['train'].dataset.preprocessor
    temp_min = 0.1
    anneal_rate = self.anneal_rate
    temp = 1.

    total_loss = 0.
    total_step = 0
    criterion = MicroTokenFLoss(beta=0.1) 
    for e in range(self.epoch):
      self.global_epoch += 1
      pred_word_labels = []
      gold_word_labels = []
      pred_phone_labels = []
      gold_phone_labels = []
      for idx, (audios, phoneme_labels, word_labels,\
                audio_masks, phone_masks, word_masks)\
          in enumerate(self.data_loader['train']):
        if idx > 2 and self.debug:
          break
        self.global_iter += 1
         
        x = cuda(audios, self.cuda)
        if self.audio_feature == "wav2vec2":
          x = self.audio_feature_net.feature_extractor(x)
        phoneme_labels = cuda(phoneme_labels, self.cuda)
        word_labels = cuda(word_labels, self.cuda)
        audio_masks = cuda(audio_masks, self.cuda)
        phone_masks = cuda(phone_masks, self.cuda)
        word_masks = cuda(word_masks, self.cuda).sum(-2) # Use averaged frame embedding by default

        if self.audio_net.ds_ratio > 1:
          audio_masks = audio_masks[:, ::self.audio_net.ds_ratio]
          word_masks = word_masks[:, :, ::self.audio_net.ds_ratio]

        audio_lens = audio_masks.sum(-1).long()
        sent_lens = phone_masks.sum(-1).long()
        word_lens = (word_labels >= 0).long().sum(-1)

        if self.model_type == "blstm":
          out_logits, embedding = self.audio_net(x, return_feat=True)
          word_logits = out_logits[:, :, :self.n_visual_class]
          phone_logits = out_logits[:, :, self.n_visual_class:]
        else:
          gumbel_logits, out_logits, _, embedding = self.audio_net(
            x, masks=audio_masks,
            temp=temp,
            num_sample=self.num_sample,
            return_feat=True)
          phone_logits = gumbel_logits
          word_logits = out_logits
        quantized = None
        if self.model_type == 'vq-mlp':
          word_logits = out_logits[:, :, :self.n_visual_class]
          quantized = out_logits[:, :, self.n_visual_class:]

        word_logits = torch.matmul(word_masks, word_logits)
        
        # XXX word_loss = F.cross_entropy(word_logits.permute(0, 2, 1), word_labels,\
        #                            ignore_index=-100,
        #                            ).div(math.log(2))
        word_num_mask = torch.where(word_masks.sum(-1) > 0,
                                    torch.tensor(1, dtype=torch.long, device=word_masks.device), 
                                    torch.tensor(0, dtype=torch.long, device=word_masks.device))
        word_loss = criterion(F.softmax(word_logits, dim=-1),
                              F.one_hot(word_labels * word_num_mask, self.n_visual_class),
                              word_num_mask) # XXX  

        phone_loss = F.ctc_loss(F.log_softmax(phone_logits, dim=-1)\
                                  .permute(1, 0, 2),
                                phoneme_labels,
                                audio_lens,
                                sent_lens) 
        info_loss = (F.softmax(phone_logits, dim=-1)\
                      * F.log_softmax(phone_logits, dim=-1)
                    ).sum().div(sent_lens.sum()*math.log(2)) 
        loss = self.weight_phone_loss * phone_loss +\
               self.weight_word_loss * word_loss +\
               self.beta * info_loss
        if self.model_type == 'vq-mlp':
          loss += self.audio_net.quantize_loss(embedding, quantized,
                                               masks=audio_masks)

        izy_bound = math.log(self.n_visual_class, 2) - word_loss
        izx_bound = info_loss
        total_loss += loss.cpu().detach().numpy()
        total_step += 1.

        self.optim.zero_grad()
        loss.backward()

        if self.max_grad_norm is not None:
          torch.nn.utils.clip_grad_norm_(
            self.audio_net.parameters(),
            self.max_grad_norm
          )
        self.optim.step()
  
        for i in range(audios.size(0)):
          audio_len = audio_lens[i]
          sent_len = sent_lens[i]
          word_len = word_lens[i]

          gold_phone_label = phoneme_labels[i, :sent_len]
          pred_phone_label = phone_logits[i, :audio_len].max(-1)[1]
          gold_phone_labels.append(gold_phone_label.cpu().detach().numpy().tolist())
          pred_phone_labels.append(pred_phone_label.cpu().detach().numpy().tolist())
          if word_len > 0:
            gold_word_labels.append(word_labels[i, :word_len].cpu().detach().numpy().tolist())
            pred_word_label = word_logits[i, :word_len].max(-1)[1]
            pred_word_labels.append(pred_word_label.cpu().detach().numpy().tolist())
          
        if self.global_iter % 1000 == 0:
          temp = np.maximum(temp * np.exp(-anneal_rate * idx), temp_min)
          avg_loss = total_loss / total_step
          print(f'i:{self.global_iter:d} temp:{temp} avg loss (total loss):{avg_loss:.2f} ({total_loss:.2f}) '
                f'IZY:{izy_bound:.2f} IZX:{izx_bound:.2f}')
          self.history['temp'] = temp

      # Evaluate training visual word classification accuracy and phone token error rate
      acc = compute_accuracy(gold_word_labels, pred_word_labels)
      dist, n_tokens = compute_edit_distance(pred_phone_labels, gold_phone_labels, preprocessor)
      pter = float(dist) / float(n_tokens)
      print(f'Epoch {self.global_epoch}\ttraining visual word accuracy: {acc:.3f}\ttraining phone token error rate: {pter:.3f}')

      if (self.global_epoch % 2) == 0:
        self.scheduler.step()
      self.test(save_embedding=save_embedding)
  def test_quantized(self, out_prefix='predictions'): # Compute quantized word F1 
    self.set_mode('eval')
    testset = self.data_loader['test'].dataset
    preprocessor = testset.preprocessor

    total_loss = 0.
    total_step = 0.
   
    pred_word_labels = []
    gold_word_labels = []
 
    word_readable_f = open(self.ckpt_dir.joinpath(f'{out_prefix}_visual_word_by_{self.clustering_method}.{self.global_epoch}.readable'), 'w')
    with torch.no_grad():
      B = 0
      for b_idx, batch in enumerate(self.data_loader['test']):        
        audios = batch[0]
        word_labels = batch[2].squeeze(-1)
        audio_masks = batch[3]
        word_masks = batch[5]
        if b_idx > 2 and self.debug:
          break
        if b_idx == 0: 
          B = audios.size(0)
 
        # (batch size, max segment num, feat dim) or (batch size, max segment num, max segment len, feat dim)
        x = cuda(audios, self.cuda)

        # (batch size,)
        word_labels = cuda(word_labels, self.cuda)

        # (batch size, max segment num) or (batch size, max segment num, max segment len)
        audio_masks = cuda(audio_masks, self.cuda)

        if self.audio_feature == "wav2vec2":
          B = x.size(0)
          T = x.size(1) 
          x = self.audio_feature_net.feature_extractor(x.view(B*T, -1)).permute(0, 2, 1)
          x = x.view(B, T, x.size(-2), x.size(-1))
          x = (x * audio_masks.unsqueeze(-1)).sum(-2) / (audio_masks.sum(-1, keepdim=True) + torch.tensor(1e-10, device=x.device))
          audio_masks = torch.where(audio_masks.sum(-1) > 0,
                                    torch.tensor(1, device=x.device),
                                    torch.tensor(0, device=x.device))
          
        # (batch size, max segment num)
        if audio_masks.dim() == 3: 
          segment_masks = torch.where(audio_masks.sum(-1) > 0,
                                      torch.tensor(1., device=audio_masks.device),
                                      torch.tensor(0., device=audio_masks.device))
        else:
          segment_masks = audio_masks.clone()
             
        if self.audio_net.ds_ratio > 1:
          audio_masks = audio_masks[:, ::self.audio_net.ds_ratio]
          segment_masks = segment_masks[:, ::self.audio_net.ds_ratio]

        # (batch size, max segment num, n visual class)
        word_logits, _, phone_loss = self.audio_net(x, masks=audio_masks)
        word_probs = F.softmax(word_logits, dim=-1).detach().cpu().numpy().astype(np.float64)
        cluster_labels = [self.clusterer.predict(word_probs[i]) for i in range(B)]
        
        word_probs_quantized = [torch.FloatTensor(self.clusterer.cluster_centers_[labels]) for labels in cluster_labels]
        word_probs_quantized = torch.stack(word_probs_quantized).to(x.device)
        word_logits_quantized = torch.log(word_probs_quantized)
  
        if self.use_logsoftmax:
          segment_word_logits = (F.log_softmax(word_logits_quantized, dim=-1)\
                                * segment_masks.unsqueeze(-1)).sum(-2)
        else:
          segment_word_logits = (word_logits_quantized\
                                * segment_masks.unsqueeze(-1)).sum(-2)

        for idx in range(audios.size(0)):
          global_idx = b_idx * B + idx
          audio_id = os.path.splitext(os.path.split(testset.dataset[global_idx][0])[1])[0]
          
          gold_word_label = word_labels[idx].cpu().detach().numpy().tolist()
          pred_word_label = segment_word_logits[idx].max(-1)[1].cpu().detach().numpy().tolist()
           
          gold_word_labels.append(gold_word_label)
          pred_word_labels.append(pred_word_label)
          pred_word_name = preprocessor.to_word_text([pred_word_label])[0]
          gold_word_name = preprocessor.to_word_text([gold_word_label])[0]
          word_readable_f.write(f'Utterance id: {audio_id}\n'
                                f'Gold word label: {gold_word_name}\n'
                                f'Pred word label: {pred_word_name}\n')
    word_acc = compute_accuracy(gold_word_labels, pred_word_labels)
    word_prec,\
    word_rec,\
    word_f1, _ = precision_recall_fscore_support(np.asarray(gold_word_labels),
                                                 np.asarray(pred_word_labels),
                                                 average='macro')
    print('[TEST RESULTS AFTER CLUSTERING]')
    info = f'Epoch {self.global_epoch}\tClustering Method: {self.clustering_method}\n'\
           f'WER: {1-word_acc:.3f}\tWord Acc.: {word_acc:.3f}\n'\
           f'Word Precision: {word_prec:.3f}\tWord Recall: {word_rec:.3f}\tWord F1: {word_f1:.3f}'
    print(info)
    save_path = self.ckpt_dir.joinpath(f'results_file_{self.config.seed}.txt')
    with open(save_path, 'a') as file:
      file.write(info+'\n')
    self.set_mode('train')
  def train(self, save_embedding=False):
    self.set_mode('train')
    preprocessor = self.data_loader['train'].dataset.preprocessor
    temp_min = 0.1
    anneal_rate = self.anneal_rate
    temp = 1.

    total_loss = 0.
    total_step = 0
    for e in range(self.epoch):
      self.global_epoch += 1
      pred_word_labels = []
      gold_word_labels = []
      pred_phone_labels = []
      gold_phone_labels = []
      for b_idx, (audios, phoneme_labels, word_labels,\
                audio_masks, phone_masks, word_masks)\
          in enumerate(self.data_loader['train']):
        if b_idx > 2 and self.debug:
          break
        self.global_iter += 1
         
        x = cuda(audios, self.cuda)
        if self.audio_feature == "wav2vec2":
          x = self.audio_feature_net.feature_extractor(x)
        phoneme_labels = cuda(phoneme_labels, self.cuda)
        word_labels = cuda(word_labels, self.cuda)
        audio_masks = cuda(audio_masks, self.cuda)
        phone_masks = cuda(phone_masks, self.cuda)
        word_masks = cuda(word_masks, self.cuda)
        if self.audio_net.ds_ratio > 1:
          audio_masks = audio_masks[:, ::self.audio_net.ds_ratio]
          word_masks = word_masks[:, :, ::self.audio_net.ds_ratio]

        audio_lens = audio_masks.sum(-1).long()
        sent_lens = phone_masks.sum(-1).long()
        word_lens = (word_labels >= 0).long().sum(-1)

        phone_logits, word_logits, _, embedding = self.audio_net(
                               x, masks=audio_masks,
                               temp=temp,
                               num_sample=self.num_sample,
                               return_feat=True)
        
        # Compute phoneme one-hot vector
        phoneme_vectors = F.one_hot(phoneme_labels, self.n_phone_class)
        phone_denoised_logits,\
        phone_word_logits,\
        denoised_encodings,\
        embedding = self.phone_net(phoneme_vectors,
                                   temp=temp,
                                   num_sample=self.num_sample,
                                   return_feat=True)

        quantized = None
        if self.model_type == 'vq-mlp':
          word_logits = out_logits[:, :, :self.n_visual_class]
          quantized = out_logits[:, :, self.n_visual_class:]

        word_logits = torch.matmul(word_masks, word_logits)
        
        word_loss = F.cross_entropy(word_logits.permute(0, 2, 1), word_labels,\
                                    ignore_index=-100,
                                    ).div(math.log(2))
        info_loss = (F.softmax(phone_logits, dim=-1)\
                      * F.log_softmax(phone_logits, dim=-1)
                    ).sum().div(audio_lens.sum()*math.log(2)) 

        # Permutation-invariant CTC loss for multilingual phones
        batch_size = x.size(0)
        phone_word_losses = [] 
        num_words = np.where(word_masks.sum(-1) > 0, 
                             torch.tensor(1, device=x.device), 
                             torch.tensor(0, device=x.device)).sum(-1) 
        for idx in range(batch_size):
          word_orders = list(itertools.permutations(range(num_words[idx])))
          word_orders = word_orders[:200] # Limit the number of order
          phone_word_losses.append(torch.max(
                              [F.ctc_loss(F.log_softmax(phone_denoised_logits[idx], dim=-1)\
                                            .permute(1, 0, 2), 
                                          word_labels[word_order],
                                          sent_lens[idx],
                                          num_words[idx])
                               for word_order in word_orders]
                              )
                            )
        phone_word_loss = torch.sum(phone_word_losses)
        phone_info_loss = (F.softmax(phone_denoised_logits, dim=-1)\
                      * F.log_softmax(phone_denoised_logits, dim=-1)
                    ).sum().div(sent_lens.sum()*math.log(2)) 
        

        # Use denoised phoneme labels for training the phoneme classifier
        phone_word_encodings = F.gumbel_softmax(phone_word_logits, 
                                                tau=temp,
                                                dim=-1)
        denoising_mask = torch.where(phone_word_encodings.max(-1)[1].detach() > 0,
                                     torch.tensor(1, device=x.device),
                                     torch.tensor(0, device=x.device)).detach()
        phoneme_labels_denoised = denoising_mask * denoised_encodings.max(-1)[1].detach()\
                                    + (1 - denoising_mask) * phoneme_labels 
        phone_loss = F.ctc_loss(F.log_softmax(phone_logits, dim=-1)\
                                  .permute(1, 0, 2),
                                phoneme_labels_denoised,
                                audio_lens,
                                sent_lens)
        audio_ib_loss = self.weight_phone_loss * phone_loss\
                        + self.weight_word_loss * word_loss\
                        + self.beta * info_loss\

        phone_ib_loss = self.weight_phone_word_loss * phone_word_loss\
                        + self.beta * phone_info_loss # TODO weight_phone_word

        loss =  audio_ib_loss + phone_ib_loss
        if self.model_type == 'vq-mlp':
          loss += self.audio_net.quantize_loss(embedding, quantized,
                                               masks=audio_masks)

        izy_bound = math.log(self.n_visual_class, 2) - word_loss
        izx_bound = info_loss
        total_loss += loss.cpu().detach().numpy()
        total_step += 1.

        self.optim.zero_grad()
        loss.backward()
        if self.max_grad_norm is not None:
          torch.nn.utils.clip_grad_norm_(
            self.audio_net.parameters(),
            self.max_grad_norm
          )
        self.optim.step()
  
        for i in range(audios.size(0)):
          audio_len = audio_lens[i]
          sent_len = sent_lens[i]
          word_len = word_lens[i]

          gold_phone_label = phoneme_labels_denoised[i, :sent_len]
          pred_phone_label = phone_logits[i, :audio_len].max(-1)[1]
          gold_phone_labels.append(gold_phone_label.cpu().detach().numpy().tolist())
          pred_phone_labels.append(pred_phone_label.cpu().detach().numpy().tolist())

          if word_len > 0:
            gold_word_labels.append(word_labels[i, :word_len].cpu().detach().numpy().tolist())
            pred_word_label = word_logits[i, :word_len].max(-1)[1]
            pred_word_labels.append(pred_word_label.cpu().detach().numpy().tolist())

        if self.global_iter % 1000 == 0:
          temp = np.maximum(temp * np.exp(-anneal_rate * b_idx), temp_min)
          avg_loss = total_loss / total_step
          print(f'i:{self.global_iter:d} temp:{temp} avg loss (total loss):{avg_loss:.2f} ({total_loss:.2f}) '
                f'IZY:{izy_bound:.2f} IZX:{izx_bound:.2f}'
                f'phone_loss:{phone_loss:.5f} phone_word_loss:{phone_word_loss:.5f}')

      # Evaluate training visual word classification accuracy and phone token error rate
      acc = compute_accuracy(gold_word_labels, pred_word_labels)
      dist, n_tokens = compute_edit_distance(pred_phone_labels, gold_phone_labels, preprocessor)
      pter = float(dist) / float(n_tokens)
      print(f'Epoch {self.global_epoch}\ttraining visual word accuracy: {acc:.3f}\ttraining phone token error rate: {pter:.3f}')

      if (self.global_epoch % 2) == 0:
        self.scheduler.step()
      self.test(save_embedding=save_embedding)
Exemple #9
0
    def test(self, save_embedding=False, out_prefix='predictions'):
        self.set_mode('eval')
        testset = self.data_loader['test'].dataset
        preprocessor = testset.preprocessor

        total_loss = 0.
        total_word_loss = 0.
        total_phone_loss = 0.
        total_step = 0.

        pred_word_labels = []
        gold_word_labels = []
        if not self.ckpt_dir.joinpath('outputs/phonetic/dev-clean').is_dir():
            os.makedirs(self.ckpt_dir.joinpath('outputs/phonetic/dev-clean'))

        gold_phone_file = os.path.join(
            testset.data_path,
            f'{testset.splits[0]}/{testset.splits[0]}_nonoverlap.item')
        word_readable_f = open(
            self.ckpt_dir.joinpath(
                f'{out_prefix}_visual_word.{self.global_epoch}.readable'), 'w')
        phone_file = self.ckpt_dir.joinpath(
            f'{out_prefix}_phoneme.{self.global_epoch}.txt')
        phone_f = open(phone_file, 'w')

        with torch.no_grad():
            B = 0
            for b_idx, batch in enumerate(self.data_loader['test']):
                audios = batch[0]
                word_labels = batch[2]
                audio_masks = batch[3]
                word_masks = batch[5]
                if b_idx > 2 and self.debug:
                    break
                if b_idx == 0:
                    B = audios.size(0)

                x = cuda(audios, self.cuda)
                if self.audio_feature == "wav2vec2":
                    x = self.audio_feature_net.feature_extractor(x)
                word_labels = cuda(word_labels, self.cuda)
                audio_masks = cuda(audio_masks, self.cuda)
                # (batch size, max word num, max word len, max audio len)
                word_masks = cuda(word_masks, self.cuda)
                # (batch size x max word num, max word len)
                word_masks_2d = torch.where(
                    word_masks.sum(-1).view(-1, self.max_word_len) > 0,
                    torch.tensor(1, device=x.device),
                    torch.tensor(0, device=x.device))

                # (batch size x max word num)
                word_masks_1d = torch.where(
                    word_masks_2d.sum(-1) > 0, torch.tensor(1,
                                                            device=x.device),
                    torch.tensor(0, device=x.device))
                # images = cuda(images, self.cuda)

                if self.audio_net.ds_ratio > 1:
                    audio_masks = audio_masks[:, ::self.audio_net.ds_ratio]
                    word_masks = word_masks[:, :, ::self.audio_net.ds_ratio]
                audio_lens = audio_masks.sum(-1).long()
                # (batch size, max word num)
                word_lens = word_masks.sum(dim=(-1, -2)).long()
                # (batch size,)
                word_nums = torch.where(word_lens > 0,
                                        torch.tensor(1, device=x.device),
                                        torch.tensor(0,
                                                     device=x.device)).sum(-1)

                sent_phone_logits,\
                word_logits,\
                 _,\
                 embedding = self.audio_net(x,
                                            masks=audio_masks,
                                            return_feat=True)

                # (batch size, max word num, max word len, n phone class)
                phone_logits = torch.matmul(word_masks,
                                            sent_phone_logits.unsqueeze(1))
                # (batch size x max word num, max word len, n phone class)
                phone_logits = phone_logits.view(-1, self.max_word_len,
                                                 self.n_phone_class)
                phone_probs = F.softmax(phone_logits, dim=-1)

                # (batch size x max word num, n phone class)
                pm_phone_probs = self.audio_net.reverse_forward(
                    word_labels.flatten(), ignore_index=self.ignore_index)
                # (batch size x max word num, max word len, n phone class)
                pm_phone_probs = pm_phone_probs\
                                 .unsqueeze(1).expand(-1, self.max_word_len, -1)

                # (batch size, max word num, max word len, n word class)
                word_logits = torch.matmul(word_masks,
                                           word_logits.unsqueeze(1))
                # (batch size, max word num, n word class)
                word_logits = (
                    word_logits * word_masks.sum(-1, keepdim=True)
                ).sum(
                    -2
                )  # / (word_lens.unsqueeze(-1) + EPS) # Average over frames
                # (batch size x max word num, n word class)
                word_logits_2d = word_logits.view(-1, self.n_visual_class)
                word_probs = F.softmax(word_logits_2d, dim=-1)

                if self.loss_type == 'cross_entropy':
                    # phone_loss = -((pm_phone_probs * word_masks_2d.unsqueeze(-1)) *\
                    #              torch.log_softmax(phone_logits, dim=-1)).mean()
                    phone_loss = self.criterion(phone_probs, pm_phone_probs,
                                                word_masks_2d)
                    word_loss = F.cross_entropy(word_logits_2d,
                                                word_labels.flatten(),
                                                ignore_index=self.ignore_index)
                else:
                    phone_loss = self.criterion(phone_probs, pm_phone_probs,
                                                word_masks_2d)
                    if self.max_normalize:
                        word_probs = word_probs / word_probs.max(
                            -1, keepdim=True)[0]
                    # (batch size x max word num, n word class)
                    word_labels_onehot = F.one_hot(word_labels,
                                                   self.n_visual_class)
                    word_labels_onehot = word_masks_1d.unsqueeze(
                        -1) * word_labels_onehot
                    word_loss = self.criterion(
                        word_probs, word_labels_onehot,
                        torch.ones(audio_lens.size(0), device=x.device))
                loss = word_loss + phone_loss
                total_loss += loss.cpu().detach().numpy()
                total_word_loss += word_loss.cpu().detach().numpy()
                total_phone_loss += phone_loss.cpu().detach().numpy()
                total_step += 1.

                for idx in range(audios.size(0)):
                    global_idx = b_idx * B + idx
                    audio_id = os.path.splitext(
                        os.path.split(testset.dataset[global_idx][0])[1])[0]
                    pred_phone_label = sent_phone_logits[
                        idx, :audio_lens[idx]].max(-1)[1]
                    pred_phone_label_list = pred_phone_label.cpu().detach(
                    ).numpy().tolist()
                    pred_phone_names = ','.join(
                        [str(p) for p in pred_phone_label_list])
                    phone_f.write(f'{audio_id} {pred_phone_names}\n')

                    if word_nums[idx] > 0:
                        gold_word_label = word_labels[
                            idx, :word_nums[idx]].cpu().detach().numpy(
                            ).tolist()
                        pred_word_label = word_logits[
                            idx, :word_nums[idx]].max(
                                -1)[1].cpu().detach().numpy().tolist()

                        gold_word_labels.extend(gold_word_label)
                        pred_word_labels.extend(pred_word_label)
                        pred_word_names = preprocessor.to_word_text(
                            pred_word_label)
                        gold_word_names = preprocessor.to_word_text(
                            gold_word_label)

                        for word_idx in range(word_nums[idx]):
                            pred_word_name = pred_word_names[word_idx]
                            gold_word_name = gold_word_names[word_idx]
                            word_readable_f.write(
                                f'Utterance id: {audio_id}\n'
                                f'Gold word label: {gold_word_name}\n'
                                f'Pred word label: {pred_word_name}\n\n')
            phone_f.close()
            word_readable_f.close()
            avg_loss = total_loss / total_step
            avg_word_loss = total_word_loss / total_step
            avg_phone_loss = total_phone_loss / total_step
            # Compute word accuracy and word token F1
            print('[TEST RESULT]')
            word_acc = compute_accuracy(gold_word_labels, pred_word_labels)
            word_prec,\
            word_rec,\
            word_f1, _ = precision_recall_fscore_support(np.asarray(gold_word_labels),
                                                         np.asarray(pred_word_labels),
                                                         average='macro')
            print(
                f'Epoch {self.global_epoch}\tLoss: {avg_loss:.4f}\tWord Loss: {avg_word_loss}\tPhone Loss: {avg_phone_loss}\n'
                f'WER: {1-word_acc:.3f}\tWord Acc.: {word_acc:.3f}\n'
                f'Word Precision: {word_prec:.3f}\tWord Recall: {word_rec:.3f}\tWord F1: {word_f1:.3f}'
            )
            token_f1,\
            token_prec,\
            token_recall = compute_token_f1(phone_file,
                                            gold_phone_file,
                                            self.ckpt_dir.joinpath(f'confusion.{self.global_epoch}.png'))

            if self.history['word_acc'] < word_acc:
                self.history['token_f1'] = token_f1
                self.history['word_acc'] = word_acc
                self.history['loss'] = avg_loss
                self.history['iter'] = self.global_iter
                self.history['epoch'] = self.global_epoch
                self.save_checkpoint(f'best_acc_{self.config.seed}.tar')
            self.set_mode('train')