def _train_batch(self, model, batch_items, dataset, step, total_steps): # -- scheduled sampling -- if not self.scheduled_sampling: teacher_forcing_ratio = self.teacher_forcing_ratio else: progress = 1.0 * step / total_steps teacher_forcing_ratio = 1.0 - progress # -- LOAD BATCH -- batch_src_ids = batch_items[0][0] batch_src_lengths = batch_items[1] batch_acous_feats = batch_items[2][0] batch_acous_lengths = batch_items[3] # -- CONSTRUCT MINIBATCH -- batch_size = batch_src_ids.size(0) batch_seq_len = int(max(batch_src_lengths)) batch_acous_len = int(max(batch_acous_lengths)) n_minibatch = int(batch_size / self.minibatch_size) n_minibatch += int(batch_size % self.minibatch_size > 0) las_resloss = 0 # minibatch for bidx in range(n_minibatch): # debug # import pdb; pdb.set_trace() # define loss las_loss = NLLLoss() las_loss.reset() # load data i_start = bidx * self.minibatch_size i_end = min(i_start + self.minibatch_size, batch_size) src_ids = batch_src_ids[i_start:i_end] src_lengths = batch_src_lengths[i_start:i_end] acous_feats = batch_acous_feats[i_start:i_end] acous_lengths = batch_acous_lengths[i_start:i_end] seq_len = max(src_lengths) acous_len = max(acous_lengths) acous_len = acous_len + 8 - acous_len % 8 src_ids = src_ids[:, :seq_len].to(device=self.device) acous_feats = acous_feats[:, :acous_len].to(device=self.device) # sanity check src if step == 1: check_src_tensor_print(src_ids, dataset.src_id2word) # get padding mask non_padding_mask_src = src_ids.data.ne(PAD) # Forward propagation decoder_outputs, decoder_hidden, ret_dict = model( acous_feats, acous_lens=acous_lengths, tgt=src_ids, is_training=True, teacher_forcing_ratio=teacher_forcing_ratio, use_gpu=self.use_gpu) logps = torch.stack(decoder_outputs, dim=1).to(device=self.device) las_loss.eval_batch_with_mask(logps.reshape(-1, logps.size(-1)), src_ids.reshape(-1), non_padding_mask_src.reshape(-1)) las_loss.norm_term = 1.0 * torch.sum(non_padding_mask_src) # import pdb; pdb.set_trace() # Backward propagation: accumulate gradient if self.normalise_loss: las_loss.normalise() las_loss.acc_loss /= n_minibatch las_loss.backward() las_resloss += las_loss.get_loss() torch.cuda.empty_cache() # update weights self.optimizer.step() model.zero_grad() losses = {'las_loss': las_resloss} return losses
def _train_batch(self, model, batch_items, dataset, step, total_steps): """ Args: src_ids = w1 w2 w3 </s> <pad> <pad> <pad> tgt_ids = <s> w1 w2 w3 </s> <pad> <pad> <pad> Others: internal input = <s> w1 w2 w3 </s> <pad> <pad> decoder_outputs = w1 w2 w3 </s> <pad> <pad> <pad> """ # load data batch_src_ids = batch_items['srcid'][0] batch_src_lengths = batch_items['srclen'] batch_tgt_ids = batch_items['tgtid'][0] batch_tgt_lengths = batch_items['tgtlen'] # separate into minibatch batch_size = batch_src_ids.size(0) batch_seq_len = int(max(batch_src_lengths)) n_minibatch = int(batch_size / self.minibatch_size) n_minibatch += int(batch_size % self.minibatch_size > 0) resloss = 0 for bidx in range(n_minibatch): # debug # import pdb; pdb.set_trace() # define loss loss = NLLLoss() loss.reset() # load data i_start = bidx * self.minibatch_size i_end = min(i_start + self.minibatch_size, batch_size) src_ids = batch_src_ids[i_start:i_end] src_lengths = batch_src_lengths[i_start:i_end] tgt_ids = batch_tgt_ids[i_start:i_end] tgt_lengths = batch_tgt_lengths[i_start:i_end] src_len = max(src_lengths) tgt_len = max(tgt_lengths) src_ids = src_ids[:, :src_len].to(device=self.device) tgt_ids = tgt_ids.to(device=self.device) # get padding mask non_padding_mask_src = src_ids.data.ne(PAD) non_padding_mask_tgt = tgt_ids.data.ne(PAD) # Forward propagation preds, logps, dec_outputs = model.forward_train( src_ids, tgt_ids, use_gpu=self.use_gpu) # Get loss if not self.eval_with_mask: loss.eval_batch(logps[:, :-1, :].reshape(-1, logps.size(-1)), tgt_ids[:, 1:].reshape(-1)) loss.norm_term = 1.0 * tgt_ids.size(0) * tgt_ids[:, 1:].size(1) else: loss.eval_batch_with_mask( logps[:, :-1, :].reshape(-1, logps.size(-1)), tgt_ids[:, 1:].reshape(-1), non_padding_mask_tgt[:, 1:].reshape(-1)) loss.norm_term = 1.0 * torch.sum(non_padding_mask_tgt[:, 1:]) # import pdb; pdb.set_trace() # Backward propagation: accumulate gradient if self.normalise_loss: loss.normalise() loss.acc_loss /= n_minibatch loss.backward() resloss += loss.get_loss() torch.cuda.empty_cache() # update weights self.optimizer.step() model.zero_grad() return resloss
def _train_batch(self, model, batch_items, dataset, step, total_steps): # load data batch_src_ids = batch_items['srcid'][0] batch_src_lengths = batch_items['srclen'] batch_tgt_ids = batch_items['tgtid'][0] batch_tgt_lengths = batch_items['tgtlen'] # separate into minibatch batch_size = batch_src_ids.size(0) batch_seq_len = int(max(batch_src_lengths)) n_minibatch = int(batch_size / self.minibatch_size) n_minibatch += int(batch_size % self.minibatch_size > 0) resloss_de = 0 resloss_en = 0 for bidx in range(n_minibatch): # debug # print(bidx,n_minibatch) # import pdb; pdb.set_trace() # define loss loss_de = NLLLoss() loss_de.reset() loss_en = NLLLoss() loss_en.reset() # load data i_start = bidx * self.minibatch_size i_end = min(i_start + self.minibatch_size, batch_size) src_ids = batch_src_ids[i_start:i_end] src_lengths = batch_src_lengths[i_start:i_end] tgt_ids = batch_tgt_ids[i_start:i_end] tgt_lengths = batch_tgt_lengths[i_start:i_end] src_len = max(src_lengths) tgt_len = max(tgt_lengths) src_ids = src_ids.to(device=self.device) tgt_ids = tgt_ids.to(device=self.device) # debug oom # acous_feats, acous_lengths, src_ids, tgt_ids = self._debug_oom(acous_len, acous_feats) # get padding mask non_padding_mask_src = src_ids.data.ne(PAD) non_padding_mask_tgt = tgt_ids.data.ne(PAD) # Forward propagation out_dict = model.forward_train(src_ids, tgt=tgt_ids, mode='AE_MT', use_gpu=self.use_gpu) logps_de = out_dict['logps_mt'][:, :-1, :] logps_en = out_dict['logps_ae'][:, 1:, :] # Get loss if not self.eval_with_mask: loss_de.eval_batch(logps_de.reshape(-1, logps_de.size(-1)), tgt_ids[:, 1:].reshape(-1)) loss_de.norm_term = 1.0 * tgt_ids.size(0) * tgt_ids[:, 1:].size(1) loss_en.eval_batch(logps_en.reshape(-1, logps_en.size(-1)), src_ids[:, 1:].reshape(-1)) loss_en.norm_term = 1.0 * src_ids.size(0) * src_ids[:, 1:].size(1) else: loss_de.eval_batch_with_mask( logps_de.reshape(-1, logps_de.size(-1)), tgt_ids[:, 1:].reshape(-1), non_padding_mask_tgt[:, 1:].reshape(-1)) loss_de.norm_term = 1.0 * torch.sum(non_padding_mask_tgt[:, 1:]) loss_en.eval_batch_with_mask( logps_en.reshape(-1, logps_en.size(-1)), src_ids[:, 1:].reshape(-1), non_padding_mask_src[:, 1:].reshape(-1)) loss_en.norm_term = 1.0 * torch.sum(non_padding_mask_src[:, 1:]) # import pdb; pdb.set_trace() # Backward propagation: accumulate gradient if self.normalise_loss: loss_de.normalise() loss_en.normalise() loss_de.acc_loss /= n_minibatch loss_de.acc_loss *= self.loss_coeff['loss_mt'] resloss_de += loss_de.get_loss() loss_en.acc_loss /= n_minibatch loss_en.acc_loss *= self.loss_coeff['loss_ae'] resloss_en += loss_en.get_loss() loss_en.add(loss_de) loss_en.backward() # torch.cuda.empty_cache() # update weights self.optimizer.step() model.zero_grad() losses = {} losses['nll_loss_de'] = resloss_de losses['nll_loss_en'] = resloss_en return losses
def _train_batch(self, model, batch_items, dataset, step, total_steps): # load data batch_src_ids = batch_items['srcid'][0] batch_src_lengths = batch_items['srclen'] batch_acous_feats = batch_items['acous_feat'][0] batch_acous_lengths = batch_items['acouslen'] # separate into minibatch batch_size = batch_src_ids.size(0) batch_seq_len = int(max(batch_src_lengths)) n_minibatch = int(batch_size / self.minibatch_size) n_minibatch += int(batch_size % self.minibatch_size > 0) resloss_asr = 0 resloss_ae = 0 resloss_kl = 0 resloss_l2 = 0 for bidx in range(n_minibatch): # debug # print(bidx,n_minibatch) # import pdb; pdb.set_trace() # define loss loss_asr = NLLLoss() loss_asr.reset() loss_ae = NLLLoss() loss_ae.reset() loss_kl = KLDivLoss() loss_kl.reset() loss_l2 = MSELoss() loss_l2.reset() # load data i_start = bidx * self.minibatch_size i_end = min(i_start + self.minibatch_size, batch_size) src_ids = batch_src_ids[i_start:i_end] src_lengths = batch_src_lengths[i_start:i_end] acous_feats = batch_acous_feats[i_start:i_end] acous_lengths = batch_acous_lengths[i_start:i_end] src_len = max(src_lengths) acous_len = max(acous_lengths) acous_len = acous_len + 8 - acous_len % 8 src_ids = src_ids.to(device=self.device) acous_feats = acous_feats[:, :acous_len].to(device=self.device) # debug oom # acous_feats, acous_lengths, src_ids, tgt_ids = self._debug_oom(acous_len, acous_feats) # get padding mask non_padding_mask_src = src_ids.data.ne(PAD) # Forward propagation out_dict = model.forward_train(src_ids, acous_feats=acous_feats, acous_lens=acous_lengths, mode='AE_ASR', use_gpu=self.use_gpu) logps_asr = out_dict['logps_asr'] emb_asr = out_dict['emb_asr'] logps_ae = out_dict['logps_ae'] emb_ae = out_dict['emb_ae'] refs_ae = out_dict['refs_ae'] # import pdb; pdb.set_trace() # Get loss if not self.eval_with_mask: loss_asr.eval_batch(logps_asr.reshape(-1, logps_asr.size(-1)), src_ids[:, 1:].reshape(-1)) loss_asr.norm_term = 1.0 * src_ids.size( 0) * src_ids[:, 1:].size(1) loss_ae.eval_batch(logps_ae.reshape(-1, logps_ae.size(-1)), refs_ae.reshape(-1)) loss_ae.norm_term = 1.0 * src_ids.size(0) * src_ids[:, 1:].size(1) loss_kl.eval_batch(logps_ae.reshape(-1, logps_ae.size(-1)), logps_asr.reshape(-1, logps_asr.size(-1))) loss_kl.norm_term = 1.0 * src_ids.size(0) * src_ids[:, 1:].size(1) loss_l2.eval_batch(emb_asr.reshape(-1, emb_asr.size(-1)), emb_ae.reshape(-1, emb_ae.size(-1))) loss_l2.norm_term = 1.0 * src_ids.size(0) * src_ids[:, 1:].size(1) else: loss_asr.eval_batch_with_mask( logps_asr.reshape(-1, logps_asr.size(-1)), src_ids[:, 1:].reshape(-1), non_padding_mask_src[:, 1:].reshape(-1)) loss_asr.norm_term = 1.0 * torch.sum(non_padding_mask_src[:, 1:]) loss_ae.eval_batch_with_mask( logps_ae.reshape(-1, logps_ae.size(-1)), refs_ae.reshape(-1), non_padding_mask_src[:, 1:].reshape(-1)) loss_ae.norm_term = 1.0 * torch.sum(non_padding_mask_src[:, 1:]) loss_kl.eval_batch_with_mask( logps_ae.reshape(-1, logps_ae.size(-1)), logps_asr.reshape(-1, logps_asr.size(-1)), non_padding_mask_src[:, 1:].reshape(-1)) loss_kl.norm_term = 1.0 * torch.sum(non_padding_mask_src[:, 1:]) loss_l2.eval_batch_with_mask( emb_asr.reshape(-1, emb_asr.size(-1)), emb_ae.reshape(-1, emb_ae.size(-1)), non_padding_mask_src[:, 1:].reshape(-1)) loss_l2.norm_term = 1.0 * torch.sum(non_padding_mask_src[:, 1:]) # Backward propagation: accumulate gradient if self.normalise_loss: loss_asr.normalise() loss_ae.normalise() loss_kl.normalise() loss_l2.normalise() loss_asr.acc_loss *= self.loss_coeff['nll_asr'] loss_asr.acc_loss /= n_minibatch resloss_asr += loss_asr.get_loss() loss_ae.acc_loss *= self.loss_coeff['nll_ae'] loss_ae.acc_loss /= n_minibatch resloss_ae += loss_ae.get_loss() loss_kl.acc_loss *= self.loss_coeff['kl_en'] loss_kl.acc_loss /= n_minibatch resloss_kl += loss_kl.get_loss() loss_l2.acc_loss *= self.loss_coeff['l2'] loss_l2.acc_loss /= n_minibatch resloss_l2 += loss_l2.get_loss() loss_asr.add(loss_ae) loss_asr.add(loss_kl) loss_asr.add(loss_l2) loss_asr.backward() # torch.cuda.empty_cache() # update weights self.optimizer.step() model.zero_grad() losses = {} losses['nll_loss_asr'] = resloss_asr losses['nll_loss_ae'] = resloss_ae losses['kl_loss'] = resloss_kl losses['l2_loss'] = resloss_l2 return losses
def _train_batch(self, src_ids, tgt_ids, model, step, total_steps, src_probs=None, src_labs=None): """ Args: src_ids = w1 w2 w3 </s> <pad> <pad> <pad> tgt_ids = <s> w1 w2 w3 </s> <pad> <pad> <pad> (optional) src_probs = p1 p2 p3 0 0 ... Others: internal input = <s> w1 w2 w3 </s> <pad> <pad> decoder_outputs = w1 w2 w3 </s> <pad> <pad> <pad> """ # define loss loss = NLLLoss() # scheduled sampling if not self.scheduled_sampling: teacher_forcing_ratio = self.teacher_forcing_ratio else: # use self.teacher_forcing_ratio as the starting point progress = 1.0 * step / total_steps teacher_forcing_ratio = 1.0 - progress # get padding mask non_padding_mask_src = src_ids.data.ne(PAD) non_padding_mask_tgt = tgt_ids.data.ne(PAD) # Forward propagation decoder_outputs, decoder_hidden, ret_dict = model( src_ids, tgt_ids, is_training=True, teacher_forcing_ratio=teacher_forcing_ratio, att_key_feats=src_probs) # Get loss loss.reset() # import pdb; pdb.set_trace() logps = torch.stack(decoder_outputs, dim=1).to(device=device) if not self.eval_with_mask: loss.eval_batch(logps.reshape(-1, logps.size(-1)), tgt_ids[:, 1:].reshape(-1)) else: loss.eval_batch_with_mask(logps.reshape(-1, logps.size(-1)), tgt_ids[:, 1:].reshape(-1), non_padding_mask_tgt[:, 1:].reshape(-1)) if not self.eval_with_mask: loss.norm_term = 1.0 * tgt_ids.size(0) * tgt_ids[:, 1:].size(1) else: loss.norm_term = 1.0 * torch.sum(non_padding_mask_tgt[:, 1:]) loss.normalise() # Backward propagation model.zero_grad() resloss = loss.get_loss() att_resloss = 0 dsfclassify_resloss = 0 loss.backward() self.optimizer.step() return resloss, att_resloss, dsfclassify_resloss