def _train_batch(self, model, batch_items, dataset, step, total_steps): # load data batch_src_ids = batch_items['srcid'][0] batch_src_lengths = batch_items['srclen'] batch_acous_feats = batch_items['acous_feat'][0] batch_acous_lengths = batch_items['acouslen'] # separate into minibatch batch_size = batch_src_ids.size(0) batch_seq_len = int(max(batch_src_lengths)) n_minibatch = int(batch_size / self.minibatch_size) n_minibatch += int(batch_size % self.minibatch_size > 0) resloss_asr = 0 resloss_ae = 0 resloss_kl = 0 resloss_l2 = 0 for bidx in range(n_minibatch): # debug # print(bidx,n_minibatch) # import pdb; pdb.set_trace() # define loss loss_asr = NLLLoss() loss_asr.reset() loss_ae = NLLLoss() loss_ae.reset() loss_kl = KLDivLoss() loss_kl.reset() loss_l2 = MSELoss() loss_l2.reset() # load data i_start = bidx * self.minibatch_size i_end = min(i_start + self.minibatch_size, batch_size) src_ids = batch_src_ids[i_start:i_end] src_lengths = batch_src_lengths[i_start:i_end] acous_feats = batch_acous_feats[i_start:i_end] acous_lengths = batch_acous_lengths[i_start:i_end] src_len = max(src_lengths) acous_len = max(acous_lengths) acous_len = acous_len + 8 - acous_len % 8 src_ids = src_ids.to(device=self.device) acous_feats = acous_feats[:, :acous_len].to(device=self.device) # debug oom # acous_feats, acous_lengths, src_ids, tgt_ids = self._debug_oom(acous_len, acous_feats) # get padding mask non_padding_mask_src = src_ids.data.ne(PAD) # Forward propagation out_dict = model.forward_train(src_ids, acous_feats=acous_feats, acous_lens=acous_lengths, mode='AE_ASR', use_gpu=self.use_gpu) logps_asr = out_dict['logps_asr'] emb_asr = out_dict['emb_asr'] logps_ae = out_dict['logps_ae'] emb_ae = out_dict['emb_ae'] refs_ae = out_dict['refs_ae'] # import pdb; pdb.set_trace() # Get loss if not self.eval_with_mask: loss_asr.eval_batch(logps_asr.reshape(-1, logps_asr.size(-1)), src_ids[:, 1:].reshape(-1)) loss_asr.norm_term = 1.0 * src_ids.size( 0) * src_ids[:, 1:].size(1) loss_ae.eval_batch(logps_ae.reshape(-1, logps_ae.size(-1)), refs_ae.reshape(-1)) loss_ae.norm_term = 1.0 * src_ids.size(0) * src_ids[:, 1:].size(1) loss_kl.eval_batch(logps_ae.reshape(-1, logps_ae.size(-1)), logps_asr.reshape(-1, logps_asr.size(-1))) loss_kl.norm_term = 1.0 * src_ids.size(0) * src_ids[:, 1:].size(1) loss_l2.eval_batch(emb_asr.reshape(-1, emb_asr.size(-1)), emb_ae.reshape(-1, emb_ae.size(-1))) loss_l2.norm_term = 1.0 * src_ids.size(0) * src_ids[:, 1:].size(1) else: loss_asr.eval_batch_with_mask( logps_asr.reshape(-1, logps_asr.size(-1)), src_ids[:, 1:].reshape(-1), non_padding_mask_src[:, 1:].reshape(-1)) loss_asr.norm_term = 1.0 * torch.sum(non_padding_mask_src[:, 1:]) loss_ae.eval_batch_with_mask( logps_ae.reshape(-1, logps_ae.size(-1)), refs_ae.reshape(-1), non_padding_mask_src[:, 1:].reshape(-1)) loss_ae.norm_term = 1.0 * torch.sum(non_padding_mask_src[:, 1:]) loss_kl.eval_batch_with_mask( logps_ae.reshape(-1, logps_ae.size(-1)), logps_asr.reshape(-1, logps_asr.size(-1)), non_padding_mask_src[:, 1:].reshape(-1)) loss_kl.norm_term = 1.0 * torch.sum(non_padding_mask_src[:, 1:]) loss_l2.eval_batch_with_mask( emb_asr.reshape(-1, emb_asr.size(-1)), emb_ae.reshape(-1, emb_ae.size(-1)), non_padding_mask_src[:, 1:].reshape(-1)) loss_l2.norm_term = 1.0 * torch.sum(non_padding_mask_src[:, 1:]) # Backward propagation: accumulate gradient if self.normalise_loss: loss_asr.normalise() loss_ae.normalise() loss_kl.normalise() loss_l2.normalise() loss_asr.acc_loss *= self.loss_coeff['nll_asr'] loss_asr.acc_loss /= n_minibatch resloss_asr += loss_asr.get_loss() loss_ae.acc_loss *= self.loss_coeff['nll_ae'] loss_ae.acc_loss /= n_minibatch resloss_ae += loss_ae.get_loss() loss_kl.acc_loss *= self.loss_coeff['kl_en'] loss_kl.acc_loss /= n_minibatch resloss_kl += loss_kl.get_loss() loss_l2.acc_loss *= self.loss_coeff['l2'] loss_l2.acc_loss /= n_minibatch resloss_l2 += loss_l2.get_loss() loss_asr.add(loss_ae) loss_asr.add(loss_kl) loss_asr.add(loss_l2) loss_asr.backward() # torch.cuda.empty_cache() # update weights self.optimizer.step() model.zero_grad() losses = {} losses['nll_loss_asr'] = resloss_asr losses['nll_loss_ae'] = resloss_ae losses['kl_loss'] = resloss_kl losses['l2_loss'] = resloss_l2 return losses
def _train_batch(self, model, batch_items, dataset, step, total_steps): # load data batch_src_ids = batch_items['srcid'][0] batch_src_lengths = batch_items['srclen'] batch_tgt_ids = batch_items['tgtid'][0] batch_tgt_lengths = batch_items['tgtlen'] # separate into minibatch batch_size = batch_src_ids.size(0) batch_seq_len = int(max(batch_src_lengths)) n_minibatch = int(batch_size / self.minibatch_size) n_minibatch += int(batch_size % self.minibatch_size > 0) resloss_de = 0 resloss_en = 0 for bidx in range(n_minibatch): # debug # print(bidx,n_minibatch) # import pdb; pdb.set_trace() # define loss loss_de = NLLLoss() loss_de.reset() loss_en = NLLLoss() loss_en.reset() # load data i_start = bidx * self.minibatch_size i_end = min(i_start + self.minibatch_size, batch_size) src_ids = batch_src_ids[i_start:i_end] src_lengths = batch_src_lengths[i_start:i_end] tgt_ids = batch_tgt_ids[i_start:i_end] tgt_lengths = batch_tgt_lengths[i_start:i_end] src_len = max(src_lengths) tgt_len = max(tgt_lengths) src_ids = src_ids.to(device=self.device) tgt_ids = tgt_ids.to(device=self.device) # debug oom # acous_feats, acous_lengths, src_ids, tgt_ids = self._debug_oom(acous_len, acous_feats) # get padding mask non_padding_mask_src = src_ids.data.ne(PAD) non_padding_mask_tgt = tgt_ids.data.ne(PAD) # Forward propagation out_dict = model.forward_train(src_ids, tgt=tgt_ids, mode='AE_MT', use_gpu=self.use_gpu) logps_de = out_dict['logps_mt'][:, :-1, :] logps_en = out_dict['logps_ae'][:, 1:, :] # Get loss if not self.eval_with_mask: loss_de.eval_batch(logps_de.reshape(-1, logps_de.size(-1)), tgt_ids[:, 1:].reshape(-1)) loss_de.norm_term = 1.0 * tgt_ids.size(0) * tgt_ids[:, 1:].size(1) loss_en.eval_batch(logps_en.reshape(-1, logps_en.size(-1)), src_ids[:, 1:].reshape(-1)) loss_en.norm_term = 1.0 * src_ids.size(0) * src_ids[:, 1:].size(1) else: loss_de.eval_batch_with_mask( logps_de.reshape(-1, logps_de.size(-1)), tgt_ids[:, 1:].reshape(-1), non_padding_mask_tgt[:, 1:].reshape(-1)) loss_de.norm_term = 1.0 * torch.sum(non_padding_mask_tgt[:, 1:]) loss_en.eval_batch_with_mask( logps_en.reshape(-1, logps_en.size(-1)), src_ids[:, 1:].reshape(-1), non_padding_mask_src[:, 1:].reshape(-1)) loss_en.norm_term = 1.0 * torch.sum(non_padding_mask_src[:, 1:]) # import pdb; pdb.set_trace() # Backward propagation: accumulate gradient if self.normalise_loss: loss_de.normalise() loss_en.normalise() loss_de.acc_loss /= n_minibatch loss_de.acc_loss *= self.loss_coeff['loss_mt'] resloss_de += loss_de.get_loss() loss_en.acc_loss /= n_minibatch loss_en.acc_loss *= self.loss_coeff['loss_ae'] resloss_en += loss_en.get_loss() loss_en.add(loss_de) loss_en.backward() # torch.cuda.empty_cache() # update weights self.optimizer.step() model.zero_grad() losses = {} losses['nll_loss_de'] = resloss_de losses['nll_loss_en'] = resloss_en return losses