예제 #1
0
def compute_kl(test_set, model, test_path_out, use_gpu, max_seq_len, device):
    """
		compute KL divergence between ASR and AE output
	"""

    # load test
    test_set.construct_batches(is_train=False)
    evaliter = iter(test_set.iter_loader)
    print('num batches: {}'.format(len(evaliter)))
    path_out = os.path.join(test_path_out, 'kl.stats')

    # init losses
    resloss_asr = 0
    resloss_ae = 0
    resloss_kl = 0
    resloss_l2 = 0
    h_asr = 0
    h_ae = 0
    resloss_norm = 0

    model.eval()
    with torch.no_grad():
        for idx in range(len(evaliter)):
            # for idx in range(2):

            print(idx + 1, len(evaliter))
            batch_items = evaliter.next()

            # load data
            src_ids = batch_items['srcid'][0]
            src_lengths = batch_items['srclen']
            tgt_ids = batch_items['tgtid'][0]
            tgt_lengths = batch_items['tgtlen']
            acous_feats = batch_items['acous_feat'][0]
            acous_lengths = batch_items['acouslen']

            src_len = max(src_lengths)
            tgt_len = max(tgt_lengths)
            acous_len = max(acous_lengths)
            src_ids = src_ids[:, :src_len].to(device=device)
            tgt_ids = tgt_ids.to(device=device)
            acous_feats = acous_feats.to(device=device)

            n_minibatch = int(tgt_len / 100 + tgt_len % 100 > 0)
            minibatch_size = int(src_ids.size(0) / n_minibatch)
            n_minibatch = int(src_ids.size(0) / minibatch_size) + \
             (src_ids.size(0) % minibatch_size > 0)

            for j in range(n_minibatch):

                st = j * minibatch_size
                ed = min((j + 1) * minibatch_size, src_ids.size(0))
                src_ids_sub = src_ids[st:ed, :]
                tgt_ids_sub = tgt_ids[st:ed, :]
                acous_feats_sub = acous_feats[st:ed, :]
                acous_lengths_sub = acous_lengths[st:ed]
                print('minibatch: ', st, ed, src_ids_sub.size(0))

                # generate logp
                out_dict = model.forward_eval(src=src_ids_sub,
                                              acous_feats=acous_feats_sub,
                                              acous_lens=acous_lengths_sub,
                                              mode='AE_ASR',
                                              use_gpu=use_gpu)

                max_len = min(src_ids_sub.size(1),
                              out_dict['preds_asr'].size(1))
                preds_hyp_asr = out_dict['preds_asr'][:, :max_len - 1]
                preds_hyp_ae = out_dict['preds_ae'][:, :max_len - 1]
                emb_hyp_asr = out_dict['emb_asr'][:, :max_len - 1]
                emb_hyp_ae = out_dict['emb_ae'][:, :max_len - 1]
                logps_hyp_asr = out_dict['logps_asr'][:, :max_len - 1]
                logps_hyp_ae = out_dict['logps_ae'][:, :max_len - 1]

                refs_ae = out_dict['refs_ae'][:, :max_len - 1]
                src_ids_sub = src_ids_sub[:, :max_len]
                non_padding_mask_src = src_ids_sub.data.ne(PAD)

                # import pdb; pdb.set_trace()

                # various losses
                loss_asr = NLLLoss()
                loss_asr.reset()
                loss_ae = NLLLoss()
                loss_ae.reset()
                loss_kl = KLDivLoss()
                loss_kl.reset()
                loss_l2 = MSELoss()
                loss_l2.reset()

                loss_asr.eval_batch_with_mask(
                    logps_hyp_asr.reshape(-1, logps_hyp_asr.size(-1)),
                    src_ids_sub[:, 1:].reshape(-1),
                    non_padding_mask_src[:, 1:].reshape(-1))
                loss_asr.norm_term = 1.0 * torch.sum(non_padding_mask_src[:,
                                                                          1:])
                loss_ae.eval_batch_with_mask(
                    logps_hyp_ae.reshape(-1, logps_hyp_ae.size(-1)),
                    refs_ae.reshape(-1), non_padding_mask_src[:,
                                                              1:].reshape(-1))
                loss_ae.norm_term = 1.0 * torch.sum(non_padding_mask_src[:,
                                                                         1:])
                loss_kl.eval_batch_with_mask(
                    logps_hyp_ae.reshape(-1, logps_hyp_ae.size(-1)),
                    logps_hyp_asr.reshape(-1, logps_hyp_asr.size(-1)),
                    non_padding_mask_src[:, 1:].reshape(-1))
                loss_kl.norm_term = 1.0 * torch.sum(non_padding_mask_src[:,
                                                                         1:])
                loss_l2.eval_batch_with_mask(
                    emb_hyp_asr.reshape(-1, emb_hyp_asr.size(-1)),
                    emb_hyp_ae.reshape(-1, emb_hyp_ae.size(-1)),
                    non_padding_mask_src[:, 1:].reshape(-1))
                loss_l2.norm_term = 1.0 * torch.sum(non_padding_mask_src[:,
                                                                         1:])

                loss_asr.normalise()
                loss_ae.normalise()
                loss_kl.normalise()
                loss_l2.normalise()

                resloss_asr += loss_asr.get_loss()
                resloss_ae += loss_ae.get_loss()
                resloss_kl += loss_kl.get_loss()
                resloss_l2 += loss_l2.get_loss()
                resloss_norm += 1

                # compute per token entropy
                entropy_asr = torch.mean(
                    Categorical(probs=torch.exp(logps_hyp_asr)).entropy())
                entropy_ae = torch.mean(
                    Categorical(probs=torch.exp(logps_hyp_ae)).entropy())
                h_asr += entropy_asr.item()
                h_ae += entropy_ae.item()

                # import pdb; pdb.set_trace()

    fout = open(path_out, 'w')
    fout.write('Various stats (averaged over tokens)')
    fout.write('\n{}\n'.format('-' * 50))
    fout.write('NLL ASR: {:0.2f}\n'.format(1. * resloss_asr / resloss_norm))
    fout.write('NLL AE: {:0.2f}\n'.format(1. * resloss_ae / resloss_norm))
    fout.write('KL between ASR, AE: {:0.2f}\n'.format(1. * resloss_kl /
                                                      resloss_norm))
    fout.write('L2 between embeddings: {:0.2f}\n'.format(1. * resloss_l2 /
                                                         resloss_norm))
    fout.write('Entropy ASR: {:0.2f}\n'.format(1. * h_asr / resloss_norm))
    fout.write('Entropy AE: {:0.2f}\n'.format(1. * h_ae / resloss_norm))

    fout.close()
예제 #2
0
    def _train_batch(self, model, batch_items, dataset, step, total_steps):

        # -- scheduled sampling --
        if not self.scheduled_sampling:
            teacher_forcing_ratio = self.teacher_forcing_ratio
        else:
            progress = 1.0 * step / total_steps
            teacher_forcing_ratio = 1.0 - progress

        # -- LOAD BATCH --
        batch_src_ids = batch_items[0][0]
        batch_src_lengths = batch_items[1]
        batch_acous_feats = batch_items[2][0]
        batch_acous_lengths = batch_items[3]

        # -- CONSTRUCT MINIBATCH --
        batch_size = batch_src_ids.size(0)
        batch_seq_len = int(max(batch_src_lengths))
        batch_acous_len = int(max(batch_acous_lengths))

        n_minibatch = int(batch_size / self.minibatch_size)
        n_minibatch += int(batch_size % self.minibatch_size > 0)

        las_resloss = 0

        # minibatch
        for bidx in range(n_minibatch):

            # debug
            # import pdb; pdb.set_trace()

            # define loss
            las_loss = NLLLoss()
            las_loss.reset()

            # load data
            i_start = bidx * self.minibatch_size
            i_end = min(i_start + self.minibatch_size, batch_size)
            src_ids = batch_src_ids[i_start:i_end]
            src_lengths = batch_src_lengths[i_start:i_end]
            acous_feats = batch_acous_feats[i_start:i_end]
            acous_lengths = batch_acous_lengths[i_start:i_end]

            seq_len = max(src_lengths)
            acous_len = max(acous_lengths)
            acous_len = acous_len + 8 - acous_len % 8
            src_ids = src_ids[:, :seq_len].to(device=self.device)
            acous_feats = acous_feats[:, :acous_len].to(device=self.device)

            # sanity check src
            if step == 1: check_src_tensor_print(src_ids, dataset.src_id2word)

            # get padding mask
            non_padding_mask_src = src_ids.data.ne(PAD)

            # Forward propagation
            decoder_outputs, decoder_hidden, ret_dict = model(
                acous_feats,
                acous_lens=acous_lengths,
                tgt=src_ids,
                is_training=True,
                teacher_forcing_ratio=teacher_forcing_ratio,
                use_gpu=self.use_gpu)

            logps = torch.stack(decoder_outputs, dim=1).to(device=self.device)
            las_loss.eval_batch_with_mask(logps.reshape(-1, logps.size(-1)),
                                          src_ids.reshape(-1),
                                          non_padding_mask_src.reshape(-1))
            las_loss.norm_term = 1.0 * torch.sum(non_padding_mask_src)

            # import pdb; pdb.set_trace()
            # Backward propagation: accumulate gradient
            if self.normalise_loss: las_loss.normalise()
            las_loss.acc_loss /= n_minibatch
            las_loss.backward()
            las_resloss += las_loss.get_loss()
            torch.cuda.empty_cache()

        # update weights
        self.optimizer.step()
        model.zero_grad()
        losses = {'las_loss': las_resloss}

        return losses
예제 #3
0
    def _evaluate_batches(self, model, dataset):

        model.eval()

        las_match = 0
        las_total = 0
        las_resloss = 0
        las_resloss_norm = 0

        evaliter = iter(dataset.iter_loader)
        out_count = 0

        with torch.no_grad():
            for idx in range(len(evaliter)):
                batch_items = evaliter.next()

                # load data
                batch_src_ids = batch_items[0][0]
                batch_src_lengths = batch_items[1]
                batch_acous_feats = batch_items[2][0]
                batch_acous_lengths = batch_items[3]

                # separate into minibatch
                batch_size = batch_src_ids.size(0)
                batch_seq_len = int(max(batch_src_lengths))
                batch_acous_len = int(max(batch_acous_lengths))

                n_minibatch = int(batch_size / self.minibatch_size)
                n_minibatch += int(batch_size % self.minibatch_size > 0)

                # minibatch
                for bidx in range(n_minibatch):

                    las_loss = NLLLoss()
                    las_loss.reset()

                    i_start = bidx * self.minibatch_size
                    i_end = min(i_start + self.minibatch_size, batch_size)
                    src_ids = batch_src_ids[i_start:i_end]
                    src_lengths = batch_src_lengths[i_start:i_end]
                    acous_feats = batch_acous_feats[i_start:i_end]
                    acous_lengths = batch_acous_lengths[i_start:i_end]

                    seq_len = max(src_lengths)
                    acous_len = max(acous_lengths)
                    acous_len = acous_len + 8 - acous_len % 8
                    src_ids = src_ids[:, :seq_len].to(device=self.device)
                    acous_feats = acous_feats[:, :acous_len].to(
                        device=self.device)

                    non_padding_mask_src = src_ids.data.ne(PAD)

                    # eval using ref
                    decoder_outputs, decoder_hidden, ret_dict = model(
                        acous_feats,
                        acous_lens=acous_lengths,
                        teacher_forcing_ratio=1.0,
                        tgt=src_ids,
                        is_training=False,
                        use_gpu=self.use_gpu)
                    # eval under hyp
                    # decoder_outputs, decoder_hidden, ret_dict = model(
                    # 	acous_feats, acous_lens=acous_lengths,
                    # 	teacher_forcing_ratio=0.0,
                    # 	tgt=src_ids, is_training=False, use_gpu=self.use_gpu)

                    # Evaluation
                    logps = torch.stack(decoder_outputs,
                                        dim=1).to(device=self.device)
                    las_loss.eval_batch_with_mask(
                        logps.reshape(-1, logps.size(-1)), src_ids.reshape(-1),
                        non_padding_mask_src.reshape(-1))
                    las_loss.norm_term = torch.sum(non_padding_mask_src)
                    if self.normalise_loss: las_loss.normalise()
                    las_resloss += las_loss.get_loss()
                    las_resloss_norm += 1

                    # las accuracy
                    seqlist = ret_dict['sequence']
                    seqres = torch.stack(seqlist, dim=1).to(device=self.device)
                    correct = seqres.view(-1).eq(src_ids.reshape(-1))\
                     .masked_select(non_padding_mask_src.reshape(-1)).sum().item()
                    las_match += correct
                    las_total += non_padding_mask_src.sum().item()

                    out_count = self._print_hyp(out_count, src_ids,
                                                dataset.src_id2word, seqlist)

        if las_total == 0:
            las_acc = float('nan')
        else:
            las_acc = las_match / las_total

        las_resloss /= (1.0 * las_resloss_norm)
        accs = {'las_acc': las_acc}
        losses = {'las_loss': las_resloss}

        return accs, losses
예제 #4
0
    def _evaluate_batches(self, model, dataset):

        # todo: return BLEU score (use BLEU to determine roll back etc)
        # import pdb; pdb.set_trace()

        model.eval()

        resloss = 0
        resloss_norm = 0

        match = 0
        total = 0

        evaliter = iter(dataset.iter_loader)
        out_count = 0

        with torch.no_grad():
            for idx in range(len(evaliter)):
                batch_items = evaliter.next()

                # load data
                batch_src_ids = batch_items['srcid'][0]
                batch_src_lengths = batch_items['srclen']
                batch_tgt_ids = batch_items['tgtid'][0]
                batch_tgt_lengths = batch_items['tgtlen']

                # separate into minibatch
                batch_size = batch_src_ids.size(0)
                batch_seq_len = int(max(batch_src_lengths))

                n_minibatch = int(batch_size / self.minibatch_size)
                n_minibatch += int(batch_size % self.minibatch_size > 0)

                for bidx in range(n_minibatch):

                    loss = NLLLoss()
                    loss.reset()

                    i_start = bidx * self.minibatch_size
                    i_end = min(i_start + self.minibatch_size, batch_size)
                    src_ids = batch_src_ids[i_start:i_end]
                    src_lengths = batch_src_lengths[i_start:i_end]
                    tgt_ids = batch_tgt_ids[i_start:i_end]
                    tgt_lengths = batch_tgt_lengths[i_start:i_end]

                    src_len = max(src_lengths)
                    tgt_len = max(tgt_lengths)
                    src_ids = src_ids[:, :src_len].to(device=self.device)
                    tgt_ids = tgt_ids.to(device=self.device)

                    non_padding_mask_tgt = tgt_ids.data.ne(PAD)
                    non_padding_mask_src = src_ids.data.ne(PAD)

                    # run model
                    preds, logps, dec_outputs = model.forward_eval(
                        src_ids, use_gpu=self.use_gpu)

                    # evaluation
                    if not self.eval_with_mask:
                        loss.eval_batch(
                            logps[:, 1:, :].reshape(-1, logps.size(-1)),
                            tgt_ids[:, 1:].reshape(-1))
                        loss.norm_term = 1.0 * tgt_ids.size(
                            0) * tgt_ids[:, 1:].size(1)
                    else:
                        loss.eval_batch_with_mask(
                            logps[:, 1:, :].reshape(-1, logps.size(-1)),
                            tgt_ids[:, 1:].reshape(-1),
                            non_padding_mask_tgt[:, 1:].reshape(-1))
                        loss.norm_term = 1.0 * torch.sum(
                            non_padding_mask_tgt[:, 1:])
                    if self.normalise_loss: loss.normalise()
                    resloss += loss.get_loss()
                    resloss_norm += 1

                    seqres = preds[:, 1:]
                    correct = seqres.reshape(-1).eq(tgt_ids[:,1:].reshape(-1))\
                     .masked_select(non_padding_mask_tgt[:,1:].reshape(-1)).sum().item()
                    match += correct
                    total += non_padding_mask_tgt[:, 1:].sum().item()

                    out_count = self._print_hyp(out_count, src_ids, tgt_ids,
                                                dataset.src_id2word,
                                                dataset.tgt_id2word, seqres)

        if total == 0:
            accuracy = float('nan')
        else:
            accuracy = match / total

        resloss /= (1.0 * resloss_norm)
        torch.cuda.empty_cache()
        losses = {}
        losses['nll_loss'] = resloss

        return resloss, accuracy, losses
예제 #5
0
    def _train_batch(self, model, batch_items, dataset, step, total_steps):
        """
			Args:
				src_ids 		=     w1 w2 w3 </s> <pad> <pad> <pad>
				tgt_ids 		= <s> w1 w2 w3 </s> <pad> <pad> <pad>

			Others:
				internal input 	= <s> w1 w2 w3 </s> <pad> <pad>
				decoder_outputs	= 	  w1 w2 w3 </s> <pad> <pad> <pad>
		"""

        # load data
        batch_src_ids = batch_items['srcid'][0]
        batch_src_lengths = batch_items['srclen']
        batch_tgt_ids = batch_items['tgtid'][0]
        batch_tgt_lengths = batch_items['tgtlen']

        # separate into minibatch
        batch_size = batch_src_ids.size(0)
        batch_seq_len = int(max(batch_src_lengths))
        n_minibatch = int(batch_size / self.minibatch_size)
        n_minibatch += int(batch_size % self.minibatch_size > 0)
        resloss = 0

        for bidx in range(n_minibatch):

            # debug
            # import pdb; pdb.set_trace()

            # define loss
            loss = NLLLoss()
            loss.reset()

            # load data
            i_start = bidx * self.minibatch_size
            i_end = min(i_start + self.minibatch_size, batch_size)
            src_ids = batch_src_ids[i_start:i_end]
            src_lengths = batch_src_lengths[i_start:i_end]
            tgt_ids = batch_tgt_ids[i_start:i_end]
            tgt_lengths = batch_tgt_lengths[i_start:i_end]

            src_len = max(src_lengths)
            tgt_len = max(tgt_lengths)
            src_ids = src_ids[:, :src_len].to(device=self.device)
            tgt_ids = tgt_ids.to(device=self.device)

            # get padding mask
            non_padding_mask_src = src_ids.data.ne(PAD)
            non_padding_mask_tgt = tgt_ids.data.ne(PAD)

            # Forward propagation
            preds, logps, dec_outputs = model.forward_train(
                src_ids, tgt_ids, use_gpu=self.use_gpu)

            # Get loss
            if not self.eval_with_mask:
                loss.eval_batch(logps[:, :-1, :].reshape(-1, logps.size(-1)),
                                tgt_ids[:, 1:].reshape(-1))
                loss.norm_term = 1.0 * tgt_ids.size(0) * tgt_ids[:, 1:].size(1)
            else:
                loss.eval_batch_with_mask(
                    logps[:, :-1, :].reshape(-1, logps.size(-1)),
                    tgt_ids[:, 1:].reshape(-1),
                    non_padding_mask_tgt[:, 1:].reshape(-1))
                loss.norm_term = 1.0 * torch.sum(non_padding_mask_tgt[:, 1:])

            # import pdb; pdb.set_trace()
            # Backward propagation: accumulate gradient
            if self.normalise_loss: loss.normalise()
            loss.acc_loss /= n_minibatch
            loss.backward()
            resloss += loss.get_loss()
            torch.cuda.empty_cache()

        # update weights
        self.optimizer.step()
        model.zero_grad()

        return resloss
예제 #6
0
    def _evaluate_batches(self, model, dataset):

        # import pdb; pdb.set_trace()

        model.eval()

        resloss_asr = 0
        resloss_ae = 0
        resloss_kl = 0
        resloss_l2 = 0
        resloss_norm = 0

        # accuracy
        match_asr = 0
        total_asr = 0
        match_ae = 0
        total_ae = 0

        # bleu
        hyp_corpus_asr = []
        ref_corpus_asr = []
        hyp_corpus_ae = []
        ref_corpus_ae = []

        evaliter = iter(dataset.iter_loader)
        out_count = 0

        with torch.no_grad():
            for idx in range(len(evaliter)):
                batch_items = evaliter.next()

                # import pdb; pdb.set_trace()

                # load data
                batch_src_ids = batch_items['srcid'][0]
                batch_src_lengths = batch_items['srclen']
                batch_acous_feats = batch_items['acous_feat'][0]
                batch_acous_lengths = batch_items['acouslen']

                # separate into minibatch
                batch_size = batch_src_ids.size(0)
                batch_seq_len = int(max(batch_src_lengths))

                n_minibatch = int(batch_size / self.minibatch_size)
                n_minibatch += int(batch_size % self.minibatch_size > 0)

                for bidx in range(n_minibatch):

                    loss_asr = NLLLoss()
                    loss_asr.reset()
                    loss_ae = NLLLoss()
                    loss_ae.reset()
                    loss_kl = KLDivLoss()
                    loss_kl.reset()
                    loss_l2 = MSELoss()
                    loss_l2.reset()

                    i_start = bidx * self.minibatch_size
                    i_end = min(i_start + self.minibatch_size, batch_size)
                    src_ids = batch_src_ids[i_start:i_end]
                    src_lengths = batch_src_lengths[i_start:i_end]
                    acous_feats = batch_acous_feats[i_start:i_end]
                    acous_lengths = batch_acous_lengths[i_start:i_end]

                    src_len = max(src_lengths)
                    acous_len = max(acous_lengths)
                    acous_len = acous_len + 8 - acous_len % 8
                    src_ids = src_ids.to(device=self.device)
                    acous_feats = acous_feats[:, :acous_len].to(
                        device=self.device)

                    # debug oom
                    # acous_feats, acous_lengths, src_ids, tgt_ids = self._debug_oom(acous_len, acous_feats)

                    non_padding_mask_src = src_ids.data.ne(PAD)

                    # [run-TF] to save time
                    # out_dict =  model.forward_train(src_ids, acous_feats=acous_feats,
                    #	acous_lens=acous_lengths, mode='ASR', use_gpu=self.use_gpu)

                    # [run-FR] to get true stats
                    out_dict = model.forward_eval(src=src_ids,
                                                  acous_feats=acous_feats,
                                                  acous_lens=acous_lengths,
                                                  mode='AE_ASR',
                                                  use_gpu=self.use_gpu)

                    preds_asr = out_dict['preds_asr']
                    logps_asr = out_dict['logps_asr']
                    emb_asr = out_dict['emb_asr']
                    preds_ae = out_dict['preds_ae']
                    logps_ae = out_dict['logps_ae']
                    emb_ae = out_dict['emb_ae']
                    refs_ae = out_dict['refs_ae']
                    logps_hyp_asr = logps_asr
                    preds_hyp_asr = preds_asr
                    emb_hyp_asr = emb_asr
                    logps_hyp_ae = logps_ae
                    preds_hyp_ae = preds_ae
                    emb_hyp_ae = emb_ae

                    # evaluation
                    if not self.eval_with_mask:
                        loss_asr.eval_batch(
                            logps_hyp_asr.reshape(-1, logps_hyp_asr.size(-1)),
                            src_ids[:, 1:].reshape(-1))
                        loss_asr.norm_term = 1.0 * src_ids.size(
                            0) * src_ids[:, 1:].size(1)
                        loss_ae.eval_batch(
                            logps_hyp_ae.reshape(-1, logps_hyp_ae.size(-1)),
                            refs_ae.reshape(-1))
                        loss_ae.norm_term = 1.0 * src_ids.size(
                            0) * src_ids[:, 1:].size(1)
                        loss_kl.eval_batch(
                            logps_hyp_ae.reshape(-1, logps_hyp_ae.size(-1)),
                            logps_hyp_asr.reshape(-1, logps_hyp_asr.size(-1)))
                        loss_kl.norm_term = 1.0 * src_ids.size(
                            0) * src_ids[:, 1:].size(1)
                        loss_l2.eval_batch(
                            emb_hyp_asr.reshape(-1, emb_hyp_asr.size(-1)),
                            emb_hyp_ae.reshape(-1, emb_hyp_ae.size(-1)))
                        loss_l2.norm_term = 1.0 * src_ids.size(
                            0) * src_ids[:, 1:].size(1)

                    else:
                        loss_asr.eval_batch_with_mask(
                            logps_hyp_asr.reshape(-1, logps_hyp_asr.size(-1)),
                            src_ids[:, 1:].reshape(-1),
                            non_padding_mask_src[:, 1:].reshape(-1))
                        loss_asr.norm_term = 1.0 * torch.sum(
                            non_padding_mask_src[:, 1:])
                        loss_ae.eval_batch_with_mask(
                            logps_hyp_ae.reshape(-1, logps_hyp_ae.size(-1)),
                            refs_ae.reshape(-1),
                            non_padding_mask_src[:, 1:].reshape(-1))
                        loss_ae.norm_term = 1.0 * torch.sum(
                            non_padding_mask_src[:, 1:])
                        loss_kl.eval_batch_with_mask(
                            logps_hyp_ae.reshape(-1, logps_hyp_ae.size(-1)),
                            logps_hyp_asr.reshape(-1, logps_hyp_asr.size(-1)),
                            non_padding_mask_src[:, 1:].reshape(-1))
                        loss_kl.norm_term = 1.0 * torch.sum(
                            non_padding_mask_src[:, 1:])
                        loss_l2.eval_batch_with_mask(
                            emb_hyp_asr.reshape(-1, emb_hyp_asr.size(-1)),
                            emb_hyp_ae.reshape(-1, emb_hyp_ae.size(-1)),
                            non_padding_mask_src[:, 1:].reshape(-1))
                        loss_l2.norm_term = 1.0 * torch.sum(
                            non_padding_mask_src[:, 1:])

                    if self.normalise_loss:
                        loss_asr.normalise()
                        loss_ae.normalise()
                        loss_kl.normalise()
                        loss_l2.normalise()

                    resloss_asr += loss_asr.get_loss()
                    resloss_ae += loss_ae.get_loss()
                    resloss_kl += loss_kl.get_loss()
                    resloss_l2 += loss_l2.get_loss()
                    resloss_norm += 1

                    # ----- debug -----
                    # print('{}/{}, {}/{}'.format(bidx, n_minibatch, idx, len(evaliter)))
                    # if loss_kl.get_loss() > 10 or (bidx==4 and idx==1):
                    # 	import pdb; pdb.set_trace()
                    # -----------------

                    # accuracy
                    seqres_asr = preds_hyp_asr
                    correct_asr = seqres_asr.reshape(-1).eq(src_ids[:,1:].reshape(-1))\
                     .masked_select(non_padding_mask_src[:,1:].reshape(-1)).sum().item()
                    match_asr += correct_asr
                    total_asr += non_padding_mask_src[:, 1:].sum().item()

                    seqres_ae = preds_hyp_ae
                    correct_ae = seqres_ae.reshape(-1).eq(refs_ae.reshape(-1))\
                     .masked_select(non_padding_mask_src[:,1:].reshape(-1)).sum().item()
                    match_ae += correct_ae
                    total_ae += non_padding_mask_src[:, 1:].sum().item()

                    # append to refs_ae
                    dummy = torch.zeros(refs_ae.size(0),
                                        1).to(device=self.device).long()
                    refs_ae_add = torch.cat((dummy, refs_ae), dim=1)

                    # print
                    out_count_dummy = self._print(out_count,
                                                  src_ids,
                                                  dataset.src_id2word,
                                                  seqres_asr,
                                                  tail='-asr')
                    out_count = self._print(out_count,
                                            refs_ae_add,
                                            dataset.src_id2word,
                                            seqres_ae,
                                            tail='-ae ')

                    # accumulate corpus
                    hyp_corpus_asr, ref_corpus_asr = add2corpus(
                        seqres_asr,
                        src_ids,
                        dataset.src_id2word,
                        hyp_corpus_asr,
                        ref_corpus_asr,
                        type='word')
                    hyp_corpus_ae, ref_corpus_ae = add2corpus(
                        seqres_ae,
                        refs_ae_add,
                        dataset.src_id2word,
                        hyp_corpus_ae,
                        ref_corpus_ae,
                        type='word')

        # import pdb; pdb.set_trace()
        bleu_asr = torchtext.data.metrics.bleu_score(hyp_corpus_asr,
                                                     ref_corpus_asr)
        bleu_ae = torchtext.data.metrics.bleu_score(hyp_corpus_ae,
                                                    ref_corpus_ae)

        # torch.cuda.empty_cache()
        if total_asr == 0:
            accuracy_asr = float('nan')
        else:
            accuracy_asr = match_asr / total_asr
        if total_ae == 0:
            accuracy_ae = float('nan')
        else:
            accuracy_ae = match_ae / total_ae

        resloss_asr *= self.loss_coeff['nll_asr']
        resloss_asr /= (1.0 * resloss_norm)
        resloss_ae *= self.loss_coeff['nll_ae']
        resloss_ae /= (1.0 * resloss_norm)
        resloss_kl *= self.loss_coeff['kl_en']
        resloss_kl /= (1.0 * resloss_norm)
        resloss_l2 *= self.loss_coeff['l2']
        resloss_l2 /= (1.0 * resloss_norm)

        losses = {}
        losses['l2_loss'] = resloss_l2
        losses['kl_loss'] = resloss_kl
        losses['nll_loss_asr'] = resloss_asr
        losses['nll_loss_ae'] = resloss_ae
        metrics = {}
        metrics['accuracy_asr'] = accuracy_asr
        metrics['bleu_asr'] = bleu_asr
        metrics['accuracy_ae'] = accuracy_ae
        metrics['bleu_ae'] = bleu_ae

        return losses, metrics
예제 #7
0
    def _evaluate_batches(self, model, dataset):

        # import pdb; pdb.set_trace()

        model.eval()

        resloss = 0
        resloss_norm = 0

        # accuracy
        match = 0
        total = 0

        # bleu
        hyp_corpus = []
        ref_corpus = []

        evaliter = iter(dataset.iter_loader)
        out_count = 0

        with torch.no_grad():
            for idx in range(len(evaliter)):
                batch_items = evaliter.next()

                # load data
                batch_src_ids = batch_items['srcid'][0]
                batch_src_lengths = batch_items['srclen']
                batch_tgt_ids = batch_items['tgtid'][0]
                batch_tgt_lengths = batch_items['tgtlen']

                # separate into minibatch
                batch_size = batch_src_ids.size(0)
                batch_seq_len = int(max(batch_src_lengths))

                n_minibatch = int(batch_size / self.minibatch_size)
                n_minibatch += int(batch_size % self.minibatch_size > 0)

                for bidx in range(n_minibatch):

                    loss = NLLLoss()
                    loss.reset()

                    i_start = bidx * self.minibatch_size
                    i_end = min(i_start + self.minibatch_size, batch_size)
                    src_ids = batch_src_ids[i_start:i_end]
                    src_lengths = batch_src_lengths[i_start:i_end]
                    tgt_ids = batch_tgt_ids[i_start:i_end]
                    tgt_lengths = batch_tgt_lengths[i_start:i_end]

                    src_len = max(src_lengths)
                    tgt_len = max(tgt_lengths)
                    src_ids = src_ids[:, :src_len].to(device=self.device)
                    tgt_ids = tgt_ids.to(device=self.device)

                    non_padding_mask_tgt = tgt_ids.data.ne(PAD)
                    non_padding_mask_src = src_ids.data.ne(PAD)

                    # import pdb; pdb.set_trace()
                    if self.eval_mode == 'tf':
                        # [run-TF] to save time
                        preds, logps, dec_outputs = model.forward_train(
                            src_ids, tgt_ids, use_gpu=self.use_gpu)
                        logps_hyp = logps[:, :-1, :]
                        preds_hyp = preds[:, :-1]
                    elif self.eval_mode == 'fr':
                        # [run-FR] to get true stats
                        preds, logps, dec_outputs = model.forward_eval(
                            src_ids, use_gpu=self.use_gpu)
                        logps_hyp = logps[:, 1:, :]
                        preds_hyp = preds[:, 1:]

                    # evaluation
                    if not self.eval_with_mask:
                        loss.eval_batch(
                            logps_hyp.reshape(-1, logps_hyp.size(-1)),
                            tgt_ids[:, 1:].reshape(-1))
                        loss.norm_term = 1.0 * tgt_ids.size(
                            0) * tgt_ids[:, 1:].size(1)
                    else:
                        loss.eval_batch_with_mask(
                            logps_hyp.reshape(-1, logps_hyp.size(-1)),
                            tgt_ids[:, 1:].reshape(-1),
                            non_padding_mask_tgt[:, 1:].reshape(-1))
                        loss.norm_term = 1.0 * torch.sum(
                            non_padding_mask_tgt[:, 1:])
                    if self.normalise_loss: loss.normalise()
                    resloss += loss.get_loss()
                    resloss_norm += 1

                    # accuracy
                    seqres = preds_hyp
                    correct = seqres.reshape(-1).eq(tgt_ids[:,1:].reshape(-1))\
                     .masked_select(non_padding_mask_tgt[:,1:].reshape(-1)).sum().item()
                    match += correct
                    total += non_padding_mask_tgt[:, 1:].sum().item()

                    # print
                    out_count = self._print_hyp(out_count, src_ids, tgt_ids,
                                                dataset.src_id2word,
                                                dataset.tgt_id2word, seqres)

                    # accumulate corpus
                    hyp_corpus, ref_corpus = add2corpus(seqres,
                                                        tgt_ids,
                                                        dataset.tgt_id2word,
                                                        hyp_corpus,
                                                        ref_corpus,
                                                        type=dataset.use_type)

        # import pdb; pdb.set_trace()
        bleu = torchtext.data.metrics.bleu_score(hyp_corpus, ref_corpus)

        if total == 0:
            accuracy = float('nan')
        else:
            accuracy = match / total

        resloss /= (1.0 * resloss_norm)

        losses = {}
        losses['nll_loss'] = resloss
        metrics = {}
        metrics['accuracy'] = accuracy
        metrics['bleu'] = bleu

        return losses, metrics
예제 #8
0
    def _train_batch(self, model, batch_items, dataset, step, total_steps):

        # load data
        batch_src_ids = batch_items['srcid'][0]
        batch_src_lengths = batch_items['srclen']
        batch_acous_feats = batch_items['acous_feat'][0]
        batch_acous_lengths = batch_items['acouslen']

        # separate into minibatch
        batch_size = batch_src_ids.size(0)
        batch_seq_len = int(max(batch_src_lengths))
        n_minibatch = int(batch_size / self.minibatch_size)
        n_minibatch += int(batch_size % self.minibatch_size > 0)
        resloss_asr = 0
        resloss_ae = 0
        resloss_kl = 0
        resloss_l2 = 0

        for bidx in range(n_minibatch):

            # debug
            # print(bidx,n_minibatch)
            # import pdb; pdb.set_trace()

            # define loss
            loss_asr = NLLLoss()
            loss_asr.reset()
            loss_ae = NLLLoss()
            loss_ae.reset()
            loss_kl = KLDivLoss()
            loss_kl.reset()
            loss_l2 = MSELoss()
            loss_l2.reset()

            # load data
            i_start = bidx * self.minibatch_size
            i_end = min(i_start + self.minibatch_size, batch_size)
            src_ids = batch_src_ids[i_start:i_end]
            src_lengths = batch_src_lengths[i_start:i_end]
            acous_feats = batch_acous_feats[i_start:i_end]
            acous_lengths = batch_acous_lengths[i_start:i_end]

            src_len = max(src_lengths)
            acous_len = max(acous_lengths)
            acous_len = acous_len + 8 - acous_len % 8
            src_ids = src_ids.to(device=self.device)
            acous_feats = acous_feats[:, :acous_len].to(device=self.device)

            # debug oom
            # acous_feats, acous_lengths, src_ids, tgt_ids = self._debug_oom(acous_len, acous_feats)

            # get padding mask
            non_padding_mask_src = src_ids.data.ne(PAD)

            # Forward propagation
            out_dict = model.forward_train(src_ids,
                                           acous_feats=acous_feats,
                                           acous_lens=acous_lengths,
                                           mode='AE_ASR',
                                           use_gpu=self.use_gpu)

            logps_asr = out_dict['logps_asr']
            emb_asr = out_dict['emb_asr']
            logps_ae = out_dict['logps_ae']
            emb_ae = out_dict['emb_ae']
            refs_ae = out_dict['refs_ae']

            # import pdb; pdb.set_trace()
            # Get loss
            if not self.eval_with_mask:
                loss_asr.eval_batch(logps_asr.reshape(-1, logps_asr.size(-1)),
                                    src_ids[:, 1:].reshape(-1))
                loss_asr.norm_term = 1.0 * src_ids.size(
                    0) * src_ids[:, 1:].size(1)
                loss_ae.eval_batch(logps_ae.reshape(-1, logps_ae.size(-1)),
                                   refs_ae.reshape(-1))
                loss_ae.norm_term = 1.0 * src_ids.size(0) * src_ids[:,
                                                                    1:].size(1)
                loss_kl.eval_batch(logps_ae.reshape(-1, logps_ae.size(-1)),
                                   logps_asr.reshape(-1, logps_asr.size(-1)))
                loss_kl.norm_term = 1.0 * src_ids.size(0) * src_ids[:,
                                                                    1:].size(1)
                loss_l2.eval_batch(emb_asr.reshape(-1, emb_asr.size(-1)),
                                   emb_ae.reshape(-1, emb_ae.size(-1)))
                loss_l2.norm_term = 1.0 * src_ids.size(0) * src_ids[:,
                                                                    1:].size(1)
            else:
                loss_asr.eval_batch_with_mask(
                    logps_asr.reshape(-1, logps_asr.size(-1)),
                    src_ids[:, 1:].reshape(-1),
                    non_padding_mask_src[:, 1:].reshape(-1))
                loss_asr.norm_term = 1.0 * torch.sum(non_padding_mask_src[:,
                                                                          1:])
                loss_ae.eval_batch_with_mask(
                    logps_ae.reshape(-1, logps_ae.size(-1)),
                    refs_ae.reshape(-1), non_padding_mask_src[:,
                                                              1:].reshape(-1))
                loss_ae.norm_term = 1.0 * torch.sum(non_padding_mask_src[:,
                                                                         1:])
                loss_kl.eval_batch_with_mask(
                    logps_ae.reshape(-1, logps_ae.size(-1)),
                    logps_asr.reshape(-1, logps_asr.size(-1)),
                    non_padding_mask_src[:, 1:].reshape(-1))
                loss_kl.norm_term = 1.0 * torch.sum(non_padding_mask_src[:,
                                                                         1:])
                loss_l2.eval_batch_with_mask(
                    emb_asr.reshape(-1, emb_asr.size(-1)),
                    emb_ae.reshape(-1, emb_ae.size(-1)),
                    non_padding_mask_src[:, 1:].reshape(-1))
                loss_l2.norm_term = 1.0 * torch.sum(non_padding_mask_src[:,
                                                                         1:])

            # Backward propagation: accumulate gradient
            if self.normalise_loss:
                loss_asr.normalise()
                loss_ae.normalise()
                loss_kl.normalise()
                loss_l2.normalise()

            loss_asr.acc_loss *= self.loss_coeff['nll_asr']
            loss_asr.acc_loss /= n_minibatch
            resloss_asr += loss_asr.get_loss()

            loss_ae.acc_loss *= self.loss_coeff['nll_ae']
            loss_ae.acc_loss /= n_minibatch
            resloss_ae += loss_ae.get_loss()

            loss_kl.acc_loss *= self.loss_coeff['kl_en']
            loss_kl.acc_loss /= n_minibatch
            resloss_kl += loss_kl.get_loss()

            loss_l2.acc_loss *= self.loss_coeff['l2']
            loss_l2.acc_loss /= n_minibatch
            resloss_l2 += loss_l2.get_loss()

            loss_asr.add(loss_ae)
            loss_asr.add(loss_kl)
            loss_asr.add(loss_l2)
            loss_asr.backward()
            # torch.cuda.empty_cache()

        # update weights
        self.optimizer.step()
        model.zero_grad()

        losses = {}
        losses['nll_loss_asr'] = resloss_asr
        losses['nll_loss_ae'] = resloss_ae
        losses['kl_loss'] = resloss_kl
        losses['l2_loss'] = resloss_l2

        return losses
예제 #9
0
    def _evaluate_batches(self, model, dataset):

        # import pdb; pdb.set_trace()

        model.eval()

        resloss_en = 0
        resloss_norm = 0

        # accuracy
        match_en = 0
        total_en = 0

        # bleu
        hyp_corpus_en = []
        ref_corpus_en = []

        evaliter = iter(dataset.iter_loader)
        out_count = 0

        with torch.no_grad():
            for idx in range(len(evaliter)):
                batch_items = evaliter.next()

                # import pdb; pdb.set_trace()

                # load data
                batch_src_ids = batch_items['srcid'][0]
                batch_src_lengths = batch_items['srclen']
                batch_acous_feats = batch_items['acous_feat'][0]
                batch_acous_lengths = batch_items['acouslen']

                # separate into minibatch
                batch_size = batch_src_ids.size(0)
                batch_seq_len = int(max(batch_src_lengths))

                n_minibatch = int(batch_size / self.minibatch_size)
                n_minibatch += int(batch_size % self.minibatch_size > 0)

                for bidx in range(n_minibatch):

                    loss_en = NLLLoss()
                    loss_en.reset()

                    i_start = bidx * self.minibatch_size
                    i_end = min(i_start + self.minibatch_size, batch_size)
                    src_ids = batch_src_ids[i_start:i_end]
                    src_lengths = batch_src_lengths[i_start:i_end]
                    acous_feats = batch_acous_feats[i_start:i_end]
                    acous_lengths = batch_acous_lengths[i_start:i_end]

                    src_len = max(src_lengths)
                    acous_len = max(acous_lengths)
                    acous_len = acous_len + 8 - acous_len % 8
                    src_ids = src_ids.to(device=self.device)
                    acous_feats = acous_feats[:, :acous_len].to(
                        device=self.device)

                    # debug oom
                    # acous_feats, acous_lengths, src_ids, tgt_ids = self._debug_oom(acous_len, acous_feats)

                    non_padding_mask_src = src_ids.data.ne(PAD)

                    # [run-TF] to save time
                    # out_dict =  model.forward_train(src_ids, acous_feats=acous_feats,
                    #	acous_lens=acous_lengths, mode='ASR', use_gpu=self.use_gpu)

                    # [run-FR] to get true stats
                    out_dict = model.forward_eval(acous_feats=acous_feats,
                                                  acous_lens=acous_lengths,
                                                  mode='ASR',
                                                  use_gpu=self.use_gpu)

                    preds_en = out_dict['preds_asr']
                    logps_en = out_dict['logps_asr']
                    logps_hyp_en = logps_en
                    preds_hyp_en = preds_en

                    # evaluation
                    if not self.eval_with_mask:
                        loss_en.eval_batch(
                            logps_hyp_en.reshape(-1, logps_hyp_en.size(-1)),
                            src_ids[:, 1:].reshape(-1))
                        loss_en.norm_term = 1.0 * src_ids.size(
                            0) * src_ids[:, 1:].size(1)
                    else:
                        loss_en.eval_batch_with_mask(
                            logps_hyp_en.reshape(-1, logps_hyp_en.size(-1)),
                            src_ids[:, 1:].reshape(-1),
                            non_padding_mask_src[:, 1:].reshape(-1))
                        loss_en.norm_term = 1.0 * torch.sum(
                            non_padding_mask_src[:, 1:])

                    if self.normalise_loss:
                        loss_en.normalise()

                    resloss_en += loss_en.get_loss()
                    resloss_norm += 1

                    # accuracy
                    seqres_en = preds_hyp_en
                    correct_en = seqres_en.reshape(-1).eq(src_ids[:,1:].reshape(-1))\
                     .masked_select(non_padding_mask_src[:,1:].reshape(-1)).sum().item()
                    match_en += correct_en
                    total_en += non_padding_mask_src[:, 1:].sum().item()

                    # print
                    out_count = self._print(out_count,
                                            src_ids,
                                            dataset.src_id2word,
                                            seqres_en,
                                            tail='-asr')

                    # accumulate corpus
                    hyp_corpus_en, ref_corpus_en = add2corpus(
                        seqres_en,
                        src_ids,
                        dataset.src_id2word,
                        hyp_corpus_en,
                        ref_corpus_en,
                        type='word')

        # import pdb; pdb.set_trace()
        bleu_en = torchtext.data.metrics.bleu_score(hyp_corpus_en,
                                                    ref_corpus_en)

        # torch.cuda.empty_cache()
        if total_en == 0:
            accuracy_en = float('nan')
        else:
            accuracy_en = match_en / total_en

        resloss_en /= (1.0 * resloss_norm)

        losses = {}
        losses['nll_loss_de'] = 0
        losses['nll_loss_en'] = resloss_en
        metrics = {}
        metrics['accuracy_de'] = 0
        metrics['bleu_de'] = 0
        metrics['accuracy_en'] = accuracy_en
        metrics['bleu_en'] = bleu_en

        return losses, metrics
    def _train_batch(self, model, batch_items, dataset, step, total_steps):

        # load data
        batch_src_ids = batch_items['srcid'][0]
        batch_src_lengths = batch_items['srclen']
        batch_tgt_ids = batch_items['tgtid'][0]
        batch_tgt_lengths = batch_items['tgtlen']

        # separate into minibatch
        batch_size = batch_src_ids.size(0)
        batch_seq_len = int(max(batch_src_lengths))
        n_minibatch = int(batch_size / self.minibatch_size)
        n_minibatch += int(batch_size % self.minibatch_size > 0)
        resloss_de = 0
        resloss_en = 0

        for bidx in range(n_minibatch):

            # debug
            # print(bidx,n_minibatch)
            # import pdb; pdb.set_trace()

            # define loss
            loss_de = NLLLoss()
            loss_de.reset()
            loss_en = NLLLoss()
            loss_en.reset()

            # load data
            i_start = bidx * self.minibatch_size
            i_end = min(i_start + self.minibatch_size, batch_size)
            src_ids = batch_src_ids[i_start:i_end]
            src_lengths = batch_src_lengths[i_start:i_end]
            tgt_ids = batch_tgt_ids[i_start:i_end]
            tgt_lengths = batch_tgt_lengths[i_start:i_end]

            src_len = max(src_lengths)
            tgt_len = max(tgt_lengths)
            src_ids = src_ids.to(device=self.device)
            tgt_ids = tgt_ids.to(device=self.device)

            # debug oom
            # acous_feats, acous_lengths, src_ids, tgt_ids = self._debug_oom(acous_len, acous_feats)

            # get padding mask
            non_padding_mask_src = src_ids.data.ne(PAD)
            non_padding_mask_tgt = tgt_ids.data.ne(PAD)

            # Forward propagation
            out_dict = model.forward_train(src_ids,
                                           tgt=tgt_ids,
                                           mode='AE_MT',
                                           use_gpu=self.use_gpu)

            logps_de = out_dict['logps_mt'][:, :-1, :]
            logps_en = out_dict['logps_ae'][:, 1:, :]

            # Get loss
            if not self.eval_with_mask:
                loss_de.eval_batch(logps_de.reshape(-1, logps_de.size(-1)),
                                   tgt_ids[:, 1:].reshape(-1))
                loss_de.norm_term = 1.0 * tgt_ids.size(0) * tgt_ids[:,
                                                                    1:].size(1)
                loss_en.eval_batch(logps_en.reshape(-1, logps_en.size(-1)),
                                   src_ids[:, 1:].reshape(-1))
                loss_en.norm_term = 1.0 * src_ids.size(0) * src_ids[:,
                                                                    1:].size(1)
            else:
                loss_de.eval_batch_with_mask(
                    logps_de.reshape(-1, logps_de.size(-1)),
                    tgt_ids[:, 1:].reshape(-1),
                    non_padding_mask_tgt[:, 1:].reshape(-1))
                loss_de.norm_term = 1.0 * torch.sum(non_padding_mask_tgt[:,
                                                                         1:])
                loss_en.eval_batch_with_mask(
                    logps_en.reshape(-1, logps_en.size(-1)),
                    src_ids[:, 1:].reshape(-1),
                    non_padding_mask_src[:, 1:].reshape(-1))
                loss_en.norm_term = 1.0 * torch.sum(non_padding_mask_src[:,
                                                                         1:])

            # import pdb; pdb.set_trace()
            # Backward propagation: accumulate gradient
            if self.normalise_loss:
                loss_de.normalise()
                loss_en.normalise()

            loss_de.acc_loss /= n_minibatch
            loss_de.acc_loss *= self.loss_coeff['loss_mt']
            resloss_de += loss_de.get_loss()

            loss_en.acc_loss /= n_minibatch
            loss_en.acc_loss *= self.loss_coeff['loss_ae']
            resloss_en += loss_en.get_loss()

            loss_en.add(loss_de)
            loss_en.backward()

            # torch.cuda.empty_cache()

        # update weights
        self.optimizer.step()
        model.zero_grad()

        losses = {}
        losses['nll_loss_de'] = resloss_de
        losses['nll_loss_en'] = resloss_en

        return losses
예제 #11
0
    def _train_batch(self,
                     src_ids,
                     tgt_ids,
                     model,
                     step,
                     total_steps,
                     src_probs=None,
                     src_labs=None):
        """
			Args:
				src_ids 		=     w1 w2 w3 </s> <pad> <pad> <pad>
				tgt_ids 		= <s> w1 w2 w3 </s> <pad> <pad> <pad>
			(optional)
				src_probs 		=     p1 p2 p3 0    0     ...
			Others:
				internal input 	= <s> w1 w2 w3 </s> <pad> <pad>
				decoder_outputs	= 	  w1 w2 w3 </s> <pad> <pad> <pad>
		"""

        # define loss
        loss = NLLLoss()

        # scheduled sampling
        if not self.scheduled_sampling:
            teacher_forcing_ratio = self.teacher_forcing_ratio
        else:
            # use self.teacher_forcing_ratio as the starting point
            progress = 1.0 * step / total_steps
            teacher_forcing_ratio = 1.0 - progress

        # get padding mask
        non_padding_mask_src = src_ids.data.ne(PAD)
        non_padding_mask_tgt = tgt_ids.data.ne(PAD)

        # Forward propagation
        decoder_outputs, decoder_hidden, ret_dict = model(
            src_ids,
            tgt_ids,
            is_training=True,
            teacher_forcing_ratio=teacher_forcing_ratio,
            att_key_feats=src_probs)

        # Get loss
        loss.reset()
        # import pdb; pdb.set_trace()

        logps = torch.stack(decoder_outputs, dim=1).to(device=device)
        if not self.eval_with_mask:
            loss.eval_batch(logps.reshape(-1, logps.size(-1)),
                            tgt_ids[:, 1:].reshape(-1))
        else:
            loss.eval_batch_with_mask(logps.reshape(-1, logps.size(-1)),
                                      tgt_ids[:, 1:].reshape(-1),
                                      non_padding_mask_tgt[:, 1:].reshape(-1))

        if not self.eval_with_mask:
            loss.norm_term = 1.0 * tgt_ids.size(0) * tgt_ids[:, 1:].size(1)
        else:
            loss.norm_term = 1.0 * torch.sum(non_padding_mask_tgt[:, 1:])
        loss.normalise()

        # Backward propagation
        model.zero_grad()
        resloss = loss.get_loss()
        att_resloss = 0
        dsfclassify_resloss = 0

        loss.backward()
        self.optimizer.step()

        return resloss, att_resloss, dsfclassify_resloss
예제 #12
0
    def _evaluate_batches(self, model, batches, dataset):

        model.eval()

        loss = NLLLoss()
        loss.reset()

        match = 0
        total = 0

        out_count = 0
        with torch.no_grad():
            for batch in batches:

                src_ids = batch['src_word_ids']
                src_lengths = batch['src_sentence_lengths']
                tgt_ids = batch['tgt_word_ids']
                tgt_lengths = batch['tgt_sentence_lengths']
                src_probs = None
                if 'src_ddfd_probs' in batch and model.additional_key_size > 0:
                    src_probs = batch['src_ddfd_probs']
                    src_probs = _convert_to_tensor(src_probs,
                                                   self.use_gpu).unsqueeze(2)
                src_labs = None
                if 'src_ddfd_labs' in batch:
                    src_labs = batch['src_ddfd_labs']
                    src_labs = _convert_to_tensor(src_labs,
                                                  self.use_gpu).unsqueeze(2)

                src_ids = _convert_to_tensor(src_ids, self.use_gpu)
                tgt_ids = _convert_to_tensor(tgt_ids, self.use_gpu)

                non_padding_mask_tgt = tgt_ids.data.ne(PAD)
                non_padding_mask_src = src_ids.data.ne(PAD)

                decoder_outputs, decoder_hidden, other = model(
                    src_ids,
                    tgt_ids,
                    is_training=False,
                    att_key_feats=src_probs)

                # Evaluation
                logps = torch.stack(decoder_outputs, dim=1).to(device=device)
                if not self.eval_with_mask:
                    loss.eval_batch(logps.reshape(-1, logps.size(-1)),
                                    tgt_ids[:, 1:].reshape(-1))
                else:
                    loss.eval_batch_with_mask(
                        logps.reshape(-1, logps.size(-1)),
                        tgt_ids[:, 1:].reshape(-1),
                        non_padding_mask_tgt[:, 1:].reshape(-1))

                seqlist = other['sequence']
                seqres = torch.stack(seqlist, dim=1).to(device=device)
                correct = seqres.view(-1).eq(tgt_ids[:,1:].reshape(-1))\
                 .masked_select(non_padding_mask_tgt[:,1:].reshape(-1)).sum().item()
                match += correct
                total += non_padding_mask_tgt[:, 1:].sum().item()

                if not self.eval_with_mask:
                    loss.norm_term = 1.0 * tgt_ids.size(
                        0) * tgt_ids[:, 1:].size(1)
                else:
                    loss.norm_term = 1.0 * torch.sum(non_padding_mask_tgt[:,
                                                                          1:])
                loss.normalise()

                if out_count < 3:
                    srcwords = _convert_to_words_batchfirst(
                        src_ids, dataset.tgt_id2word)
                    refwords = _convert_to_words_batchfirst(
                        tgt_ids[:, 1:], dataset.tgt_id2word)
                    seqwords = _convert_to_words(seqlist, dataset.tgt_id2word)
                    outsrc = 'SRC: {}\n'.format(' '.join(
                        srcwords[0])).encode('utf-8')
                    outref = 'REF: {}\n'.format(' '.join(
                        refwords[0])).encode('utf-8')
                    outline = 'GEN: {}\n'.format(' '.join(
                        seqwords[0])).encode('utf-8')
                    sys.stdout.buffer.write(outsrc)
                    sys.stdout.buffer.write(outref)
                    sys.stdout.buffer.write(outline)
                    out_count += 1

        att_resloss = 0
        attcls_resloss = 0
        resloss = loss.get_loss()

        if total == 0:
            accuracy = float('nan')
        else:
            accuracy = match / total
        torch.cuda.empty_cache()

        losses = {}
        losses['att_loss'] = att_resloss
        losses['attcls_loss'] = attcls_resloss

        return resloss, accuracy, losses