def compute_kl(test_set, model, test_path_out, use_gpu, max_seq_len, device): """ compute KL divergence between ASR and AE output """ # load test test_set.construct_batches(is_train=False) evaliter = iter(test_set.iter_loader) print('num batches: {}'.format(len(evaliter))) path_out = os.path.join(test_path_out, 'kl.stats') # init losses resloss_asr = 0 resloss_ae = 0 resloss_kl = 0 resloss_l2 = 0 h_asr = 0 h_ae = 0 resloss_norm = 0 model.eval() with torch.no_grad(): for idx in range(len(evaliter)): # for idx in range(2): print(idx + 1, len(evaliter)) batch_items = evaliter.next() # load data src_ids = batch_items['srcid'][0] src_lengths = batch_items['srclen'] tgt_ids = batch_items['tgtid'][0] tgt_lengths = batch_items['tgtlen'] acous_feats = batch_items['acous_feat'][0] acous_lengths = batch_items['acouslen'] src_len = max(src_lengths) tgt_len = max(tgt_lengths) acous_len = max(acous_lengths) src_ids = src_ids[:, :src_len].to(device=device) tgt_ids = tgt_ids.to(device=device) acous_feats = acous_feats.to(device=device) n_minibatch = int(tgt_len / 100 + tgt_len % 100 > 0) minibatch_size = int(src_ids.size(0) / n_minibatch) n_minibatch = int(src_ids.size(0) / minibatch_size) + \ (src_ids.size(0) % minibatch_size > 0) for j in range(n_minibatch): st = j * minibatch_size ed = min((j + 1) * minibatch_size, src_ids.size(0)) src_ids_sub = src_ids[st:ed, :] tgt_ids_sub = tgt_ids[st:ed, :] acous_feats_sub = acous_feats[st:ed, :] acous_lengths_sub = acous_lengths[st:ed] print('minibatch: ', st, ed, src_ids_sub.size(0)) # generate logp out_dict = model.forward_eval(src=src_ids_sub, acous_feats=acous_feats_sub, acous_lens=acous_lengths_sub, mode='AE_ASR', use_gpu=use_gpu) max_len = min(src_ids_sub.size(1), out_dict['preds_asr'].size(1)) preds_hyp_asr = out_dict['preds_asr'][:, :max_len - 1] preds_hyp_ae = out_dict['preds_ae'][:, :max_len - 1] emb_hyp_asr = out_dict['emb_asr'][:, :max_len - 1] emb_hyp_ae = out_dict['emb_ae'][:, :max_len - 1] logps_hyp_asr = out_dict['logps_asr'][:, :max_len - 1] logps_hyp_ae = out_dict['logps_ae'][:, :max_len - 1] refs_ae = out_dict['refs_ae'][:, :max_len - 1] src_ids_sub = src_ids_sub[:, :max_len] non_padding_mask_src = src_ids_sub.data.ne(PAD) # import pdb; pdb.set_trace() # various losses loss_asr = NLLLoss() loss_asr.reset() loss_ae = NLLLoss() loss_ae.reset() loss_kl = KLDivLoss() loss_kl.reset() loss_l2 = MSELoss() loss_l2.reset() loss_asr.eval_batch_with_mask( logps_hyp_asr.reshape(-1, logps_hyp_asr.size(-1)), src_ids_sub[:, 1:].reshape(-1), non_padding_mask_src[:, 1:].reshape(-1)) loss_asr.norm_term = 1.0 * torch.sum(non_padding_mask_src[:, 1:]) loss_ae.eval_batch_with_mask( logps_hyp_ae.reshape(-1, logps_hyp_ae.size(-1)), refs_ae.reshape(-1), non_padding_mask_src[:, 1:].reshape(-1)) loss_ae.norm_term = 1.0 * torch.sum(non_padding_mask_src[:, 1:]) loss_kl.eval_batch_with_mask( logps_hyp_ae.reshape(-1, logps_hyp_ae.size(-1)), logps_hyp_asr.reshape(-1, logps_hyp_asr.size(-1)), non_padding_mask_src[:, 1:].reshape(-1)) loss_kl.norm_term = 1.0 * torch.sum(non_padding_mask_src[:, 1:]) loss_l2.eval_batch_with_mask( emb_hyp_asr.reshape(-1, emb_hyp_asr.size(-1)), emb_hyp_ae.reshape(-1, emb_hyp_ae.size(-1)), non_padding_mask_src[:, 1:].reshape(-1)) loss_l2.norm_term = 1.0 * torch.sum(non_padding_mask_src[:, 1:]) loss_asr.normalise() loss_ae.normalise() loss_kl.normalise() loss_l2.normalise() resloss_asr += loss_asr.get_loss() resloss_ae += loss_ae.get_loss() resloss_kl += loss_kl.get_loss() resloss_l2 += loss_l2.get_loss() resloss_norm += 1 # compute per token entropy entropy_asr = torch.mean( Categorical(probs=torch.exp(logps_hyp_asr)).entropy()) entropy_ae = torch.mean( Categorical(probs=torch.exp(logps_hyp_ae)).entropy()) h_asr += entropy_asr.item() h_ae += entropy_ae.item() # import pdb; pdb.set_trace() fout = open(path_out, 'w') fout.write('Various stats (averaged over tokens)') fout.write('\n{}\n'.format('-' * 50)) fout.write('NLL ASR: {:0.2f}\n'.format(1. * resloss_asr / resloss_norm)) fout.write('NLL AE: {:0.2f}\n'.format(1. * resloss_ae / resloss_norm)) fout.write('KL between ASR, AE: {:0.2f}\n'.format(1. * resloss_kl / resloss_norm)) fout.write('L2 between embeddings: {:0.2f}\n'.format(1. * resloss_l2 / resloss_norm)) fout.write('Entropy ASR: {:0.2f}\n'.format(1. * h_asr / resloss_norm)) fout.write('Entropy AE: {:0.2f}\n'.format(1. * h_ae / resloss_norm)) fout.close()
def _train_batch(self, model, batch_items, dataset, step, total_steps): # -- scheduled sampling -- if not self.scheduled_sampling: teacher_forcing_ratio = self.teacher_forcing_ratio else: progress = 1.0 * step / total_steps teacher_forcing_ratio = 1.0 - progress # -- LOAD BATCH -- batch_src_ids = batch_items[0][0] batch_src_lengths = batch_items[1] batch_acous_feats = batch_items[2][0] batch_acous_lengths = batch_items[3] # -- CONSTRUCT MINIBATCH -- batch_size = batch_src_ids.size(0) batch_seq_len = int(max(batch_src_lengths)) batch_acous_len = int(max(batch_acous_lengths)) n_minibatch = int(batch_size / self.minibatch_size) n_minibatch += int(batch_size % self.minibatch_size > 0) las_resloss = 0 # minibatch for bidx in range(n_minibatch): # debug # import pdb; pdb.set_trace() # define loss las_loss = NLLLoss() las_loss.reset() # load data i_start = bidx * self.minibatch_size i_end = min(i_start + self.minibatch_size, batch_size) src_ids = batch_src_ids[i_start:i_end] src_lengths = batch_src_lengths[i_start:i_end] acous_feats = batch_acous_feats[i_start:i_end] acous_lengths = batch_acous_lengths[i_start:i_end] seq_len = max(src_lengths) acous_len = max(acous_lengths) acous_len = acous_len + 8 - acous_len % 8 src_ids = src_ids[:, :seq_len].to(device=self.device) acous_feats = acous_feats[:, :acous_len].to(device=self.device) # sanity check src if step == 1: check_src_tensor_print(src_ids, dataset.src_id2word) # get padding mask non_padding_mask_src = src_ids.data.ne(PAD) # Forward propagation decoder_outputs, decoder_hidden, ret_dict = model( acous_feats, acous_lens=acous_lengths, tgt=src_ids, is_training=True, teacher_forcing_ratio=teacher_forcing_ratio, use_gpu=self.use_gpu) logps = torch.stack(decoder_outputs, dim=1).to(device=self.device) las_loss.eval_batch_with_mask(logps.reshape(-1, logps.size(-1)), src_ids.reshape(-1), non_padding_mask_src.reshape(-1)) las_loss.norm_term = 1.0 * torch.sum(non_padding_mask_src) # import pdb; pdb.set_trace() # Backward propagation: accumulate gradient if self.normalise_loss: las_loss.normalise() las_loss.acc_loss /= n_minibatch las_loss.backward() las_resloss += las_loss.get_loss() torch.cuda.empty_cache() # update weights self.optimizer.step() model.zero_grad() losses = {'las_loss': las_resloss} return losses
def _evaluate_batches(self, model, dataset): model.eval() las_match = 0 las_total = 0 las_resloss = 0 las_resloss_norm = 0 evaliter = iter(dataset.iter_loader) out_count = 0 with torch.no_grad(): for idx in range(len(evaliter)): batch_items = evaliter.next() # load data batch_src_ids = batch_items[0][0] batch_src_lengths = batch_items[1] batch_acous_feats = batch_items[2][0] batch_acous_lengths = batch_items[3] # separate into minibatch batch_size = batch_src_ids.size(0) batch_seq_len = int(max(batch_src_lengths)) batch_acous_len = int(max(batch_acous_lengths)) n_minibatch = int(batch_size / self.minibatch_size) n_minibatch += int(batch_size % self.minibatch_size > 0) # minibatch for bidx in range(n_minibatch): las_loss = NLLLoss() las_loss.reset() i_start = bidx * self.minibatch_size i_end = min(i_start + self.minibatch_size, batch_size) src_ids = batch_src_ids[i_start:i_end] src_lengths = batch_src_lengths[i_start:i_end] acous_feats = batch_acous_feats[i_start:i_end] acous_lengths = batch_acous_lengths[i_start:i_end] seq_len = max(src_lengths) acous_len = max(acous_lengths) acous_len = acous_len + 8 - acous_len % 8 src_ids = src_ids[:, :seq_len].to(device=self.device) acous_feats = acous_feats[:, :acous_len].to( device=self.device) non_padding_mask_src = src_ids.data.ne(PAD) # eval using ref decoder_outputs, decoder_hidden, ret_dict = model( acous_feats, acous_lens=acous_lengths, teacher_forcing_ratio=1.0, tgt=src_ids, is_training=False, use_gpu=self.use_gpu) # eval under hyp # decoder_outputs, decoder_hidden, ret_dict = model( # acous_feats, acous_lens=acous_lengths, # teacher_forcing_ratio=0.0, # tgt=src_ids, is_training=False, use_gpu=self.use_gpu) # Evaluation logps = torch.stack(decoder_outputs, dim=1).to(device=self.device) las_loss.eval_batch_with_mask( logps.reshape(-1, logps.size(-1)), src_ids.reshape(-1), non_padding_mask_src.reshape(-1)) las_loss.norm_term = torch.sum(non_padding_mask_src) if self.normalise_loss: las_loss.normalise() las_resloss += las_loss.get_loss() las_resloss_norm += 1 # las accuracy seqlist = ret_dict['sequence'] seqres = torch.stack(seqlist, dim=1).to(device=self.device) correct = seqres.view(-1).eq(src_ids.reshape(-1))\ .masked_select(non_padding_mask_src.reshape(-1)).sum().item() las_match += correct las_total += non_padding_mask_src.sum().item() out_count = self._print_hyp(out_count, src_ids, dataset.src_id2word, seqlist) if las_total == 0: las_acc = float('nan') else: las_acc = las_match / las_total las_resloss /= (1.0 * las_resloss_norm) accs = {'las_acc': las_acc} losses = {'las_loss': las_resloss} return accs, losses
def _evaluate_batches(self, model, dataset): # todo: return BLEU score (use BLEU to determine roll back etc) # import pdb; pdb.set_trace() model.eval() resloss = 0 resloss_norm = 0 match = 0 total = 0 evaliter = iter(dataset.iter_loader) out_count = 0 with torch.no_grad(): for idx in range(len(evaliter)): batch_items = evaliter.next() # load data batch_src_ids = batch_items['srcid'][0] batch_src_lengths = batch_items['srclen'] batch_tgt_ids = batch_items['tgtid'][0] batch_tgt_lengths = batch_items['tgtlen'] # separate into minibatch batch_size = batch_src_ids.size(0) batch_seq_len = int(max(batch_src_lengths)) n_minibatch = int(batch_size / self.minibatch_size) n_minibatch += int(batch_size % self.minibatch_size > 0) for bidx in range(n_minibatch): loss = NLLLoss() loss.reset() i_start = bidx * self.minibatch_size i_end = min(i_start + self.minibatch_size, batch_size) src_ids = batch_src_ids[i_start:i_end] src_lengths = batch_src_lengths[i_start:i_end] tgt_ids = batch_tgt_ids[i_start:i_end] tgt_lengths = batch_tgt_lengths[i_start:i_end] src_len = max(src_lengths) tgt_len = max(tgt_lengths) src_ids = src_ids[:, :src_len].to(device=self.device) tgt_ids = tgt_ids.to(device=self.device) non_padding_mask_tgt = tgt_ids.data.ne(PAD) non_padding_mask_src = src_ids.data.ne(PAD) # run model preds, logps, dec_outputs = model.forward_eval( src_ids, use_gpu=self.use_gpu) # evaluation if not self.eval_with_mask: loss.eval_batch( logps[:, 1:, :].reshape(-1, logps.size(-1)), tgt_ids[:, 1:].reshape(-1)) loss.norm_term = 1.0 * tgt_ids.size( 0) * tgt_ids[:, 1:].size(1) else: loss.eval_batch_with_mask( logps[:, 1:, :].reshape(-1, logps.size(-1)), tgt_ids[:, 1:].reshape(-1), non_padding_mask_tgt[:, 1:].reshape(-1)) loss.norm_term = 1.0 * torch.sum( non_padding_mask_tgt[:, 1:]) if self.normalise_loss: loss.normalise() resloss += loss.get_loss() resloss_norm += 1 seqres = preds[:, 1:] correct = seqres.reshape(-1).eq(tgt_ids[:,1:].reshape(-1))\ .masked_select(non_padding_mask_tgt[:,1:].reshape(-1)).sum().item() match += correct total += non_padding_mask_tgt[:, 1:].sum().item() out_count = self._print_hyp(out_count, src_ids, tgt_ids, dataset.src_id2word, dataset.tgt_id2word, seqres) if total == 0: accuracy = float('nan') else: accuracy = match / total resloss /= (1.0 * resloss_norm) torch.cuda.empty_cache() losses = {} losses['nll_loss'] = resloss return resloss, accuracy, losses
def _train_batch(self, model, batch_items, dataset, step, total_steps): """ Args: src_ids = w1 w2 w3 </s> <pad> <pad> <pad> tgt_ids = <s> w1 w2 w3 </s> <pad> <pad> <pad> Others: internal input = <s> w1 w2 w3 </s> <pad> <pad> decoder_outputs = w1 w2 w3 </s> <pad> <pad> <pad> """ # load data batch_src_ids = batch_items['srcid'][0] batch_src_lengths = batch_items['srclen'] batch_tgt_ids = batch_items['tgtid'][0] batch_tgt_lengths = batch_items['tgtlen'] # separate into minibatch batch_size = batch_src_ids.size(0) batch_seq_len = int(max(batch_src_lengths)) n_minibatch = int(batch_size / self.minibatch_size) n_minibatch += int(batch_size % self.minibatch_size > 0) resloss = 0 for bidx in range(n_minibatch): # debug # import pdb; pdb.set_trace() # define loss loss = NLLLoss() loss.reset() # load data i_start = bidx * self.minibatch_size i_end = min(i_start + self.minibatch_size, batch_size) src_ids = batch_src_ids[i_start:i_end] src_lengths = batch_src_lengths[i_start:i_end] tgt_ids = batch_tgt_ids[i_start:i_end] tgt_lengths = batch_tgt_lengths[i_start:i_end] src_len = max(src_lengths) tgt_len = max(tgt_lengths) src_ids = src_ids[:, :src_len].to(device=self.device) tgt_ids = tgt_ids.to(device=self.device) # get padding mask non_padding_mask_src = src_ids.data.ne(PAD) non_padding_mask_tgt = tgt_ids.data.ne(PAD) # Forward propagation preds, logps, dec_outputs = model.forward_train( src_ids, tgt_ids, use_gpu=self.use_gpu) # Get loss if not self.eval_with_mask: loss.eval_batch(logps[:, :-1, :].reshape(-1, logps.size(-1)), tgt_ids[:, 1:].reshape(-1)) loss.norm_term = 1.0 * tgt_ids.size(0) * tgt_ids[:, 1:].size(1) else: loss.eval_batch_with_mask( logps[:, :-1, :].reshape(-1, logps.size(-1)), tgt_ids[:, 1:].reshape(-1), non_padding_mask_tgt[:, 1:].reshape(-1)) loss.norm_term = 1.0 * torch.sum(non_padding_mask_tgt[:, 1:]) # import pdb; pdb.set_trace() # Backward propagation: accumulate gradient if self.normalise_loss: loss.normalise() loss.acc_loss /= n_minibatch loss.backward() resloss += loss.get_loss() torch.cuda.empty_cache() # update weights self.optimizer.step() model.zero_grad() return resloss
def _evaluate_batches(self, model, dataset): # import pdb; pdb.set_trace() model.eval() resloss_asr = 0 resloss_ae = 0 resloss_kl = 0 resloss_l2 = 0 resloss_norm = 0 # accuracy match_asr = 0 total_asr = 0 match_ae = 0 total_ae = 0 # bleu hyp_corpus_asr = [] ref_corpus_asr = [] hyp_corpus_ae = [] ref_corpus_ae = [] evaliter = iter(dataset.iter_loader) out_count = 0 with torch.no_grad(): for idx in range(len(evaliter)): batch_items = evaliter.next() # import pdb; pdb.set_trace() # load data batch_src_ids = batch_items['srcid'][0] batch_src_lengths = batch_items['srclen'] batch_acous_feats = batch_items['acous_feat'][0] batch_acous_lengths = batch_items['acouslen'] # separate into minibatch batch_size = batch_src_ids.size(0) batch_seq_len = int(max(batch_src_lengths)) n_minibatch = int(batch_size / self.minibatch_size) n_minibatch += int(batch_size % self.minibatch_size > 0) for bidx in range(n_minibatch): loss_asr = NLLLoss() loss_asr.reset() loss_ae = NLLLoss() loss_ae.reset() loss_kl = KLDivLoss() loss_kl.reset() loss_l2 = MSELoss() loss_l2.reset() i_start = bidx * self.minibatch_size i_end = min(i_start + self.minibatch_size, batch_size) src_ids = batch_src_ids[i_start:i_end] src_lengths = batch_src_lengths[i_start:i_end] acous_feats = batch_acous_feats[i_start:i_end] acous_lengths = batch_acous_lengths[i_start:i_end] src_len = max(src_lengths) acous_len = max(acous_lengths) acous_len = acous_len + 8 - acous_len % 8 src_ids = src_ids.to(device=self.device) acous_feats = acous_feats[:, :acous_len].to( device=self.device) # debug oom # acous_feats, acous_lengths, src_ids, tgt_ids = self._debug_oom(acous_len, acous_feats) non_padding_mask_src = src_ids.data.ne(PAD) # [run-TF] to save time # out_dict = model.forward_train(src_ids, acous_feats=acous_feats, # acous_lens=acous_lengths, mode='ASR', use_gpu=self.use_gpu) # [run-FR] to get true stats out_dict = model.forward_eval(src=src_ids, acous_feats=acous_feats, acous_lens=acous_lengths, mode='AE_ASR', use_gpu=self.use_gpu) preds_asr = out_dict['preds_asr'] logps_asr = out_dict['logps_asr'] emb_asr = out_dict['emb_asr'] preds_ae = out_dict['preds_ae'] logps_ae = out_dict['logps_ae'] emb_ae = out_dict['emb_ae'] refs_ae = out_dict['refs_ae'] logps_hyp_asr = logps_asr preds_hyp_asr = preds_asr emb_hyp_asr = emb_asr logps_hyp_ae = logps_ae preds_hyp_ae = preds_ae emb_hyp_ae = emb_ae # evaluation if not self.eval_with_mask: loss_asr.eval_batch( logps_hyp_asr.reshape(-1, logps_hyp_asr.size(-1)), src_ids[:, 1:].reshape(-1)) loss_asr.norm_term = 1.0 * src_ids.size( 0) * src_ids[:, 1:].size(1) loss_ae.eval_batch( logps_hyp_ae.reshape(-1, logps_hyp_ae.size(-1)), refs_ae.reshape(-1)) loss_ae.norm_term = 1.0 * src_ids.size( 0) * src_ids[:, 1:].size(1) loss_kl.eval_batch( logps_hyp_ae.reshape(-1, logps_hyp_ae.size(-1)), logps_hyp_asr.reshape(-1, logps_hyp_asr.size(-1))) loss_kl.norm_term = 1.0 * src_ids.size( 0) * src_ids[:, 1:].size(1) loss_l2.eval_batch( emb_hyp_asr.reshape(-1, emb_hyp_asr.size(-1)), emb_hyp_ae.reshape(-1, emb_hyp_ae.size(-1))) loss_l2.norm_term = 1.0 * src_ids.size( 0) * src_ids[:, 1:].size(1) else: loss_asr.eval_batch_with_mask( logps_hyp_asr.reshape(-1, logps_hyp_asr.size(-1)), src_ids[:, 1:].reshape(-1), non_padding_mask_src[:, 1:].reshape(-1)) loss_asr.norm_term = 1.0 * torch.sum( non_padding_mask_src[:, 1:]) loss_ae.eval_batch_with_mask( logps_hyp_ae.reshape(-1, logps_hyp_ae.size(-1)), refs_ae.reshape(-1), non_padding_mask_src[:, 1:].reshape(-1)) loss_ae.norm_term = 1.0 * torch.sum( non_padding_mask_src[:, 1:]) loss_kl.eval_batch_with_mask( logps_hyp_ae.reshape(-1, logps_hyp_ae.size(-1)), logps_hyp_asr.reshape(-1, logps_hyp_asr.size(-1)), non_padding_mask_src[:, 1:].reshape(-1)) loss_kl.norm_term = 1.0 * torch.sum( non_padding_mask_src[:, 1:]) loss_l2.eval_batch_with_mask( emb_hyp_asr.reshape(-1, emb_hyp_asr.size(-1)), emb_hyp_ae.reshape(-1, emb_hyp_ae.size(-1)), non_padding_mask_src[:, 1:].reshape(-1)) loss_l2.norm_term = 1.0 * torch.sum( non_padding_mask_src[:, 1:]) if self.normalise_loss: loss_asr.normalise() loss_ae.normalise() loss_kl.normalise() loss_l2.normalise() resloss_asr += loss_asr.get_loss() resloss_ae += loss_ae.get_loss() resloss_kl += loss_kl.get_loss() resloss_l2 += loss_l2.get_loss() resloss_norm += 1 # ----- debug ----- # print('{}/{}, {}/{}'.format(bidx, n_minibatch, idx, len(evaliter))) # if loss_kl.get_loss() > 10 or (bidx==4 and idx==1): # import pdb; pdb.set_trace() # ----------------- # accuracy seqres_asr = preds_hyp_asr correct_asr = seqres_asr.reshape(-1).eq(src_ids[:,1:].reshape(-1))\ .masked_select(non_padding_mask_src[:,1:].reshape(-1)).sum().item() match_asr += correct_asr total_asr += non_padding_mask_src[:, 1:].sum().item() seqres_ae = preds_hyp_ae correct_ae = seqres_ae.reshape(-1).eq(refs_ae.reshape(-1))\ .masked_select(non_padding_mask_src[:,1:].reshape(-1)).sum().item() match_ae += correct_ae total_ae += non_padding_mask_src[:, 1:].sum().item() # append to refs_ae dummy = torch.zeros(refs_ae.size(0), 1).to(device=self.device).long() refs_ae_add = torch.cat((dummy, refs_ae), dim=1) # print out_count_dummy = self._print(out_count, src_ids, dataset.src_id2word, seqres_asr, tail='-asr') out_count = self._print(out_count, refs_ae_add, dataset.src_id2word, seqres_ae, tail='-ae ') # accumulate corpus hyp_corpus_asr, ref_corpus_asr = add2corpus( seqres_asr, src_ids, dataset.src_id2word, hyp_corpus_asr, ref_corpus_asr, type='word') hyp_corpus_ae, ref_corpus_ae = add2corpus( seqres_ae, refs_ae_add, dataset.src_id2word, hyp_corpus_ae, ref_corpus_ae, type='word') # import pdb; pdb.set_trace() bleu_asr = torchtext.data.metrics.bleu_score(hyp_corpus_asr, ref_corpus_asr) bleu_ae = torchtext.data.metrics.bleu_score(hyp_corpus_ae, ref_corpus_ae) # torch.cuda.empty_cache() if total_asr == 0: accuracy_asr = float('nan') else: accuracy_asr = match_asr / total_asr if total_ae == 0: accuracy_ae = float('nan') else: accuracy_ae = match_ae / total_ae resloss_asr *= self.loss_coeff['nll_asr'] resloss_asr /= (1.0 * resloss_norm) resloss_ae *= self.loss_coeff['nll_ae'] resloss_ae /= (1.0 * resloss_norm) resloss_kl *= self.loss_coeff['kl_en'] resloss_kl /= (1.0 * resloss_norm) resloss_l2 *= self.loss_coeff['l2'] resloss_l2 /= (1.0 * resloss_norm) losses = {} losses['l2_loss'] = resloss_l2 losses['kl_loss'] = resloss_kl losses['nll_loss_asr'] = resloss_asr losses['nll_loss_ae'] = resloss_ae metrics = {} metrics['accuracy_asr'] = accuracy_asr metrics['bleu_asr'] = bleu_asr metrics['accuracy_ae'] = accuracy_ae metrics['bleu_ae'] = bleu_ae return losses, metrics
def _evaluate_batches(self, model, dataset): # import pdb; pdb.set_trace() model.eval() resloss = 0 resloss_norm = 0 # accuracy match = 0 total = 0 # bleu hyp_corpus = [] ref_corpus = [] evaliter = iter(dataset.iter_loader) out_count = 0 with torch.no_grad(): for idx in range(len(evaliter)): batch_items = evaliter.next() # load data batch_src_ids = batch_items['srcid'][0] batch_src_lengths = batch_items['srclen'] batch_tgt_ids = batch_items['tgtid'][0] batch_tgt_lengths = batch_items['tgtlen'] # separate into minibatch batch_size = batch_src_ids.size(0) batch_seq_len = int(max(batch_src_lengths)) n_minibatch = int(batch_size / self.minibatch_size) n_minibatch += int(batch_size % self.minibatch_size > 0) for bidx in range(n_minibatch): loss = NLLLoss() loss.reset() i_start = bidx * self.minibatch_size i_end = min(i_start + self.minibatch_size, batch_size) src_ids = batch_src_ids[i_start:i_end] src_lengths = batch_src_lengths[i_start:i_end] tgt_ids = batch_tgt_ids[i_start:i_end] tgt_lengths = batch_tgt_lengths[i_start:i_end] src_len = max(src_lengths) tgt_len = max(tgt_lengths) src_ids = src_ids[:, :src_len].to(device=self.device) tgt_ids = tgt_ids.to(device=self.device) non_padding_mask_tgt = tgt_ids.data.ne(PAD) non_padding_mask_src = src_ids.data.ne(PAD) # import pdb; pdb.set_trace() if self.eval_mode == 'tf': # [run-TF] to save time preds, logps, dec_outputs = model.forward_train( src_ids, tgt_ids, use_gpu=self.use_gpu) logps_hyp = logps[:, :-1, :] preds_hyp = preds[:, :-1] elif self.eval_mode == 'fr': # [run-FR] to get true stats preds, logps, dec_outputs = model.forward_eval( src_ids, use_gpu=self.use_gpu) logps_hyp = logps[:, 1:, :] preds_hyp = preds[:, 1:] # evaluation if not self.eval_with_mask: loss.eval_batch( logps_hyp.reshape(-1, logps_hyp.size(-1)), tgt_ids[:, 1:].reshape(-1)) loss.norm_term = 1.0 * tgt_ids.size( 0) * tgt_ids[:, 1:].size(1) else: loss.eval_batch_with_mask( logps_hyp.reshape(-1, logps_hyp.size(-1)), tgt_ids[:, 1:].reshape(-1), non_padding_mask_tgt[:, 1:].reshape(-1)) loss.norm_term = 1.0 * torch.sum( non_padding_mask_tgt[:, 1:]) if self.normalise_loss: loss.normalise() resloss += loss.get_loss() resloss_norm += 1 # accuracy seqres = preds_hyp correct = seqres.reshape(-1).eq(tgt_ids[:,1:].reshape(-1))\ .masked_select(non_padding_mask_tgt[:,1:].reshape(-1)).sum().item() match += correct total += non_padding_mask_tgt[:, 1:].sum().item() # print out_count = self._print_hyp(out_count, src_ids, tgt_ids, dataset.src_id2word, dataset.tgt_id2word, seqres) # accumulate corpus hyp_corpus, ref_corpus = add2corpus(seqres, tgt_ids, dataset.tgt_id2word, hyp_corpus, ref_corpus, type=dataset.use_type) # import pdb; pdb.set_trace() bleu = torchtext.data.metrics.bleu_score(hyp_corpus, ref_corpus) if total == 0: accuracy = float('nan') else: accuracy = match / total resloss /= (1.0 * resloss_norm) losses = {} losses['nll_loss'] = resloss metrics = {} metrics['accuracy'] = accuracy metrics['bleu'] = bleu return losses, metrics
def _train_batch(self, model, batch_items, dataset, step, total_steps): # load data batch_src_ids = batch_items['srcid'][0] batch_src_lengths = batch_items['srclen'] batch_acous_feats = batch_items['acous_feat'][0] batch_acous_lengths = batch_items['acouslen'] # separate into minibatch batch_size = batch_src_ids.size(0) batch_seq_len = int(max(batch_src_lengths)) n_minibatch = int(batch_size / self.minibatch_size) n_minibatch += int(batch_size % self.minibatch_size > 0) resloss_asr = 0 resloss_ae = 0 resloss_kl = 0 resloss_l2 = 0 for bidx in range(n_minibatch): # debug # print(bidx,n_minibatch) # import pdb; pdb.set_trace() # define loss loss_asr = NLLLoss() loss_asr.reset() loss_ae = NLLLoss() loss_ae.reset() loss_kl = KLDivLoss() loss_kl.reset() loss_l2 = MSELoss() loss_l2.reset() # load data i_start = bidx * self.minibatch_size i_end = min(i_start + self.minibatch_size, batch_size) src_ids = batch_src_ids[i_start:i_end] src_lengths = batch_src_lengths[i_start:i_end] acous_feats = batch_acous_feats[i_start:i_end] acous_lengths = batch_acous_lengths[i_start:i_end] src_len = max(src_lengths) acous_len = max(acous_lengths) acous_len = acous_len + 8 - acous_len % 8 src_ids = src_ids.to(device=self.device) acous_feats = acous_feats[:, :acous_len].to(device=self.device) # debug oom # acous_feats, acous_lengths, src_ids, tgt_ids = self._debug_oom(acous_len, acous_feats) # get padding mask non_padding_mask_src = src_ids.data.ne(PAD) # Forward propagation out_dict = model.forward_train(src_ids, acous_feats=acous_feats, acous_lens=acous_lengths, mode='AE_ASR', use_gpu=self.use_gpu) logps_asr = out_dict['logps_asr'] emb_asr = out_dict['emb_asr'] logps_ae = out_dict['logps_ae'] emb_ae = out_dict['emb_ae'] refs_ae = out_dict['refs_ae'] # import pdb; pdb.set_trace() # Get loss if not self.eval_with_mask: loss_asr.eval_batch(logps_asr.reshape(-1, logps_asr.size(-1)), src_ids[:, 1:].reshape(-1)) loss_asr.norm_term = 1.0 * src_ids.size( 0) * src_ids[:, 1:].size(1) loss_ae.eval_batch(logps_ae.reshape(-1, logps_ae.size(-1)), refs_ae.reshape(-1)) loss_ae.norm_term = 1.0 * src_ids.size(0) * src_ids[:, 1:].size(1) loss_kl.eval_batch(logps_ae.reshape(-1, logps_ae.size(-1)), logps_asr.reshape(-1, logps_asr.size(-1))) loss_kl.norm_term = 1.0 * src_ids.size(0) * src_ids[:, 1:].size(1) loss_l2.eval_batch(emb_asr.reshape(-1, emb_asr.size(-1)), emb_ae.reshape(-1, emb_ae.size(-1))) loss_l2.norm_term = 1.0 * src_ids.size(0) * src_ids[:, 1:].size(1) else: loss_asr.eval_batch_with_mask( logps_asr.reshape(-1, logps_asr.size(-1)), src_ids[:, 1:].reshape(-1), non_padding_mask_src[:, 1:].reshape(-1)) loss_asr.norm_term = 1.0 * torch.sum(non_padding_mask_src[:, 1:]) loss_ae.eval_batch_with_mask( logps_ae.reshape(-1, logps_ae.size(-1)), refs_ae.reshape(-1), non_padding_mask_src[:, 1:].reshape(-1)) loss_ae.norm_term = 1.0 * torch.sum(non_padding_mask_src[:, 1:]) loss_kl.eval_batch_with_mask( logps_ae.reshape(-1, logps_ae.size(-1)), logps_asr.reshape(-1, logps_asr.size(-1)), non_padding_mask_src[:, 1:].reshape(-1)) loss_kl.norm_term = 1.0 * torch.sum(non_padding_mask_src[:, 1:]) loss_l2.eval_batch_with_mask( emb_asr.reshape(-1, emb_asr.size(-1)), emb_ae.reshape(-1, emb_ae.size(-1)), non_padding_mask_src[:, 1:].reshape(-1)) loss_l2.norm_term = 1.0 * torch.sum(non_padding_mask_src[:, 1:]) # Backward propagation: accumulate gradient if self.normalise_loss: loss_asr.normalise() loss_ae.normalise() loss_kl.normalise() loss_l2.normalise() loss_asr.acc_loss *= self.loss_coeff['nll_asr'] loss_asr.acc_loss /= n_minibatch resloss_asr += loss_asr.get_loss() loss_ae.acc_loss *= self.loss_coeff['nll_ae'] loss_ae.acc_loss /= n_minibatch resloss_ae += loss_ae.get_loss() loss_kl.acc_loss *= self.loss_coeff['kl_en'] loss_kl.acc_loss /= n_minibatch resloss_kl += loss_kl.get_loss() loss_l2.acc_loss *= self.loss_coeff['l2'] loss_l2.acc_loss /= n_minibatch resloss_l2 += loss_l2.get_loss() loss_asr.add(loss_ae) loss_asr.add(loss_kl) loss_asr.add(loss_l2) loss_asr.backward() # torch.cuda.empty_cache() # update weights self.optimizer.step() model.zero_grad() losses = {} losses['nll_loss_asr'] = resloss_asr losses['nll_loss_ae'] = resloss_ae losses['kl_loss'] = resloss_kl losses['l2_loss'] = resloss_l2 return losses
def _evaluate_batches(self, model, dataset): # import pdb; pdb.set_trace() model.eval() resloss_en = 0 resloss_norm = 0 # accuracy match_en = 0 total_en = 0 # bleu hyp_corpus_en = [] ref_corpus_en = [] evaliter = iter(dataset.iter_loader) out_count = 0 with torch.no_grad(): for idx in range(len(evaliter)): batch_items = evaliter.next() # import pdb; pdb.set_trace() # load data batch_src_ids = batch_items['srcid'][0] batch_src_lengths = batch_items['srclen'] batch_acous_feats = batch_items['acous_feat'][0] batch_acous_lengths = batch_items['acouslen'] # separate into minibatch batch_size = batch_src_ids.size(0) batch_seq_len = int(max(batch_src_lengths)) n_minibatch = int(batch_size / self.minibatch_size) n_minibatch += int(batch_size % self.minibatch_size > 0) for bidx in range(n_minibatch): loss_en = NLLLoss() loss_en.reset() i_start = bidx * self.minibatch_size i_end = min(i_start + self.minibatch_size, batch_size) src_ids = batch_src_ids[i_start:i_end] src_lengths = batch_src_lengths[i_start:i_end] acous_feats = batch_acous_feats[i_start:i_end] acous_lengths = batch_acous_lengths[i_start:i_end] src_len = max(src_lengths) acous_len = max(acous_lengths) acous_len = acous_len + 8 - acous_len % 8 src_ids = src_ids.to(device=self.device) acous_feats = acous_feats[:, :acous_len].to( device=self.device) # debug oom # acous_feats, acous_lengths, src_ids, tgt_ids = self._debug_oom(acous_len, acous_feats) non_padding_mask_src = src_ids.data.ne(PAD) # [run-TF] to save time # out_dict = model.forward_train(src_ids, acous_feats=acous_feats, # acous_lens=acous_lengths, mode='ASR', use_gpu=self.use_gpu) # [run-FR] to get true stats out_dict = model.forward_eval(acous_feats=acous_feats, acous_lens=acous_lengths, mode='ASR', use_gpu=self.use_gpu) preds_en = out_dict['preds_asr'] logps_en = out_dict['logps_asr'] logps_hyp_en = logps_en preds_hyp_en = preds_en # evaluation if not self.eval_with_mask: loss_en.eval_batch( logps_hyp_en.reshape(-1, logps_hyp_en.size(-1)), src_ids[:, 1:].reshape(-1)) loss_en.norm_term = 1.0 * src_ids.size( 0) * src_ids[:, 1:].size(1) else: loss_en.eval_batch_with_mask( logps_hyp_en.reshape(-1, logps_hyp_en.size(-1)), src_ids[:, 1:].reshape(-1), non_padding_mask_src[:, 1:].reshape(-1)) loss_en.norm_term = 1.0 * torch.sum( non_padding_mask_src[:, 1:]) if self.normalise_loss: loss_en.normalise() resloss_en += loss_en.get_loss() resloss_norm += 1 # accuracy seqres_en = preds_hyp_en correct_en = seqres_en.reshape(-1).eq(src_ids[:,1:].reshape(-1))\ .masked_select(non_padding_mask_src[:,1:].reshape(-1)).sum().item() match_en += correct_en total_en += non_padding_mask_src[:, 1:].sum().item() # print out_count = self._print(out_count, src_ids, dataset.src_id2word, seqres_en, tail='-asr') # accumulate corpus hyp_corpus_en, ref_corpus_en = add2corpus( seqres_en, src_ids, dataset.src_id2word, hyp_corpus_en, ref_corpus_en, type='word') # import pdb; pdb.set_trace() bleu_en = torchtext.data.metrics.bleu_score(hyp_corpus_en, ref_corpus_en) # torch.cuda.empty_cache() if total_en == 0: accuracy_en = float('nan') else: accuracy_en = match_en / total_en resloss_en /= (1.0 * resloss_norm) losses = {} losses['nll_loss_de'] = 0 losses['nll_loss_en'] = resloss_en metrics = {} metrics['accuracy_de'] = 0 metrics['bleu_de'] = 0 metrics['accuracy_en'] = accuracy_en metrics['bleu_en'] = bleu_en return losses, metrics
def _train_batch(self, model, batch_items, dataset, step, total_steps): # load data batch_src_ids = batch_items['srcid'][0] batch_src_lengths = batch_items['srclen'] batch_tgt_ids = batch_items['tgtid'][0] batch_tgt_lengths = batch_items['tgtlen'] # separate into minibatch batch_size = batch_src_ids.size(0) batch_seq_len = int(max(batch_src_lengths)) n_minibatch = int(batch_size / self.minibatch_size) n_minibatch += int(batch_size % self.minibatch_size > 0) resloss_de = 0 resloss_en = 0 for bidx in range(n_minibatch): # debug # print(bidx,n_minibatch) # import pdb; pdb.set_trace() # define loss loss_de = NLLLoss() loss_de.reset() loss_en = NLLLoss() loss_en.reset() # load data i_start = bidx * self.minibatch_size i_end = min(i_start + self.minibatch_size, batch_size) src_ids = batch_src_ids[i_start:i_end] src_lengths = batch_src_lengths[i_start:i_end] tgt_ids = batch_tgt_ids[i_start:i_end] tgt_lengths = batch_tgt_lengths[i_start:i_end] src_len = max(src_lengths) tgt_len = max(tgt_lengths) src_ids = src_ids.to(device=self.device) tgt_ids = tgt_ids.to(device=self.device) # debug oom # acous_feats, acous_lengths, src_ids, tgt_ids = self._debug_oom(acous_len, acous_feats) # get padding mask non_padding_mask_src = src_ids.data.ne(PAD) non_padding_mask_tgt = tgt_ids.data.ne(PAD) # Forward propagation out_dict = model.forward_train(src_ids, tgt=tgt_ids, mode='AE_MT', use_gpu=self.use_gpu) logps_de = out_dict['logps_mt'][:, :-1, :] logps_en = out_dict['logps_ae'][:, 1:, :] # Get loss if not self.eval_with_mask: loss_de.eval_batch(logps_de.reshape(-1, logps_de.size(-1)), tgt_ids[:, 1:].reshape(-1)) loss_de.norm_term = 1.0 * tgt_ids.size(0) * tgt_ids[:, 1:].size(1) loss_en.eval_batch(logps_en.reshape(-1, logps_en.size(-1)), src_ids[:, 1:].reshape(-1)) loss_en.norm_term = 1.0 * src_ids.size(0) * src_ids[:, 1:].size(1) else: loss_de.eval_batch_with_mask( logps_de.reshape(-1, logps_de.size(-1)), tgt_ids[:, 1:].reshape(-1), non_padding_mask_tgt[:, 1:].reshape(-1)) loss_de.norm_term = 1.0 * torch.sum(non_padding_mask_tgt[:, 1:]) loss_en.eval_batch_with_mask( logps_en.reshape(-1, logps_en.size(-1)), src_ids[:, 1:].reshape(-1), non_padding_mask_src[:, 1:].reshape(-1)) loss_en.norm_term = 1.0 * torch.sum(non_padding_mask_src[:, 1:]) # import pdb; pdb.set_trace() # Backward propagation: accumulate gradient if self.normalise_loss: loss_de.normalise() loss_en.normalise() loss_de.acc_loss /= n_minibatch loss_de.acc_loss *= self.loss_coeff['loss_mt'] resloss_de += loss_de.get_loss() loss_en.acc_loss /= n_minibatch loss_en.acc_loss *= self.loss_coeff['loss_ae'] resloss_en += loss_en.get_loss() loss_en.add(loss_de) loss_en.backward() # torch.cuda.empty_cache() # update weights self.optimizer.step() model.zero_grad() losses = {} losses['nll_loss_de'] = resloss_de losses['nll_loss_en'] = resloss_en return losses
def _train_batch(self, src_ids, tgt_ids, model, step, total_steps, src_probs=None, src_labs=None): """ Args: src_ids = w1 w2 w3 </s> <pad> <pad> <pad> tgt_ids = <s> w1 w2 w3 </s> <pad> <pad> <pad> (optional) src_probs = p1 p2 p3 0 0 ... Others: internal input = <s> w1 w2 w3 </s> <pad> <pad> decoder_outputs = w1 w2 w3 </s> <pad> <pad> <pad> """ # define loss loss = NLLLoss() # scheduled sampling if not self.scheduled_sampling: teacher_forcing_ratio = self.teacher_forcing_ratio else: # use self.teacher_forcing_ratio as the starting point progress = 1.0 * step / total_steps teacher_forcing_ratio = 1.0 - progress # get padding mask non_padding_mask_src = src_ids.data.ne(PAD) non_padding_mask_tgt = tgt_ids.data.ne(PAD) # Forward propagation decoder_outputs, decoder_hidden, ret_dict = model( src_ids, tgt_ids, is_training=True, teacher_forcing_ratio=teacher_forcing_ratio, att_key_feats=src_probs) # Get loss loss.reset() # import pdb; pdb.set_trace() logps = torch.stack(decoder_outputs, dim=1).to(device=device) if not self.eval_with_mask: loss.eval_batch(logps.reshape(-1, logps.size(-1)), tgt_ids[:, 1:].reshape(-1)) else: loss.eval_batch_with_mask(logps.reshape(-1, logps.size(-1)), tgt_ids[:, 1:].reshape(-1), non_padding_mask_tgt[:, 1:].reshape(-1)) if not self.eval_with_mask: loss.norm_term = 1.0 * tgt_ids.size(0) * tgt_ids[:, 1:].size(1) else: loss.norm_term = 1.0 * torch.sum(non_padding_mask_tgt[:, 1:]) loss.normalise() # Backward propagation model.zero_grad() resloss = loss.get_loss() att_resloss = 0 dsfclassify_resloss = 0 loss.backward() self.optimizer.step() return resloss, att_resloss, dsfclassify_resloss
def _evaluate_batches(self, model, batches, dataset): model.eval() loss = NLLLoss() loss.reset() match = 0 total = 0 out_count = 0 with torch.no_grad(): for batch in batches: src_ids = batch['src_word_ids'] src_lengths = batch['src_sentence_lengths'] tgt_ids = batch['tgt_word_ids'] tgt_lengths = batch['tgt_sentence_lengths'] src_probs = None if 'src_ddfd_probs' in batch and model.additional_key_size > 0: src_probs = batch['src_ddfd_probs'] src_probs = _convert_to_tensor(src_probs, self.use_gpu).unsqueeze(2) src_labs = None if 'src_ddfd_labs' in batch: src_labs = batch['src_ddfd_labs'] src_labs = _convert_to_tensor(src_labs, self.use_gpu).unsqueeze(2) src_ids = _convert_to_tensor(src_ids, self.use_gpu) tgt_ids = _convert_to_tensor(tgt_ids, self.use_gpu) non_padding_mask_tgt = tgt_ids.data.ne(PAD) non_padding_mask_src = src_ids.data.ne(PAD) decoder_outputs, decoder_hidden, other = model( src_ids, tgt_ids, is_training=False, att_key_feats=src_probs) # Evaluation logps = torch.stack(decoder_outputs, dim=1).to(device=device) if not self.eval_with_mask: loss.eval_batch(logps.reshape(-1, logps.size(-1)), tgt_ids[:, 1:].reshape(-1)) else: loss.eval_batch_with_mask( logps.reshape(-1, logps.size(-1)), tgt_ids[:, 1:].reshape(-1), non_padding_mask_tgt[:, 1:].reshape(-1)) seqlist = other['sequence'] seqres = torch.stack(seqlist, dim=1).to(device=device) correct = seqres.view(-1).eq(tgt_ids[:,1:].reshape(-1))\ .masked_select(non_padding_mask_tgt[:,1:].reshape(-1)).sum().item() match += correct total += non_padding_mask_tgt[:, 1:].sum().item() if not self.eval_with_mask: loss.norm_term = 1.0 * tgt_ids.size( 0) * tgt_ids[:, 1:].size(1) else: loss.norm_term = 1.0 * torch.sum(non_padding_mask_tgt[:, 1:]) loss.normalise() if out_count < 3: srcwords = _convert_to_words_batchfirst( src_ids, dataset.tgt_id2word) refwords = _convert_to_words_batchfirst( tgt_ids[:, 1:], dataset.tgt_id2word) seqwords = _convert_to_words(seqlist, dataset.tgt_id2word) outsrc = 'SRC: {}\n'.format(' '.join( srcwords[0])).encode('utf-8') outref = 'REF: {}\n'.format(' '.join( refwords[0])).encode('utf-8') outline = 'GEN: {}\n'.format(' '.join( seqwords[0])).encode('utf-8') sys.stdout.buffer.write(outsrc) sys.stdout.buffer.write(outref) sys.stdout.buffer.write(outline) out_count += 1 att_resloss = 0 attcls_resloss = 0 resloss = loss.get_loss() if total == 0: accuracy = float('nan') else: accuracy = match / total torch.cuda.empty_cache() losses = {} losses['att_loss'] = att_resloss losses['attcls_loss'] = attcls_resloss return resloss, accuracy, losses