def _evaluate_batches(self, model, dataset): # import pdb; pdb.set_trace() model.eval() resloss_asr = 0 resloss_ae = 0 resloss_kl = 0 resloss_l2 = 0 resloss_norm = 0 # accuracy match_asr = 0 total_asr = 0 match_ae = 0 total_ae = 0 # bleu hyp_corpus_asr = [] ref_corpus_asr = [] hyp_corpus_ae = [] ref_corpus_ae = [] evaliter = iter(dataset.iter_loader) out_count = 0 with torch.no_grad(): for idx in range(len(evaliter)): batch_items = evaliter.next() # import pdb; pdb.set_trace() # load data batch_src_ids = batch_items['srcid'][0] batch_src_lengths = batch_items['srclen'] batch_acous_feats = batch_items['acous_feat'][0] batch_acous_lengths = batch_items['acouslen'] # separate into minibatch batch_size = batch_src_ids.size(0) batch_seq_len = int(max(batch_src_lengths)) n_minibatch = int(batch_size / self.minibatch_size) n_minibatch += int(batch_size % self.minibatch_size > 0) for bidx in range(n_minibatch): loss_asr = NLLLoss() loss_asr.reset() loss_ae = NLLLoss() loss_ae.reset() loss_kl = KLDivLoss() loss_kl.reset() loss_l2 = MSELoss() loss_l2.reset() i_start = bidx * self.minibatch_size i_end = min(i_start + self.minibatch_size, batch_size) src_ids = batch_src_ids[i_start:i_end] src_lengths = batch_src_lengths[i_start:i_end] acous_feats = batch_acous_feats[i_start:i_end] acous_lengths = batch_acous_lengths[i_start:i_end] src_len = max(src_lengths) acous_len = max(acous_lengths) acous_len = acous_len + 8 - acous_len % 8 src_ids = src_ids.to(device=self.device) acous_feats = acous_feats[:, :acous_len].to( device=self.device) # debug oom # acous_feats, acous_lengths, src_ids, tgt_ids = self._debug_oom(acous_len, acous_feats) non_padding_mask_src = src_ids.data.ne(PAD) # [run-TF] to save time # out_dict = model.forward_train(src_ids, acous_feats=acous_feats, # acous_lens=acous_lengths, mode='ASR', use_gpu=self.use_gpu) # [run-FR] to get true stats out_dict = model.forward_eval(src=src_ids, acous_feats=acous_feats, acous_lens=acous_lengths, mode='AE_ASR', use_gpu=self.use_gpu) preds_asr = out_dict['preds_asr'] logps_asr = out_dict['logps_asr'] emb_asr = out_dict['emb_asr'] preds_ae = out_dict['preds_ae'] logps_ae = out_dict['logps_ae'] emb_ae = out_dict['emb_ae'] refs_ae = out_dict['refs_ae'] logps_hyp_asr = logps_asr preds_hyp_asr = preds_asr emb_hyp_asr = emb_asr logps_hyp_ae = logps_ae preds_hyp_ae = preds_ae emb_hyp_ae = emb_ae # evaluation if not self.eval_with_mask: loss_asr.eval_batch( logps_hyp_asr.reshape(-1, logps_hyp_asr.size(-1)), src_ids[:, 1:].reshape(-1)) loss_asr.norm_term = 1.0 * src_ids.size( 0) * src_ids[:, 1:].size(1) loss_ae.eval_batch( logps_hyp_ae.reshape(-1, logps_hyp_ae.size(-1)), refs_ae.reshape(-1)) loss_ae.norm_term = 1.0 * src_ids.size( 0) * src_ids[:, 1:].size(1) loss_kl.eval_batch( logps_hyp_ae.reshape(-1, logps_hyp_ae.size(-1)), logps_hyp_asr.reshape(-1, logps_hyp_asr.size(-1))) loss_kl.norm_term = 1.0 * src_ids.size( 0) * src_ids[:, 1:].size(1) loss_l2.eval_batch( emb_hyp_asr.reshape(-1, emb_hyp_asr.size(-1)), emb_hyp_ae.reshape(-1, emb_hyp_ae.size(-1))) loss_l2.norm_term = 1.0 * src_ids.size( 0) * src_ids[:, 1:].size(1) else: loss_asr.eval_batch_with_mask( logps_hyp_asr.reshape(-1, logps_hyp_asr.size(-1)), src_ids[:, 1:].reshape(-1), non_padding_mask_src[:, 1:].reshape(-1)) loss_asr.norm_term = 1.0 * torch.sum( non_padding_mask_src[:, 1:]) loss_ae.eval_batch_with_mask( logps_hyp_ae.reshape(-1, logps_hyp_ae.size(-1)), refs_ae.reshape(-1), non_padding_mask_src[:, 1:].reshape(-1)) loss_ae.norm_term = 1.0 * torch.sum( non_padding_mask_src[:, 1:]) loss_kl.eval_batch_with_mask( logps_hyp_ae.reshape(-1, logps_hyp_ae.size(-1)), logps_hyp_asr.reshape(-1, logps_hyp_asr.size(-1)), non_padding_mask_src[:, 1:].reshape(-1)) loss_kl.norm_term = 1.0 * torch.sum( non_padding_mask_src[:, 1:]) loss_l2.eval_batch_with_mask( emb_hyp_asr.reshape(-1, emb_hyp_asr.size(-1)), emb_hyp_ae.reshape(-1, emb_hyp_ae.size(-1)), non_padding_mask_src[:, 1:].reshape(-1)) loss_l2.norm_term = 1.0 * torch.sum( non_padding_mask_src[:, 1:]) if self.normalise_loss: loss_asr.normalise() loss_ae.normalise() loss_kl.normalise() loss_l2.normalise() resloss_asr += loss_asr.get_loss() resloss_ae += loss_ae.get_loss() resloss_kl += loss_kl.get_loss() resloss_l2 += loss_l2.get_loss() resloss_norm += 1 # ----- debug ----- # print('{}/{}, {}/{}'.format(bidx, n_minibatch, idx, len(evaliter))) # if loss_kl.get_loss() > 10 or (bidx==4 and idx==1): # import pdb; pdb.set_trace() # ----------------- # accuracy seqres_asr = preds_hyp_asr correct_asr = seqres_asr.reshape(-1).eq(src_ids[:,1:].reshape(-1))\ .masked_select(non_padding_mask_src[:,1:].reshape(-1)).sum().item() match_asr += correct_asr total_asr += non_padding_mask_src[:, 1:].sum().item() seqres_ae = preds_hyp_ae correct_ae = seqres_ae.reshape(-1).eq(refs_ae.reshape(-1))\ .masked_select(non_padding_mask_src[:,1:].reshape(-1)).sum().item() match_ae += correct_ae total_ae += non_padding_mask_src[:, 1:].sum().item() # append to refs_ae dummy = torch.zeros(refs_ae.size(0), 1).to(device=self.device).long() refs_ae_add = torch.cat((dummy, refs_ae), dim=1) # print out_count_dummy = self._print(out_count, src_ids, dataset.src_id2word, seqres_asr, tail='-asr') out_count = self._print(out_count, refs_ae_add, dataset.src_id2word, seqres_ae, tail='-ae ') # accumulate corpus hyp_corpus_asr, ref_corpus_asr = add2corpus( seqres_asr, src_ids, dataset.src_id2word, hyp_corpus_asr, ref_corpus_asr, type='word') hyp_corpus_ae, ref_corpus_ae = add2corpus( seqres_ae, refs_ae_add, dataset.src_id2word, hyp_corpus_ae, ref_corpus_ae, type='word') # import pdb; pdb.set_trace() bleu_asr = torchtext.data.metrics.bleu_score(hyp_corpus_asr, ref_corpus_asr) bleu_ae = torchtext.data.metrics.bleu_score(hyp_corpus_ae, ref_corpus_ae) # torch.cuda.empty_cache() if total_asr == 0: accuracy_asr = float('nan') else: accuracy_asr = match_asr / total_asr if total_ae == 0: accuracy_ae = float('nan') else: accuracy_ae = match_ae / total_ae resloss_asr *= self.loss_coeff['nll_asr'] resloss_asr /= (1.0 * resloss_norm) resloss_ae *= self.loss_coeff['nll_ae'] resloss_ae /= (1.0 * resloss_norm) resloss_kl *= self.loss_coeff['kl_en'] resloss_kl /= (1.0 * resloss_norm) resloss_l2 *= self.loss_coeff['l2'] resloss_l2 /= (1.0 * resloss_norm) losses = {} losses['l2_loss'] = resloss_l2 losses['kl_loss'] = resloss_kl losses['nll_loss_asr'] = resloss_asr losses['nll_loss_ae'] = resloss_ae metrics = {} metrics['accuracy_asr'] = accuracy_asr metrics['bleu_asr'] = bleu_asr metrics['accuracy_ae'] = accuracy_ae metrics['bleu_ae'] = bleu_ae return losses, metrics
def _evaluate_batches(self, model, dataset): # import pdb; pdb.set_trace() model.eval() resloss = 0 resloss_norm = 0 # accuracy match = 0 total = 0 # bleu hyp_corpus = [] ref_corpus = [] evaliter = iter(dataset.iter_loader) out_count = 0 with torch.no_grad(): for idx in range(len(evaliter)): batch_items = evaliter.next() # load data batch_src_ids = batch_items['srcid'][0] batch_src_lengths = batch_items['srclen'] batch_tgt_ids = batch_items['tgtid'][0] batch_tgt_lengths = batch_items['tgtlen'] # separate into minibatch batch_size = batch_src_ids.size(0) batch_seq_len = int(max(batch_src_lengths)) n_minibatch = int(batch_size / self.minibatch_size) n_minibatch += int(batch_size % self.minibatch_size > 0) for bidx in range(n_minibatch): loss = NLLLoss() loss.reset() i_start = bidx * self.minibatch_size i_end = min(i_start + self.minibatch_size, batch_size) src_ids = batch_src_ids[i_start:i_end] src_lengths = batch_src_lengths[i_start:i_end] tgt_ids = batch_tgt_ids[i_start:i_end] tgt_lengths = batch_tgt_lengths[i_start:i_end] src_len = max(src_lengths) tgt_len = max(tgt_lengths) src_ids = src_ids[:, :src_len].to(device=self.device) tgt_ids = tgt_ids.to(device=self.device) non_padding_mask_tgt = tgt_ids.data.ne(PAD) non_padding_mask_src = src_ids.data.ne(PAD) # import pdb; pdb.set_trace() if self.eval_mode == 'tf': # [run-TF] to save time preds, logps, dec_outputs = model.forward_train( src_ids, tgt_ids, use_gpu=self.use_gpu) logps_hyp = logps[:, :-1, :] preds_hyp = preds[:, :-1] elif self.eval_mode == 'fr': # [run-FR] to get true stats preds, logps, dec_outputs = model.forward_eval( src_ids, use_gpu=self.use_gpu) logps_hyp = logps[:, 1:, :] preds_hyp = preds[:, 1:] # evaluation if not self.eval_with_mask: loss.eval_batch( logps_hyp.reshape(-1, logps_hyp.size(-1)), tgt_ids[:, 1:].reshape(-1)) loss.norm_term = 1.0 * tgt_ids.size( 0) * tgt_ids[:, 1:].size(1) else: loss.eval_batch_with_mask( logps_hyp.reshape(-1, logps_hyp.size(-1)), tgt_ids[:, 1:].reshape(-1), non_padding_mask_tgt[:, 1:].reshape(-1)) loss.norm_term = 1.0 * torch.sum( non_padding_mask_tgt[:, 1:]) if self.normalise_loss: loss.normalise() resloss += loss.get_loss() resloss_norm += 1 # accuracy seqres = preds_hyp correct = seqres.reshape(-1).eq(tgt_ids[:,1:].reshape(-1))\ .masked_select(non_padding_mask_tgt[:,1:].reshape(-1)).sum().item() match += correct total += non_padding_mask_tgt[:, 1:].sum().item() # print out_count = self._print_hyp(out_count, src_ids, tgt_ids, dataset.src_id2word, dataset.tgt_id2word, seqres) # accumulate corpus hyp_corpus, ref_corpus = add2corpus(seqres, tgt_ids, dataset.tgt_id2word, hyp_corpus, ref_corpus, type=dataset.use_type) # import pdb; pdb.set_trace() bleu = torchtext.data.metrics.bleu_score(hyp_corpus, ref_corpus) if total == 0: accuracy = float('nan') else: accuracy = match / total resloss /= (1.0 * resloss_norm) losses = {} losses['nll_loss'] = resloss metrics = {} metrics['accuracy'] = accuracy metrics['bleu'] = bleu return losses, metrics
def _evaluate_batches(self, model, dataset): # import pdb; pdb.set_trace() model.eval() resloss_en = 0 resloss_norm = 0 # accuracy match_en = 0 total_en = 0 # bleu hyp_corpus_en = [] ref_corpus_en = [] evaliter = iter(dataset.iter_loader) out_count = 0 with torch.no_grad(): for idx in range(len(evaliter)): batch_items = evaliter.next() # import pdb; pdb.set_trace() # load data batch_src_ids = batch_items['srcid'][0] batch_src_lengths = batch_items['srclen'] batch_acous_feats = batch_items['acous_feat'][0] batch_acous_lengths = batch_items['acouslen'] # separate into minibatch batch_size = batch_src_ids.size(0) batch_seq_len = int(max(batch_src_lengths)) n_minibatch = int(batch_size / self.minibatch_size) n_minibatch += int(batch_size % self.minibatch_size > 0) for bidx in range(n_minibatch): loss_en = NLLLoss() loss_en.reset() i_start = bidx * self.minibatch_size i_end = min(i_start + self.minibatch_size, batch_size) src_ids = batch_src_ids[i_start:i_end] src_lengths = batch_src_lengths[i_start:i_end] acous_feats = batch_acous_feats[i_start:i_end] acous_lengths = batch_acous_lengths[i_start:i_end] src_len = max(src_lengths) acous_len = max(acous_lengths) acous_len = acous_len + 8 - acous_len % 8 src_ids = src_ids.to(device=self.device) acous_feats = acous_feats[:, :acous_len].to( device=self.device) # debug oom # acous_feats, acous_lengths, src_ids, tgt_ids = self._debug_oom(acous_len, acous_feats) non_padding_mask_src = src_ids.data.ne(PAD) # [run-TF] to save time # out_dict = model.forward_train(src_ids, acous_feats=acous_feats, # acous_lens=acous_lengths, mode='ASR', use_gpu=self.use_gpu) # [run-FR] to get true stats out_dict = model.forward_eval(acous_feats=acous_feats, acous_lens=acous_lengths, mode='ASR', use_gpu=self.use_gpu) preds_en = out_dict['preds_asr'] logps_en = out_dict['logps_asr'] logps_hyp_en = logps_en preds_hyp_en = preds_en # evaluation if not self.eval_with_mask: loss_en.eval_batch( logps_hyp_en.reshape(-1, logps_hyp_en.size(-1)), src_ids[:, 1:].reshape(-1)) loss_en.norm_term = 1.0 * src_ids.size( 0) * src_ids[:, 1:].size(1) else: loss_en.eval_batch_with_mask( logps_hyp_en.reshape(-1, logps_hyp_en.size(-1)), src_ids[:, 1:].reshape(-1), non_padding_mask_src[:, 1:].reshape(-1)) loss_en.norm_term = 1.0 * torch.sum( non_padding_mask_src[:, 1:]) if self.normalise_loss: loss_en.normalise() resloss_en += loss_en.get_loss() resloss_norm += 1 # accuracy seqres_en = preds_hyp_en correct_en = seqres_en.reshape(-1).eq(src_ids[:,1:].reshape(-1))\ .masked_select(non_padding_mask_src[:,1:].reshape(-1)).sum().item() match_en += correct_en total_en += non_padding_mask_src[:, 1:].sum().item() # print out_count = self._print(out_count, src_ids, dataset.src_id2word, seqres_en, tail='-asr') # accumulate corpus hyp_corpus_en, ref_corpus_en = add2corpus( seqres_en, src_ids, dataset.src_id2word, hyp_corpus_en, ref_corpus_en, type='word') # import pdb; pdb.set_trace() bleu_en = torchtext.data.metrics.bleu_score(hyp_corpus_en, ref_corpus_en) # torch.cuda.empty_cache() if total_en == 0: accuracy_en = float('nan') else: accuracy_en = match_en / total_en resloss_en /= (1.0 * resloss_norm) losses = {} losses['nll_loss_de'] = 0 losses['nll_loss_en'] = resloss_en metrics = {} metrics['accuracy_de'] = 0 metrics['bleu_de'] = 0 metrics['accuracy_en'] = accuracy_en metrics['bleu_en'] = bleu_en return losses, metrics