示例#1
0
    def _train_batch(self, model, batch_items, dataset, step, total_steps):

        # -- scheduled sampling --
        if not self.scheduled_sampling:
            teacher_forcing_ratio = self.teacher_forcing_ratio
        else:
            progress = 1.0 * step / total_steps
            teacher_forcing_ratio = 1.0 - progress

        # -- LOAD BATCH --
        batch_src_ids = batch_items[0][0]
        batch_src_lengths = batch_items[1]
        batch_acous_feats = batch_items[2][0]
        batch_acous_lengths = batch_items[3]

        # -- CONSTRUCT MINIBATCH --
        batch_size = batch_src_ids.size(0)
        batch_seq_len = int(max(batch_src_lengths))
        batch_acous_len = int(max(batch_acous_lengths))

        n_minibatch = int(batch_size / self.minibatch_size)
        n_minibatch += int(batch_size % self.minibatch_size > 0)

        las_resloss = 0

        # minibatch
        for bidx in range(n_minibatch):

            # debug
            # import pdb; pdb.set_trace()

            # define loss
            las_loss = NLLLoss()
            las_loss.reset()

            # load data
            i_start = bidx * self.minibatch_size
            i_end = min(i_start + self.minibatch_size, batch_size)
            src_ids = batch_src_ids[i_start:i_end]
            src_lengths = batch_src_lengths[i_start:i_end]
            acous_feats = batch_acous_feats[i_start:i_end]
            acous_lengths = batch_acous_lengths[i_start:i_end]

            seq_len = max(src_lengths)
            acous_len = max(acous_lengths)
            acous_len = acous_len + 8 - acous_len % 8
            src_ids = src_ids[:, :seq_len].to(device=self.device)
            acous_feats = acous_feats[:, :acous_len].to(device=self.device)

            # sanity check src
            if step == 1: check_src_tensor_print(src_ids, dataset.src_id2word)

            # get padding mask
            non_padding_mask_src = src_ids.data.ne(PAD)

            # Forward propagation
            decoder_outputs, decoder_hidden, ret_dict = model(
                acous_feats,
                acous_lens=acous_lengths,
                tgt=src_ids,
                is_training=True,
                teacher_forcing_ratio=teacher_forcing_ratio,
                use_gpu=self.use_gpu)

            logps = torch.stack(decoder_outputs, dim=1).to(device=self.device)
            las_loss.eval_batch_with_mask(logps.reshape(-1, logps.size(-1)),
                                          src_ids.reshape(-1),
                                          non_padding_mask_src.reshape(-1))
            las_loss.norm_term = 1.0 * torch.sum(non_padding_mask_src)

            # import pdb; pdb.set_trace()
            # Backward propagation: accumulate gradient
            if self.normalise_loss: las_loss.normalise()
            las_loss.acc_loss /= n_minibatch
            las_loss.backward()
            las_resloss += las_loss.get_loss()
            torch.cuda.empty_cache()

        # update weights
        self.optimizer.step()
        model.zero_grad()
        losses = {'las_loss': las_resloss}

        return losses
示例#2
0
    def _train_batch(self, model, batch_items, dataset, step, total_steps):
        """
			Args:
				src_ids 		=     w1 w2 w3 </s> <pad> <pad> <pad>
				tgt_ids 		= <s> w1 w2 w3 </s> <pad> <pad> <pad>

			Others:
				internal input 	= <s> w1 w2 w3 </s> <pad> <pad>
				decoder_outputs	= 	  w1 w2 w3 </s> <pad> <pad> <pad>
		"""

        # load data
        batch_src_ids = batch_items['srcid'][0]
        batch_src_lengths = batch_items['srclen']
        batch_tgt_ids = batch_items['tgtid'][0]
        batch_tgt_lengths = batch_items['tgtlen']

        # separate into minibatch
        batch_size = batch_src_ids.size(0)
        batch_seq_len = int(max(batch_src_lengths))
        n_minibatch = int(batch_size / self.minibatch_size)
        n_minibatch += int(batch_size % self.minibatch_size > 0)
        resloss = 0

        for bidx in range(n_minibatch):

            # debug
            # import pdb; pdb.set_trace()

            # define loss
            loss = NLLLoss()
            loss.reset()

            # load data
            i_start = bidx * self.minibatch_size
            i_end = min(i_start + self.minibatch_size, batch_size)
            src_ids = batch_src_ids[i_start:i_end]
            src_lengths = batch_src_lengths[i_start:i_end]
            tgt_ids = batch_tgt_ids[i_start:i_end]
            tgt_lengths = batch_tgt_lengths[i_start:i_end]

            src_len = max(src_lengths)
            tgt_len = max(tgt_lengths)
            src_ids = src_ids[:, :src_len].to(device=self.device)
            tgt_ids = tgt_ids.to(device=self.device)

            # get padding mask
            non_padding_mask_src = src_ids.data.ne(PAD)
            non_padding_mask_tgt = tgt_ids.data.ne(PAD)

            # Forward propagation
            preds, logps, dec_outputs = model.forward_train(
                src_ids, tgt_ids, use_gpu=self.use_gpu)

            # Get loss
            if not self.eval_with_mask:
                loss.eval_batch(logps[:, :-1, :].reshape(-1, logps.size(-1)),
                                tgt_ids[:, 1:].reshape(-1))
                loss.norm_term = 1.0 * tgt_ids.size(0) * tgt_ids[:, 1:].size(1)
            else:
                loss.eval_batch_with_mask(
                    logps[:, :-1, :].reshape(-1, logps.size(-1)),
                    tgt_ids[:, 1:].reshape(-1),
                    non_padding_mask_tgt[:, 1:].reshape(-1))
                loss.norm_term = 1.0 * torch.sum(non_padding_mask_tgt[:, 1:])

            # import pdb; pdb.set_trace()
            # Backward propagation: accumulate gradient
            if self.normalise_loss: loss.normalise()
            loss.acc_loss /= n_minibatch
            loss.backward()
            resloss += loss.get_loss()
            torch.cuda.empty_cache()

        # update weights
        self.optimizer.step()
        model.zero_grad()

        return resloss
    def _train_batch(self, model, batch_items, dataset, step, total_steps):

        # load data
        batch_src_ids = batch_items['srcid'][0]
        batch_src_lengths = batch_items['srclen']
        batch_tgt_ids = batch_items['tgtid'][0]
        batch_tgt_lengths = batch_items['tgtlen']

        # separate into minibatch
        batch_size = batch_src_ids.size(0)
        batch_seq_len = int(max(batch_src_lengths))
        n_minibatch = int(batch_size / self.minibatch_size)
        n_minibatch += int(batch_size % self.minibatch_size > 0)
        resloss_de = 0
        resloss_en = 0

        for bidx in range(n_minibatch):

            # debug
            # print(bidx,n_minibatch)
            # import pdb; pdb.set_trace()

            # define loss
            loss_de = NLLLoss()
            loss_de.reset()
            loss_en = NLLLoss()
            loss_en.reset()

            # load data
            i_start = bidx * self.minibatch_size
            i_end = min(i_start + self.minibatch_size, batch_size)
            src_ids = batch_src_ids[i_start:i_end]
            src_lengths = batch_src_lengths[i_start:i_end]
            tgt_ids = batch_tgt_ids[i_start:i_end]
            tgt_lengths = batch_tgt_lengths[i_start:i_end]

            src_len = max(src_lengths)
            tgt_len = max(tgt_lengths)
            src_ids = src_ids.to(device=self.device)
            tgt_ids = tgt_ids.to(device=self.device)

            # debug oom
            # acous_feats, acous_lengths, src_ids, tgt_ids = self._debug_oom(acous_len, acous_feats)

            # get padding mask
            non_padding_mask_src = src_ids.data.ne(PAD)
            non_padding_mask_tgt = tgt_ids.data.ne(PAD)

            # Forward propagation
            out_dict = model.forward_train(src_ids,
                                           tgt=tgt_ids,
                                           mode='AE_MT',
                                           use_gpu=self.use_gpu)

            logps_de = out_dict['logps_mt'][:, :-1, :]
            logps_en = out_dict['logps_ae'][:, 1:, :]

            # Get loss
            if not self.eval_with_mask:
                loss_de.eval_batch(logps_de.reshape(-1, logps_de.size(-1)),
                                   tgt_ids[:, 1:].reshape(-1))
                loss_de.norm_term = 1.0 * tgt_ids.size(0) * tgt_ids[:,
                                                                    1:].size(1)
                loss_en.eval_batch(logps_en.reshape(-1, logps_en.size(-1)),
                                   src_ids[:, 1:].reshape(-1))
                loss_en.norm_term = 1.0 * src_ids.size(0) * src_ids[:,
                                                                    1:].size(1)
            else:
                loss_de.eval_batch_with_mask(
                    logps_de.reshape(-1, logps_de.size(-1)),
                    tgt_ids[:, 1:].reshape(-1),
                    non_padding_mask_tgt[:, 1:].reshape(-1))
                loss_de.norm_term = 1.0 * torch.sum(non_padding_mask_tgt[:,
                                                                         1:])
                loss_en.eval_batch_with_mask(
                    logps_en.reshape(-1, logps_en.size(-1)),
                    src_ids[:, 1:].reshape(-1),
                    non_padding_mask_src[:, 1:].reshape(-1))
                loss_en.norm_term = 1.0 * torch.sum(non_padding_mask_src[:,
                                                                         1:])

            # import pdb; pdb.set_trace()
            # Backward propagation: accumulate gradient
            if self.normalise_loss:
                loss_de.normalise()
                loss_en.normalise()

            loss_de.acc_loss /= n_minibatch
            loss_de.acc_loss *= self.loss_coeff['loss_mt']
            resloss_de += loss_de.get_loss()

            loss_en.acc_loss /= n_minibatch
            loss_en.acc_loss *= self.loss_coeff['loss_ae']
            resloss_en += loss_en.get_loss()

            loss_en.add(loss_de)
            loss_en.backward()

            # torch.cuda.empty_cache()

        # update weights
        self.optimizer.step()
        model.zero_grad()

        losses = {}
        losses['nll_loss_de'] = resloss_de
        losses['nll_loss_en'] = resloss_en

        return losses
示例#4
0
    def _train_batch(self, model, batch_items, dataset, step, total_steps):

        # load data
        batch_src_ids = batch_items['srcid'][0]
        batch_src_lengths = batch_items['srclen']
        batch_acous_feats = batch_items['acous_feat'][0]
        batch_acous_lengths = batch_items['acouslen']

        # separate into minibatch
        batch_size = batch_src_ids.size(0)
        batch_seq_len = int(max(batch_src_lengths))
        n_minibatch = int(batch_size / self.minibatch_size)
        n_minibatch += int(batch_size % self.minibatch_size > 0)
        resloss_asr = 0
        resloss_ae = 0
        resloss_kl = 0
        resloss_l2 = 0

        for bidx in range(n_minibatch):

            # debug
            # print(bidx,n_minibatch)
            # import pdb; pdb.set_trace()

            # define loss
            loss_asr = NLLLoss()
            loss_asr.reset()
            loss_ae = NLLLoss()
            loss_ae.reset()
            loss_kl = KLDivLoss()
            loss_kl.reset()
            loss_l2 = MSELoss()
            loss_l2.reset()

            # load data
            i_start = bidx * self.minibatch_size
            i_end = min(i_start + self.minibatch_size, batch_size)
            src_ids = batch_src_ids[i_start:i_end]
            src_lengths = batch_src_lengths[i_start:i_end]
            acous_feats = batch_acous_feats[i_start:i_end]
            acous_lengths = batch_acous_lengths[i_start:i_end]

            src_len = max(src_lengths)
            acous_len = max(acous_lengths)
            acous_len = acous_len + 8 - acous_len % 8
            src_ids = src_ids.to(device=self.device)
            acous_feats = acous_feats[:, :acous_len].to(device=self.device)

            # debug oom
            # acous_feats, acous_lengths, src_ids, tgt_ids = self._debug_oom(acous_len, acous_feats)

            # get padding mask
            non_padding_mask_src = src_ids.data.ne(PAD)

            # Forward propagation
            out_dict = model.forward_train(src_ids,
                                           acous_feats=acous_feats,
                                           acous_lens=acous_lengths,
                                           mode='AE_ASR',
                                           use_gpu=self.use_gpu)

            logps_asr = out_dict['logps_asr']
            emb_asr = out_dict['emb_asr']
            logps_ae = out_dict['logps_ae']
            emb_ae = out_dict['emb_ae']
            refs_ae = out_dict['refs_ae']

            # import pdb; pdb.set_trace()
            # Get loss
            if not self.eval_with_mask:
                loss_asr.eval_batch(logps_asr.reshape(-1, logps_asr.size(-1)),
                                    src_ids[:, 1:].reshape(-1))
                loss_asr.norm_term = 1.0 * src_ids.size(
                    0) * src_ids[:, 1:].size(1)
                loss_ae.eval_batch(logps_ae.reshape(-1, logps_ae.size(-1)),
                                   refs_ae.reshape(-1))
                loss_ae.norm_term = 1.0 * src_ids.size(0) * src_ids[:,
                                                                    1:].size(1)
                loss_kl.eval_batch(logps_ae.reshape(-1, logps_ae.size(-1)),
                                   logps_asr.reshape(-1, logps_asr.size(-1)))
                loss_kl.norm_term = 1.0 * src_ids.size(0) * src_ids[:,
                                                                    1:].size(1)
                loss_l2.eval_batch(emb_asr.reshape(-1, emb_asr.size(-1)),
                                   emb_ae.reshape(-1, emb_ae.size(-1)))
                loss_l2.norm_term = 1.0 * src_ids.size(0) * src_ids[:,
                                                                    1:].size(1)
            else:
                loss_asr.eval_batch_with_mask(
                    logps_asr.reshape(-1, logps_asr.size(-1)),
                    src_ids[:, 1:].reshape(-1),
                    non_padding_mask_src[:, 1:].reshape(-1))
                loss_asr.norm_term = 1.0 * torch.sum(non_padding_mask_src[:,
                                                                          1:])
                loss_ae.eval_batch_with_mask(
                    logps_ae.reshape(-1, logps_ae.size(-1)),
                    refs_ae.reshape(-1), non_padding_mask_src[:,
                                                              1:].reshape(-1))
                loss_ae.norm_term = 1.0 * torch.sum(non_padding_mask_src[:,
                                                                         1:])
                loss_kl.eval_batch_with_mask(
                    logps_ae.reshape(-1, logps_ae.size(-1)),
                    logps_asr.reshape(-1, logps_asr.size(-1)),
                    non_padding_mask_src[:, 1:].reshape(-1))
                loss_kl.norm_term = 1.0 * torch.sum(non_padding_mask_src[:,
                                                                         1:])
                loss_l2.eval_batch_with_mask(
                    emb_asr.reshape(-1, emb_asr.size(-1)),
                    emb_ae.reshape(-1, emb_ae.size(-1)),
                    non_padding_mask_src[:, 1:].reshape(-1))
                loss_l2.norm_term = 1.0 * torch.sum(non_padding_mask_src[:,
                                                                         1:])

            # Backward propagation: accumulate gradient
            if self.normalise_loss:
                loss_asr.normalise()
                loss_ae.normalise()
                loss_kl.normalise()
                loss_l2.normalise()

            loss_asr.acc_loss *= self.loss_coeff['nll_asr']
            loss_asr.acc_loss /= n_minibatch
            resloss_asr += loss_asr.get_loss()

            loss_ae.acc_loss *= self.loss_coeff['nll_ae']
            loss_ae.acc_loss /= n_minibatch
            resloss_ae += loss_ae.get_loss()

            loss_kl.acc_loss *= self.loss_coeff['kl_en']
            loss_kl.acc_loss /= n_minibatch
            resloss_kl += loss_kl.get_loss()

            loss_l2.acc_loss *= self.loss_coeff['l2']
            loss_l2.acc_loss /= n_minibatch
            resloss_l2 += loss_l2.get_loss()

            loss_asr.add(loss_ae)
            loss_asr.add(loss_kl)
            loss_asr.add(loss_l2)
            loss_asr.backward()
            # torch.cuda.empty_cache()

        # update weights
        self.optimizer.step()
        model.zero_grad()

        losses = {}
        losses['nll_loss_asr'] = resloss_asr
        losses['nll_loss_ae'] = resloss_ae
        losses['kl_loss'] = resloss_kl
        losses['l2_loss'] = resloss_l2

        return losses
示例#5
0
    def _train_batch(self,
                     src_ids,
                     tgt_ids,
                     model,
                     step,
                     total_steps,
                     src_probs=None,
                     src_labs=None):
        """
			Args:
				src_ids 		=     w1 w2 w3 </s> <pad> <pad> <pad>
				tgt_ids 		= <s> w1 w2 w3 </s> <pad> <pad> <pad>
			(optional)
				src_probs 		=     p1 p2 p3 0    0     ...
			Others:
				internal input 	= <s> w1 w2 w3 </s> <pad> <pad>
				decoder_outputs	= 	  w1 w2 w3 </s> <pad> <pad> <pad>
		"""

        # define loss
        loss = NLLLoss()

        # scheduled sampling
        if not self.scheduled_sampling:
            teacher_forcing_ratio = self.teacher_forcing_ratio
        else:
            # use self.teacher_forcing_ratio as the starting point
            progress = 1.0 * step / total_steps
            teacher_forcing_ratio = 1.0 - progress

        # get padding mask
        non_padding_mask_src = src_ids.data.ne(PAD)
        non_padding_mask_tgt = tgt_ids.data.ne(PAD)

        # Forward propagation
        decoder_outputs, decoder_hidden, ret_dict = model(
            src_ids,
            tgt_ids,
            is_training=True,
            teacher_forcing_ratio=teacher_forcing_ratio,
            att_key_feats=src_probs)

        # Get loss
        loss.reset()
        # import pdb; pdb.set_trace()

        logps = torch.stack(decoder_outputs, dim=1).to(device=device)
        if not self.eval_with_mask:
            loss.eval_batch(logps.reshape(-1, logps.size(-1)),
                            tgt_ids[:, 1:].reshape(-1))
        else:
            loss.eval_batch_with_mask(logps.reshape(-1, logps.size(-1)),
                                      tgt_ids[:, 1:].reshape(-1),
                                      non_padding_mask_tgt[:, 1:].reshape(-1))

        if not self.eval_with_mask:
            loss.norm_term = 1.0 * tgt_ids.size(0) * tgt_ids[:, 1:].size(1)
        else:
            loss.norm_term = 1.0 * torch.sum(non_padding_mask_tgt[:, 1:])
        loss.normalise()

        # Backward propagation
        model.zero_grad()
        resloss = loss.get_loss()
        att_resloss = 0
        dsfclassify_resloss = 0

        loss.backward()
        self.optimizer.step()

        return resloss, att_resloss, dsfclassify_resloss