Esempio n. 1
0
class Trainer(object):
    def __init__(self, opt, model, dicts, data):
        weight = torch.ones(opt.full_dict_size)
        weight[PAD] = 0
        assert PAD == dicts[0].fword2idx('<pad>')
        # if opt.mul_loss:
        #     self.crit = torch.nn.NLLLoss(size_average=False, ignore_index=PAD, reduce=False)
        # else:
        #     self.crit = torch.nn.NLLLoss(size_average=True, ignore_index=PAD)
        self.opt = opt
        self.model = model
        self.train_bag = data
        self.n_batch = len(self.train_bag)
        parameters = filter(lambda p: p.requires_grad, self.model.parameters())
        self.optimizer = torch.optim.Adagrad(parameters, lr=opt.lr)  # TODO
        # self.optimizer = torch.optim.SGD(parameters,lr=opt.lr)

        self.mul_loss = opt.mul_loss
        self.add_loss = opt.add_loss

        # dicts = [word_dict, pos_dict, ner_dict]
        self.word_dict = dicts[0]
        self.pos_dict = dicts[1]
        self.ner_dict = dicts[2]

        self.clip = opt.clip
        # self.val_bag = val_data
        self.bool_test = False
        self.coverage = self.opt.coverage
        self.logger = Logger(opt.print_every, self.n_batch)
        self.histo = None

    def assert_special_chars(self):
        return None
        # print self.word_dict.word2idx['<eos>']
        assert self.word_dict.word2idx['<eos>'] == EOS
        assert self.word_dict.word2idx['<pad>'] == PAD
        assert self.word_dict.word2idx['<sos>'] == SOS
        assert self.word_dict.word2idx['<unk>'] == UNK

    def weighted_loss(self, decoder_outputs_prob, decoder_outputs, tgt_var):
        """

        :param decoder_outputs_prob: seq, batch, dict_size
        :param decoder_outputs: seq, batch
        :param tgt_var: seq, batch
        :return:
        """
        gram = [2, 3]
        weight = [0.1, 0.01]
        seq_len, batch_size = tgt_var.size()[0], tgt_var.size()[1]
        seq_len__, batch_size__ = decoder_outputs.size(
        )[0], decoder_outputs.size()[1]
        seq_len_, batch_size_, dict_size = decoder_outputs_prob.size()[0], decoder_outputs_prob.size()[1] \
            , decoder_outputs_prob.size()[2]
        assert seq_len == seq_len_ == seq_len__
        assert batch_size == batch_size_ == batch_size__

        gold_n_grams = n_gram_list(gram, tgt_var)

        weight_mat = torch.ones(decoder_outputs.size())
        for b in range(batch_size):
            pred_seq = decoder_outputs[:, b]
            for t in range(seq_len_):
                for j, g in enumerate(gram):
                    if t - g + 1 >= 0:
                        feat = ''
                        for x in range(t - g + 1, t + 1, 1):
                            feat += str(pred_seq[x]) + '_'
                        if feat in gold_n_grams[b][j]:
                            for x in range(t - g + 1, t + 1, 1):
                                weight_mat[x, b] *= weight[j]
        weight_mat = weight_mat.view(seq_len * batch_size, -1)
        # print(torch.mean(weight_mat))
        loss = self.crit(decoder_outputs_prob.view(seq_len * batch_size, -1),
                         Var(tgt_var).view(seq_len * batch_size).cuda())
        original_loss = torch.sum(loss)
        cu_weights = Var(weight_mat.squeeze()).cuda()
        loss = torch.sum(torch.mul(loss, cu_weights))
        # cu_weights = Var(weight_mat.transpose()).cuda()
        # loss = loss * cu_weights
        return loss, original_loss

    def train_iters(self):

        for epo in range(self.opt.start_epo, self.opt.n_epo + 1):
            self.logger.init_new_epo(epo)
            # Schedule
            self.opt.max_len_enc, self.opt.max_len_dec, self.cov_loss_weight = util.schedule(
                epo)
            # if self.opt.max_len_enc == self.histo:
            #     need_to_recompute = False
            # else:
            #     need_to_recompute = True
            #     self.histo = self.opt.max_len_enc

            # self.train_bag = self.model.feat.update_msks(self.opt.max_len_enc, self.opt.max_len_dec, self.train_bag,
            #                                              self.ner_dict)

            # if self.opt.feat_sp or self.opt.feat_nn:
            #     self.train_bag = self.model.feat.sp.extract_feat(self.opt, self.train_bag,
            #                                                      [self.pos_dict, self.ner_dict])

            batch_order = np.arange(self.n_batch)
            np.random.shuffle(batch_order)

            for idx, batch_idx in enumerate(batch_order):
                self.logger.init_new_batch(batch_idx)
                tmp_cur_batch = self.train_bag[batch_idx]

                current_batch = copy.deepcopy(tmp_cur_batch)
                # if need_to_recompute:
                current_batch = self.model.feat.update_msks_batch(
                    self.opt, self.opt.mode, self.opt.max_len_enc,
                    self.opt.max_len_dec, current_batch, self.pos_dict,
                    self.ner_dict)

                inp_var = current_batch['cur_inp_var']
                inp_mask = current_batch['cur_inp_mask']
                out_var = current_batch['cur_out_var']
                out_mask = current_batch['cur_out_mask']
                scatter_msk = current_batch['cur_scatter_mask'].cuda()
                replacement = current_batch['replacement']
                max_oov_len = len(replacement)
                self.logger.set_oov(max_oov_len)

                if self.opt.feat_word or self.opt.feat_ent or self.opt.feat_sent:
                    features = [
                        current_batch['word_feat'], current_batch['ent_feat'],
                        current_batch['sent_feat']
                    ]
                    feature_msks = [
                        current_batch['cur_word_msk'],
                        current_batch['cur_ent_msk'],
                        current_batch['cur_sent_msk']
                    ]
                else:
                    features = None
                    feature_msks = None
                if self.opt.mul_loss or self.opt.add_loss:
                    bigram = current_batch['bigram']
                    # print(torch.sum(bigram))
                    bigram_msk = current_batch['bigram_msk']
                    bigram_dict = current_batch['bigram_dict']
                    bigram_bunch = [bigram, bigram_msk, bigram_dict]
                    # print(torch.sum(window_msk))
                else:
                    bigram_bunch = None
                # inp_var = util.truncate_mat(self.opt.max_len_enc, inp_var)
                # out_var = util.truncate_mat(self.opt.max_len_dec, out_var)

                # Sparse Feature preload
                # sparse_feat_indicator_mat = self.model.feat_sp.generate_feature_indicator(inp_var, ori_txt,
                #                                                                           self.ner_dict)

                # # Need to generate Mask for both txt and abs
                # inp_mask = torch.gt(inp_var[0], 0)
                # out_mask = torch.gt(out_var[0], 0)
                # # Need to fix Attention Supervision since some of the the raw text are lost after truncation
                # neg_mask = torch.ge(out_var[1], self.opt.max_len_enc)
                # out_var[1] = out_var[1].masked_fill_(neg_mask, -1)
                # # out_var[1] = out_var[1] * neg_mask
                #
                # scatter_mask = util.prepare_scatter_map(inp_var[0])

                inp_var = [Var(x) for x in inp_var]

                if self.opt.use_cuda:
                    inp_var = [x.contiguous().cuda() for x in inp_var]

                self.func_train(inp_var, inp_mask, out_var, out_mask, features,
                                feature_msks, max_oov_len, scatter_msk,
                                bigram_bunch)

                if idx % self.opt.save_every == 0:
                    #######
                    # Saving
                    # End of Epo
                    print_loss_avg = sum(
                        self.logger.lm.history_loss['NLL']) / len(
                            self.logger.lm.history_loss['ALL'])
                    # print_loss_avg = sum(self.logger.current_epo['loss']) / self.logger.current_epo['count']
                    os.chdir(self.opt.save_dir)

                    name_string = '%d_%.3f_%s_Cop%s_Cov%s_%dx%s_%s%s_E%d_D%d_DL%s_%01.1f_SL%s_%01.1f_Attn%s_%01.1f_Feat%s'.lower(
                    ) % (epo, print_loss_avg, str(self.opt.enc),
                         str(self.opt.copy), str(
                             self.opt.coverage), self.model.opt.full_dict_size,
                         datetime.datetime.now().strftime("%B%d%I%M"),
                         self.opt.data_path.split('/')[-1], self.opt.name,
                         self.opt.max_len_enc, self.opt.max_len_dec,
                         str(self.opt.mul_loss), self.opt.lw_bgdyn,
                         str(self.opt.add_loss), self.opt.lw_bgsta,
                         str(self.opt.attn_sup), self.opt.lw_attn,
                         str(self.opt.feat_sp))
                    print(name_string)
                    torch.save(self.model.emb.state_dict(),
                               name_string + '_emb')
                    torch.save(self.model.feat, name_string + '_feat')
                    torch.save(self.model.enc.state_dict(),
                               name_string + '_enc')
                    torch.save(self.model.dec.state_dict(),
                               name_string + '_dec')
                    torch.save(self.model.opt, name_string + '_opt')

                    os.chdir('..')
            # End Saving
            ########

        print('\n')

    def func_train(self, inp_var, inp_msk, out_var, tgt_msk, features,
                   feature_msks, max_oov_len, scatter_mask, bigram_bunch):
        self.optimizer.zero_grad()  # clear grad

        tgt_var, attn_sup = out_var

        batch_size = inp_var[0].size()[1]
        batch_size_ = tgt_var.size()[1]
        assert batch_size == batch_size_ == inp_var[2].size(
        )[1] == inp_var[1].size()[1]

        target_len = tgt_var.size()[0]
        src_len = inp_var[0].size()[0]

        self.logger.current_batch['valid_pos'] = torch.sum(tgt_msk)

        decoder_outputs_prob, decoder_outputs, attns, discount, loss_cov, p_copys = self.model.train_forward(
            inp_var, tgt_var, inp_msk, tgt_msk, features, feature_msks,
            max_oov_len, scatter_mask, bigram_bunch, self.logger)
        tgt_padding_mask = Var(tgt_msk.float(),
                               requires_grad=False).cuda().view(
                                   target_len * batch_size, 1)

        # print(decoder_outputs_prob, decoder_outputs, attns, discount, loss_cov)
        # decoder_outputs_prob: tgt,batch,full_vocab
        # decoder_outputs: tgt, batch
        # attns: tgt, batch, src_len
        # discount: tgt, batch, full_vocab
        # loss_cov: tgt, batch
        # p_copys: batch, tgt   Value only

        if self.opt.copy:
            self.logger.lm.add_LossItem(
                LossItem(name='pgen', node=torch.mean(p_copys), weight=0))

        if self.coverage:
            # loss_cov: tgt, batch
            flat_loss_cov = loss_cov.view(target_len * batch_size, 1)
            val_loss_cov = torch.mean(tgt_padding_mask * flat_loss_cov)
            self.logger.lm.add_LossItem(
                LossItem(name='cov', node=val_loss_cov,
                         weight=self.opt.lw_cov))

        if self.opt.attn_sup:
            # attn_sup: tgt_len, batch LongTensor
            # attns: tgt_len, batch, src_len floatTensor
            tgt_sz, batch_sz = attn_sup.size()
            tgt_sz_, batch_sz_, src_len_ = attns.size()
            assert batch_sz == batch_size_
            flat_attn_sup = Var(attn_sup.view(-1)).cuda()
            flat_attns = attns.view(-1, src_len_)

            valid_attn_msk = torch.gt(flat_attn_sup, 0)
            valid_tgt = torch.masked_select(flat_attn_sup, valid_attn_msk)

            extded_valid_attn_msk = valid_attn_msk.view(
                -1, 1).expand_as(flat_attns)
            valid_pred = torch.masked_select(flat_attns, extded_valid_attn_msk)

            valid_pred = valid_pred.view(-1, src_len_)
            loss_attn = -torch.log(
                torch.gather(valid_pred, 1, valid_tgt.unsqueeze(1)))
            loss_attn = torch.mean(loss_attn)
            self.logger.lm.add_LossItem(
                LossItem(name='attn', node=loss_attn, weight=self.opt.lw_attn))
            # loss_attn = 1 - torch.mean(torch.gather(valid_pred, 0, valid_tgt))

        # Compulsory NLL loss part
        pred_prob = decoder_outputs_prob.view(target_len * batch_size, -1)
        gold_dist = Var(tgt_var).view(target_len * batch_size, 1).cuda()
        losses = -torch.gather(pred_prob, 1, gold_dist)
        losses = losses * tgt_padding_mask
        # Then for -inf mask. a word neither exists in vocab nor exists in source article will be
        # -inf.
        inf_mask = losses.le(10000)
        nll_loss = torch.masked_select(losses, inf_mask)
        nll_loss = torch.mean(nll_loss)
        self.logger.lm.add_LossItem(
            LossItem(name='NLL', node=nll_loss, weight=self.opt.lw_nll))

        if self.mul_loss:
            # discount: tgt, batch, full_vocab

            # discount = torch.min(discount, torch.ones_like(discount))  # should naturally <1
            # discount = -torch.masked_select(discount, torch.gt(discount, 0))
            discount = 2 - torch.sum(discount) / 200

            # discount = torch.mean(discount)
            self.logger.lm.add_LossItem(
                LossItem(name='BGdyn', node=discount,
                         weight=self.opt.lw_bgdyn))
        elif self.add_loss:
            # print(discount)
            _msk = torch.gt(discount, 0).view(target_len * batch_size,
                                              -1).float()

            # x = torch.masked_select(losses, _msk)
            discount = 1 - torch.sum(losses * _msk) / 500
            # print(x)

            # reward_msk = 1 - torch.gt(discount, 0).float().view(target_len * batch_size, -1) * 99./100.
            # sta_weighted_loss = torch.mean(losses * reward_msk)
            self.logger.lm.add_LossItem(
                LossItem(name='BGsta', node=discount,
                         weight=self.opt.lw_bgsta))

        loss = self.logger.lm.compute()
        loss.backward()
        # print('update')
        torch.nn.utils.clip_grad_norm(self.model.parameters(), self.clip)
        self.optimizer.step()
        self.logger.batch_end()