Esempio n. 1
0
    def __init__(self, model, sv, tv, optim, trg_dict_size,
                 valid_data=None, tests_data=None, n_critic=1):

        self.lamda = 5
        self.eps = 1e-20
        #self.beta_KL = 0.005
        self.beta_KL = 0.
        self.beta_RLGen = 0.2
        self.clip_rate = 0.
        self.beta_RLBatch = 0.

        self.model = model
        self.decoder = model.decoder
        self.classifier = self.decoder.classifier
        self.sv, self.tv = sv, tv
        self.trg_dict_size = trg_dict_size

        self.n_critic = 1

        self.translator_sample = Translator(self.model, sv, tv, k=1, noise=False)
        #self.translator = Translator(model, sv, tv, k=10)
        if isinstance(optim, list):
            self.optim_G, self.optim_D = optim[0], optim[1]
            self.optim_G.init_optimizer(self.model.parameters())
            self.optim_D.init_optimizer(self.model.parameters())
        else:
            self.optim_G = Optim(
                'adam', 10e-05, wargs.max_grad_norm,
                learning_rate_decay=wargs.learning_rate_decay,
                start_decay_from=wargs.start_decay_from,
                last_valid_bleu=wargs.last_valid_bleu
            )
            self.optim_G.init_optimizer(self.model.parameters())
            self.optim_D = optim
            self.optim_D.init_optimizer(self.model.parameters())
            self.optim = [self.optim_G, self.optim_D]

        '''
        self.optim_RL = Optim(
            'adadelta', 1.0, wargs.max_grad_norm,
            learning_rate_decay=wargs.learning_rate_decay,
            start_decay_from=wargs.start_decay_from,
            last_valid_bleu=wargs.last_valid_bleu
        )
        self.optim_RL.init_optimizer(self.model.parameters())
        '''
        self.maskSoftmax = MaskSoftmax()
        self.valid_data = valid_data
        self.tests_data = tests_data
Esempio n. 2
0
class Trainer:

    def __init__(self, model, sv, tv, optim, trg_dict_size,
                 valid_data=None, tests_data=None, n_critic=1):

        self.lamda = 5
        self.eps = 1e-20
        #self.beta_KL = 0.005
        self.beta_KL = 0.
        self.beta_RLGen = 0.2
        self.clip_rate = 0.
        self.beta_RLBatch = 0.

        self.model = model
        self.decoder = model.decoder
        self.classifier = self.decoder.classifier
        self.sv, self.tv = sv, tv
        self.trg_dict_size = trg_dict_size

        self.n_critic = 1

        self.translator_sample = Translator(self.model, sv, tv, k=1, noise=False)
        #self.translator = Translator(model, sv, tv, k=10)
        if isinstance(optim, list):
            self.optim_G, self.optim_D = optim[0], optim[1]
            self.optim_G.init_optimizer(self.model.parameters())
            self.optim_D.init_optimizer(self.model.parameters())
        else:
            self.optim_G = Optim(
                'adam', 10e-05, wargs.max_grad_norm,
                learning_rate_decay=wargs.learning_rate_decay,
                start_decay_from=wargs.start_decay_from,
                last_valid_bleu=wargs.last_valid_bleu
            )
            self.optim_G.init_optimizer(self.model.parameters())
            self.optim_D = optim
            self.optim_D.init_optimizer(self.model.parameters())
            self.optim = [self.optim_G, self.optim_D]

        '''
        self.optim_RL = Optim(
            'adadelta', 1.0, wargs.max_grad_norm,
            learning_rate_decay=wargs.learning_rate_decay,
            start_decay_from=wargs.start_decay_from,
            last_valid_bleu=wargs.last_valid_bleu
        )
        self.optim_RL.init_optimizer(self.model.parameters())
        '''
        self.maskSoftmax = MaskSoftmax()
        self.valid_data = valid_data
        self.tests_data = tests_data

    def mt_eval(self, eid, bid, optim=None):

        if optim: self.optim = optim
        state_dict = { 'model': self.model.state_dict(), 'epoch': eid, 'batch': bid, 'optim': self.optim }

        if wargs.save_one_model: model_file = '{}.pt'.format(wargs.model_prefix)
        else: model_file = '{}_e{}_upd{}.pt'.format(wargs.model_prefix, eid, bid)
        tc.save(state_dict, model_file)
        wlog('Saving temporary model in {}'.format(model_file))

        self.model.eval()

        tor0 = Translator(self.model, self.sv, self.tv, print_att=wargs.print_att)
        BLEU = tor0.trans_eval(self.valid_data, eid, bid, model_file, self.tests_data)

        self.model.train()

        return BLEU

    # p1: (max_tlen_batch, batch_size, vocab_size)
    def distance(self, P, Q, y_masks, type='JS', y_gold=None):

        B = y_masks.size(1)
        hypo_N = y_masks.data.sum()

        if Q.size(0) > P.size(0): Q = Q[:(P.size(0) + 1)]

        if type == 'JS':
            #D_kl = tc.mean(tc.sum((tc.log(p1) - tc.log(p2)) * p1, dim=-1).squeeze(), dim=0)
            M = (P + Q) / 2.
            D_kl1 = tc.sum((tc.log(P) - tc.log(M)) * P, dim=-1).squeeze()
            D_kl2 = tc.sum((tc.log(Q) - tc.log(M)) * Q, dim=-1).squeeze()
            Js = 0.5 * D_kl1 + 0.5 * D_kl2
            sent_batch_dist = tc.sum(Js * y_masks) / B
            Js = Js / y_masks.sum(0)[None, :]
            word_level_dist = tc.sum(Js * y_masks) / B
            del M, D_kl1, D_kl2, Js

        elif type == 'KL':
            KL = tc.sum(P * (tc.log(P + self.eps) - tc.log(Q + self.eps)), dim=-1)
            # (L, B, V) -> (L, B)
            sent_batch_dist = tc.sum(KL * y_masks) / B
            word_level_dist0 = tc.sum(KL * y_masks) / hypo_N
            KL = KL / y_masks.sum(0)[None, :]
            #print W_KL.data
            word_level_dist1 = tc.sum(KL * y_masks) / B
            #print W_dist.data[0], y_masks.size(1)
            del KL

        elif type == 'KL-sent':

            #print p1[0]
            #print p2[0]
            #print '-----------------------------'
            p1 = tc.gather(p1, 2, y_gold[:, :, None])[:, :, 0]
            p2 = tc.gather(p2, 2, y_gold[:, :, None])[:, :, 0]
            # p1 (max_tlen_batch, batch_size)
            #print (p2 < 1) == False
            KL = (y_masks * (tc.log(p1) - tc.log(p2))) * p1
            sent_batch_dist = tc.sum(KL) / B
            KL = KL / y_masks.sum(0)[None, :]
            word_level_dist = tc.sum(KL * y_masks) / B
            # KL: (1, batch_size)
            del p1, p2, KL

        return sent_batch_dist, word_level_dist0, word_level_dist1

    def hyps_padding_dist(self, oracle, hyps_L, y_gold_maxL, p_y_hyp):

        #hyps_dist = [None] * B
        B, hyps_dist, hyps = oracle.size(1), [], [] # oracle, w/o bos
        assert (B == len(hyps_L)) and (oracle.size(0) == p_y_hyp.size(0))
        for bidx in range(B):
            hyp_L = hyps_L[bidx] - 1    # remove bos
            if hyp_L < y_gold_maxL:
                padding = tc.ones(y_gold_maxL - hyp_L) / self.trg_dict_size
                padding = padding[:, None].expand(padding.size(0), self.trg_dict_size)
                #pad = pad[:, None].expand((pad.size(0), one_p_y_hyp.size(-1)))
                padding = Variable(padding, requires_grad=False)
                if wargs.gpu_id and not padding.is_cuda: padding = padding.cuda()
                #print one_p_y_hyp.size(0), pad.size(0)
                #print tc.cat((p_y_hyp[:hyp_L, bidx, :], padding), dim=0).size()
                hyps_dist.append(tc.cat((p_y_hyp[:hyp_L, bidx, :], padding), dim=0))
                hyps.append(tc.cat((oracle[:hyp_L, bidx],
                                   Variable(PAD * tc.ones(y_gold_maxL - hyp_L).long()).cuda()), dim=0))
            else:
                hyps_dist.append(p_y_hyp[:y_gold_maxL, bidx, :])
                hyps.append(oracle[:y_gold_maxL, bidx])
            #hyps_dist[bidx] = one_p_y_hyp
        hyps_dist = tc.stack(hyps_dist, dim=1)
        hyps = tc.stack(hyps, dim=1)
        return hyps_dist, hyps

    def gumbel_sampling(self, B, y_maxL, feed_gold_out, noise=False):

        # feed_gold_out (L * B, V)
        logit = self.classifier.pred_map(feed_gold_out, noise=noise)

        if logit.is_cuda: logit = logit.cpu()
        hyps = tc.max(logit, 1)[1]
        # hyps (L*B, 1)
        hyps = hyps.view(y_maxL, B)
        hyps[0] = BOS * tc.ones(B).long()   # first words are <s>
        # hyps (L, B)
        c1 = tc.clamp((hyps.data - EOS), min=0, max=self.trg_dict_size)
        c2 = tc.clamp((EOS - hyps.data), min=0, max=self.trg_dict_size)
        _hyps = c1 + c2
        _hyps = tc.cat([_hyps, tc.zeros(B).long().unsqueeze(0)], 0)
        _hyps = tc.min(_hyps, 0)[1]
        #_hyps = tc.max(0 - _hyps, 0)[1]
        # idx: (1, B)
        hyps_L = _hyps.view(-1).tolist()
        hyps_mask = tc.zeros(y_maxL, B)
        for bid in range(B): hyps_mask[:, bid][:hyps_L[bid]] = 1.
        hyps_mask = Variable(hyps_mask, requires_grad=False)

        if wargs.gpu_id and not hyps_mask.is_cuda: hyps_mask = hyps_mask.cuda()
        if wargs.gpu_id and not hyps.is_cuda: hyps = hyps.cuda()

        return hyps, hyps_mask, hyps_L

    def try_trans(self, srcs, ref):

        # (len, 1)
        #src = sent_filter(list(srcs[:, bid].data))
        x_filter = sent_filter(list(srcs))
        y_filter = sent_filter(list(ref))
        #wlog('\n[{:3}] {}'.format('Src', idx2sent(x_filter, self.sv)))
        #wlog('[{:3}] {}'.format('Ref', idx2sent(y_filter, self.tv)))

        onebest, onebest_ids, _ = self.translator_sample.trans_onesent(x_filter)

        #wlog('[{:3}] {}'.format('Out', onebest))

        # no EOS and BOS
        return onebest_ids


    def beamsearch_sampling(self, srcs, trgs, eos=True):

        # y_masks: (trg_max_len, batch_size)
        B = srcs.size(1)
        oracles, oracles_L = [None] * B, [None] * B

        for bidx in range(B):
            onebest_ids = self.try_trans(srcs[:, bidx].data, trgs[:, bidx].data)

            if len(onebest_ids) == 0 or onebest_ids[0] != BOS: onebest_ids = [BOS] + onebest_ids
            if eos is True and onebest_ids[-1] != EOS: onebest_ids = onebest_ids + [EOS]
            oracles_L[bidx] = len(onebest_ids)
            oracles[bidx] = onebest_ids

        maxL = max(oracles_L)
        for bidx in range(B):
            cur_L, oracle = oracles_L[bidx], oracles[bidx]
            if cur_L < maxL: oracles[bidx] = oracle + [PAD] * (maxL - cur_L)

        oracles = Variable(tc.Tensor(oracles).long().t(), requires_grad=False) # -> (L, B)
        if wargs.gpu_id and not oracles.is_cuda: oracles = oracles.cuda()
        oracles_mask = oracles.ne(PAD).float()

        return oracles, oracles_mask, oracles_L

    def train(self, dh, dev_input, k, merge=False, name='default', percentage=0.1):

        #if (k + 1) % 1 == 0 and self.valid_data and self.tests_data:
        #    wlog('Evaluation on dev ... ')
        #    mt_eval(valid_data, self.model, self.sv, self.tv,
        #            0, 0, [self.optim, self.optim_RL, self.optim_G], self.tests_data)

        batch_count = len(dev_input)
        self.model.train()
        self.sampler = Nbs(self.model, self.tv, k=3, noise=False, print_att=False, batch_sample=True)

        for eid in range(wargs.start_epoch, wargs.max_epochs + 1):

            #self.optim_G.init_optimizer(self.model.parameters())
            #self.optim_RL.init_optimizer(self.model.parameters())

            size = int(percentage * batch_count)
            shuffled_batch_idx = tc.randperm(batch_count)

            wlog('{} NEW Epoch {}'.format('-' * 50, '-' * 50))
            wlog('{}, Epo:{:>2}/{:>2} start, random {}/{}({:.2%}) calc BLEU ... '.format(
                name, eid, wargs.max_epochs, size, batch_count, percentage), False)
            param_1, param_2, param_3, param_4, param_5, param_6 = [], [], [], [], [], []
            for k in range(size):
                bid, half_size = shuffled_batch_idx[k], wargs.batch_size

                # srcs: (max_sLen_batch, batch_size, emb), trgs: (max_tLen_batch, batch_size, emb)
                if merge is False: _, srcs, _, trgs, _, slens, srcs_m, trgs_m = dev_input[bid]
                else: _, srcs, _, trgs, _, slens, srcs_m, trgs_m = dh.merge_batch(dev_input[bid])[0]
                trgs, trgs_m = trgs[0], trgs_m[0]   # we only use the first dev reference

                if wargs.sampling == 'gumbeling':
                    oracles, oracles_mask, oracles_L = self.gumbel_sampling(B, y_gold_maxL, feed_gold_out, True)
                elif wargs.sampling == 'truncation':
                    oracles, oracles_mask, oracles_L = self.beamsearch_sampling(srcs, trgs)
                elif wargs.sampling == 'length_limit':
                    batch_beam_trgs = self.sampler.beam_search_trans(srcs, srcs_m, trgs_m)
                    hyps = [list(zip(*b)[0]) for b in batch_beam_trgs]
                    oracles = batch_search_oracle(hyps, trgs[1:], trgs_m[1:])
                    if wargs.gpu_id and not oracles.is_cuda: oracles = oracles.cuda()
                    oracles_mask = oracles.ne(0).float()
                    oracles_L = oracles_mask.sum(0).data.int().tolist()

                # oracles same with trgs, with bos and eos,(L, B)
                param_1.append(BLToStrList(oracles[1:-1].t(), [l-2 for l in oracles_L]))
                param_2.append(BLToStrList(trgs[1:-1].t(), trgs_m[1:-1].sum(0).data.int().tolist()))

                param_3.append(BLToStrList(oracles[1:-1, :half_size].t(),
                                       [l-2 for l in oracles_L[:half_size]]))
                param_4.append(BLToStrList(trgs[1:-1, :half_size].t(),
                                       trgs_m[1:-1, :half_size].sum(0).data.int().tolist()))
                param_5.append(BLToStrList(oracles[1:-1, half_size:].t(),
                                       [l-2 for l in oracles_L[half_size:]]))
                param_6.append(BLToStrList(trgs[1:-1, half_size:].t(),
                                       trgs_m[1:-1, half_size:].sum(0).data.int().tolist()))

            start_bat_bleu_hist = bleu('\n'.join(param_3), ['\n'.join(param_4)], logfun=debug)
            start_bat_bleu_new = bleu('\n'.join(param_5), ['\n'.join(param_6)], logfun=debug)
            start_bat_bleu = bleu('\n'.join(param_1), ['\n'.join(param_2)], logfun=debug)
            wlog('Random BLEU on history {}, new {}, mix {}'.format(
                start_bat_bleu_hist, start_bat_bleu_new, start_bat_bleu))

            wlog('Model selection and testing ... ')
            self.mt_eval(eid, 0, [self.optim_G, self.optim_D])
            if start_bat_bleu > 0.9:
                wlog('Better BLEU ... go to next data history ...')
                return

            s_kl_seen, w_kl_seen0, w_kl_seen1, rl_gen_seen, rl_rho_seen, rl_bat_seen, w_mle_seen, \
                    s_mle_seen, ppl_seen = 0., 0., 0., 0., 0., 0., 0., 0., 0.
            for bid in range(batch_count):

                if merge is False: _, srcs, _, trgs, _, slens, srcs_m, trgs_m = dev_input[bid]
                else: _, srcs, _, trgs, _, slens, srcs_m, trgs_m = dh.merge_batch(dev_input[bid], True)[0]
                trgs, trgs_m = trgs[0], trgs_m[0]
                gold_feed, gold_feed_mask = trgs[:-1], trgs_m[:-1]
                gold, gold_mask = trgs[1:], trgs_m[1:]
                B, y_gold_maxL = srcs.size(1), gold_feed.size(0)
                N = gold.data.ne(PAD).sum()
                debug('B:{}, gold_feed_ymaxL:{}, N:{}'.format(B, y_gold_maxL, N))

                ###################################################################################
                debug('Optimizing KL distance ................................ {}'.format(name))
                #self.model.zero_grad()
                self.optim_G.zero_grad()

                feed_gold_out, _ = self.model(srcs, gold_feed, srcs_m, gold_feed_mask)
                p_y_gold = self.classifier.logit_to_prob(feed_gold_out)
                # p_y_gold: (gold_max_len - 1, B, trg_dict_size)

                if wargs.sampling == 'gumbeling':
                    oracles, oracles_mask, oracles_L = self.gumbel_sampling(B, y_gold_maxL, feed_gold_out, True)
                elif wargs.sampling == 'truncation':
                    oracles, oracles_mask, oracles_L = self.beamsearch_sampling(srcs, trgs)
                elif wargs.sampling == 'length_limit':
                    # w/o eos
                    batch_beam_trgs = self.sampler.beam_search_trans(srcs, srcs_m, trgs_m)
                    hyps = [list(zip(*b)[0]) for b in batch_beam_trgs]
                    oracles = batch_search_oracle(hyps, trgs[1:], trgs_m[1:])
                    if wargs.gpu_id and not oracles.is_cuda: oracles = oracles.cuda()
                    oracles_mask = oracles.ne(0).float()
                    oracles_L = oracles_mask.sum(0).data.int().tolist()

                oracle_feed, oracle_feed_mask = oracles[:-1], oracles_mask[:-1]
                oracle, oracle_mask = oracles[1:], oracles_mask[1:]
                # oracles same with trgs, with bos and eos,(L, B)
                feed_oracle_out, _ = self.model(srcs, oracle_feed, srcs_m, oracle_feed_mask)
                p_y_hyp = self.classifier.logit_to_prob(feed_oracle_out)
                p_y_hyp_pad, oracle = self.hyps_padding_dist(oracle, oracles_L, y_gold_maxL, p_y_hyp)
                #wlog('feed oracle dist: {}, feed gold dist: {}, oracle: {}'.format(p_y_hyp_pad.size(), p_y_gold.size(), oracle.size()))
                #B_KL_loss = self.distance(p_y_gold, p_y_hyp_pad, hyps_mask[1:], type='KL', y_gold=gold)
                S_KL_loss, W_KL_loss0, W_KL_loss1 = self.distance(
                    p_y_gold, p_y_hyp_pad, gold_mask, type='KL', y_gold=gold)
                debug('KL: Sent-level {}, Word0-level {}, Word1-level {}'.format(
                    S_KL_loss.data[0], W_KL_loss0.data[0], W_KL_loss1.data[0]))
                s_kl_seen += S_KL_loss.data[0]
                w_kl_seen0 += W_KL_loss0.data[0]
                w_kl_seen1 += W_KL_loss1.data[0]
                del p_y_hyp, feed_oracle_out

                ###################################################################################
                debug('Optimizing RL(Gen) .......... {}'.format(name))
                hyps_list = BLToStrList(oracle[:-1].t(), [l-2 for l in oracles_L], True)
                trgs_list = BLToStrList(trgs[1:-1].t(), trgs_m[1:-1].sum(0).data.int().tolist(), True)
                bleus_sampling = []
                for hyp, ref in zip(hyps_list, trgs_list):
                    bleus_sampling.append(bleu(hyp, [ref], logfun=debug))
                bleus_sampling = toVar(bleus_sampling, wargs.gpu_id)

                oracle_mask = oracle.ne(0).float()
                p_y_ahyp = p_y_hyp_pad.gather(2, oracle[:, :, None])[:, :, 0]
                p_y_ahyp = ((p_y_ahyp + self.eps).log() * oracle_mask).sum(0) / oracle_mask.sum(0)

                p_y_agold = p_y_gold.gather(2, gold[:, :, None])[:, :, 0]
                p_y_agold = ((p_y_agold + self.eps).log() * gold_mask).sum(0) / gold_mask.sum(0)

                r_theta = p_y_ahyp / p_y_agold
                A = 1. - bleus_sampling
                RL_Gen_loss = tc.min(r_theta * A, clip(r_theta, self.clip_rate) * A).sum()
                RL_Gen_loss = (RL_Gen_loss).div(B)
                debug('...... RL(Gen) cliped loss {}'.format(RL_Gen_loss.data[0]))
                rl_gen_seen += RL_Gen_loss.data[0]
                del p_y_agold

                ###################################################################################
                debug('Optimizing RL(Batch) -> Gap of MLE and BLEU ... rho ... feed onebest .... ')
                param_1 = BLToStrList(oracles[1:-1].t(), [l-2 for l in oracles_L])
                param_2 = BLToStrList(trgs[1:-1].t(), trgs_m[1:-1].sum(0).data.int().tolist())
                rl_bat_bleu = bleu(param_1, [param_2], logfun=debug)
                rl_avg_bleu = tc.mean(bleus_sampling).data[0]

                rl_rho = cor_coef(p_y_ahyp, bleus_sampling, eps=self.eps)
                rl_rho_seen += rl_rho.data[0]   # must use data, accumulating Variable needs more memory

                #p_y_hyp = p_y_hyp.exp()
                #p_y_hyp = (p_y_hyp * self.lamda / 3).exp()
                #p_y_hyp = self.maskSoftmax(p_y_hyp)
                p_y_ahyp = p_y_ahyp[None, :]
                p_y_ahyp_T = p_y_ahyp.t().expand(B, B)
                p_y_ahyp = p_y_ahyp.expand(B, B)
                p_y_ahyp_sum = p_y_ahyp_T + p_y_ahyp + self.eps

                #bleus_sampling = bleus_sampling[None, :].exp()
                bleus_sampling = self.maskSoftmax(self.lamda * bleus_sampling[None, :])
                bleus_T = bleus_sampling.t().expand(B, B)
                bleus = bleus_sampling.expand(B, B)
                bleus_sum = bleus_T + bleus + self.eps
                #print 'p_y_hyp_sum......................'
                #print p_y_hyp_sum.data
                RL_Batch_loss = p_y_ahyp / p_y_ahyp_sum * tc.log(bleus_T / bleus_sum) + \
                        p_y_ahyp_T / p_y_ahyp_sum * tc.log(bleus / bleus_sum)

                #RL_Batch_loss = tc.sum(-RL_Batch_loss * toVar(1 - tc.eye(B))).div(B)
                RL_Batch_loss = tc.sum(-RL_Batch_loss * toVar(1 - tc.eye(B), wargs.gpu_id))

                debug('RL(Batch) Mean BLEU: {}, rl_batch_loss: {}, rl_rho: {}, Bat BLEU: {}'.format(
                    rl_avg_bleu, RL_Batch_loss.data[0], rl_rho.data[0], rl_bat_bleu))
                rl_bat_seen += RL_Batch_loss.data[0]
                del oracles, oracles_mask, oracle_feed, oracle_feed_mask, oracle, oracle_mask,\
                        p_y_ahyp, bleus_sampling, bleus, p_y_ahyp_T, p_y_ahyp_sum, bleus_T, bleus_sum
                '''
                (self.beta_KL * S_KL_loss + self.beta_RLGen * RL_Gen_loss + \
                        self.beta_RLBatch * RL_Batch_loss).backward(retain_graph=True)

                mle_loss, grad_output, _ = memory_efficient(
                    feed_gold_out, gold, gold_mask, self.model.classifier)
                feed_gold_out.backward(grad_output)
                '''

                (self.beta_KL * W_KL_loss0 + self.beta_RLGen * RL_Gen_loss + \
                        self.beta_RLBatch * RL_Batch_loss).backward(retain_graph=True)
                self.optim_G.step()

                ###################################################### discrimitor
                #mle_loss, _, _ = self.classifier(feed_gold_out, gold, gold_mask)
                #mle_loss = mle_loss.div(B)
                #mle_loss = mle_loss.data[0]

                self.optim_D.zero_grad()
                mle_loss, _, _ = self.classifier.snip_back_prop(feed_gold_out, gold, gold_mask)
                self.optim_D.step()

                w_mle_seen += ( mle_loss / N )
                s_mle_seen += ( mle_loss / B )
                ppl_seen += math.exp(mle_loss/N)
                wlog('Epo:{:>2}/{:>2}, Bat:[{}/{}], W0-KL {:4.2f}, W1-KL {:4.2f}, '
                     'S-RLGen {:4.2f}, B-rho {:4.2f}, B-RLBat {:4.2f}, W-MLE:{:4.2f}, '
                     'S-MLE:{:4.2f}, W-ppl:{:4.2f}, B-bleu:{:4.2f}, A-bleu:{:4.2f}'.format(
                         eid, wargs.max_epochs, bid, batch_count, W_KL_loss0.data[0],
                         W_KL_loss1.data[0], RL_Gen_loss.data[0], rl_rho.data[0], RL_Batch_loss.data[0],
                         mle_loss/N, mle_loss/B, math.exp(mle_loss/N), rl_bat_bleu, rl_avg_bleu))
                #wlog('=' * 100)
                del S_KL_loss, W_KL_loss0, W_KL_loss1, RL_Gen_loss, RL_Batch_loss, feed_gold_out

            wlog('End epoch: S-KL {:4.2f}, W0-KL {:4.2f}, W1-KL {:4.2f}, S-RLGen {:4.2f}, B-rho '
                 '{:4.2f}, B-RLBat {:4.2f}, W-MLE {:4.2f}, S-MLE {:4.2f}, W-ppl {:4.2f}'.format(
                s_kl_seen/batch_count, w_kl_seen0/batch_count, w_kl_seen1/batch_count, rl_gen_seen/batch_count,
                rl_rho_seen/batch_count, rl_bat_seen/batch_count, w_mle_seen/batch_count,
                s_mle_seen/batch_count, ppl_seen/batch_count))
Esempio n. 3
0
def main():

    init_dir(wargs.dir_model)
    init_dir(wargs.dir_valid)

    vocab_data = {}
    train_srcD_file = wargs.src_vocab_from
    wlog('\nPreparing out of domain source vocabulary from {} ... '.format(
        train_srcD_file))
    src_vocab = extract_vocab(train_srcD_file, wargs.src_dict,
                              wargs.src_dict_size)
    #DANN
    train_srcD_file_domain = wargs.src_domain_vocab_from
    wlog('\nPreparing in domain source vocabulary from {} ...'.format(
        train_srcD_file_domain))
    src_vocab = updata_vocab(train_srcD_file_domain, src_vocab, wargs.src_dict,
                             wargs.src_dict_size)

    vocab_data['src'] = src_vocab

    train_trgD_file = wargs.trg_vocab_from
    wlog('\nPreparing out of domain target vocabulary from {} ... '.format(
        train_trgD_file))
    trg_vocab = extract_vocab(train_trgD_file, wargs.trg_dict,
                              wargs.trg_dict_size)

    #DANN
    train_trgD_file_domain = wargs.trg_domain_vocab_from
    wlog('\nPreparing in domain target vocabulary from {} ... '.format(
        train_trgD_file_domain))
    trg_vocab = updata_vocab(train_trgD_file_domain, trg_vocab, wargs.trg_dict,
                             wargs.trg_dict_size)

    vocab_data['trg'] = trg_vocab

    train_src_file = wargs.train_src
    train_trg_file = wargs.train_trg
    if wargs.fine_tune is False:
        wlog('\nPreparing out of domain training set from {} and {} ... '.
             format(train_src_file, train_trg_file))
        train_src_tlst, train_trg_tlst = wrap_data(
            train_src_file,
            train_trg_file,
            vocab_data['src'],
            vocab_data['trg'],
            max_seq_len=wargs.max_seq_len)
    else:
        wlog('\nNo out of domain trainin set ...')

    #DANN
    train_src_file_domain = wargs.train_src_domain
    train_trg_file_domain = wargs.train_trg_domain
    wlog('\nPreparing in domain training set from {} and {}...'.format(
        train_src_file_domain, train_trg_file_domain))
    train_src_tlst_domain, train_trg_tlst_domain = wrap_data(
        train_src_file_domain,
        train_trg_file_domain,
        vocab_data['src'],
        vocab_data['trg'],
        max_seq_len=wargs.max_seq_len)
    '''
    list [torch.LongTensor (sentence), torch.LongTensor, torch.LongTensor, ...]
    no padding
    '''
    valid_file = '{}{}.{}'.format(wargs.val_tst_dir, wargs.val_prefix,
                                  wargs.val_src_suffix)
    wlog('\nPreparing validation set from {} ... '.format(valid_file))
    valid_src_tlst, valid_src_lens = val_wrap_data(valid_file, src_vocab)

    if wargs.fine_tune is False:
        wlog('Out of domain Sentence-pairs count in training data: {}'.format(
            len(train_src_tlst)))
    wlog('In domain Sentence-pairs count in training data: {}'.format(
        len(train_src_tlst_domain)))

    src_vocab_size, trg_vocab_size = vocab_data['src'].size(
    ), vocab_data['trg'].size()
    wlog('Vocabulary size: |source|={}, |target|={}'.format(
        src_vocab_size, trg_vocab_size))

    if wargs.fine_tune is False:
        batch_train = Input(train_src_tlst, train_trg_tlst, wargs.batch_size)
    else:
        batch_train = None

    batch_valid = Input(valid_src_tlst, None, 1, volatile=True)
    #DANN
    batch_train_domain = Input(train_src_tlst_domain, train_trg_tlst_domain,
                               wargs.batch_size)

    tests_data = None
    if wargs.tests_prefix is not None:
        init_dir(wargs.dir_tests)
        tests_data = {}
        for prefix in wargs.tests_prefix:
            init_dir(wargs.dir_tests + '/' + prefix)
            test_file = '{}{}.{}'.format(wargs.val_tst_dir, prefix,
                                         wargs.val_src_suffix)
            wlog('Preparing test set from {} ... '.format(test_file))
            test_src_tlst, _ = val_wrap_data(test_file, src_vocab)
            tests_data[prefix] = Input(test_src_tlst, None, 1, volatile=True)

    sv = vocab_data['src'].idx2key
    tv = vocab_data['trg'].idx2key

    nmtModel = NMT(src_vocab_size, trg_vocab_size)

    if wargs.pre_train is not None:

        assert os.path.exists(wargs.pre_train), 'Requires pre-trained model'
        _dict = _load_model(wargs.pre_train)
        # initializing parameters of interactive attention model
        class_dict = None
        if len(_dict) == 4: model_dict, eid, bid, optim = _dict
        elif len(_dict) == 5:
            model_dict, class_dict, eid, bid, optim = _dict
        for name, param in nmtModel.named_parameters():
            if name in model_dict:
                param.requires_grad = not wargs.fix_pre_params
                param.data.copy_(model_dict[name])
                wlog('{:7} -> grad {}\t{}'.format('Model', param.requires_grad,
                                                  name))
            elif name.endswith('map_vocab.weight'):
                if class_dict is not None:
                    param.requires_grad = not wargs.fix_pre_params
                    param.data.copy_(class_dict['map_vocab.weight'])
                    wlog('{:7} -> grad {}\t{}'.format('Model',
                                                      param.requires_grad,
                                                      name))
            elif name.endswith('map_vocab.bias'):
                if class_dict is not None:
                    param.requires_grad = not wargs.fix_pre_params
                    param.data.copy_(class_dict['map_vocab.bias'])
                    wlog('{:7} -> grad {}\t{}'.format('Model',
                                                      param.requires_grad,
                                                      name))
            else:
                init_params(param, name, True)

        wargs.start_epoch = eid + 1
    else:
        for n, p in nmtModel.named_parameters():
            init_params(p, n, True)
        optim = Optim(wargs.opt_mode,
                      wargs.learning_rate,
                      wargs.max_grad_norm,
                      learning_rate_decay=wargs.learning_rate_decay,
                      start_decay_from=wargs.start_decay_from,
                      last_valid_bleu=wargs.last_valid_bleu)

    if wargs.gpu_id:
        nmtModel.cuda()
        wlog('Push model onto GPU[{}] ... '.format(wargs.gpu_id[0]))
    else:
        nmtModel.cpu()
        wlog('Push model onto CPU ... ')

    wlog(nmtModel)
    wlog(optim)
    pcnt1 = len([p for p in nmtModel.parameters()])
    pcnt2 = sum([p.nelement() for p in nmtModel.parameters()])
    wlog('Parameters number: {}/{}'.format(pcnt1, pcnt2))

    optim.init_optimizer(nmtModel.parameters())

    trainer = Trainer(nmtModel, batch_train, batch_train_domain, vocab_data,
                      optim, batch_valid, tests_data)

    trainer.train()
Esempio n. 4
0
def main():

    #if wargs.ss_type is not None: assert wargs.model == 1, 'Only rnnsearch support schedule sample'
    init_dir(wargs.dir_model)
    init_dir(wargs.dir_valid)

    src = os.path.join(wargs.dir_data, '{}.{}'.format(wargs.train_prefix, wargs.train_src_suffix))
    trg = os.path.join(wargs.dir_data, '{}.{}'.format(wargs.train_prefix, wargs.train_trg_suffix))
    vocabs = {}
    wlog('\nPreparing source vocabulary from {} ... '.format(src))
    src_vocab = extract_vocab(src, wargs.src_vcb, wargs.n_src_vcb_plan,
                              wargs.max_seq_len, char=wargs.src_char)
    wlog('\nPreparing target vocabulary from {} ... '.format(trg))
    trg_vocab = extract_vocab(trg, wargs.trg_vcb, wargs.n_trg_vcb_plan, wargs.max_seq_len)
    n_src_vcb, n_trg_vcb = src_vocab.size(), trg_vocab.size()
    wlog('Vocabulary size: |source|={}, |target|={}'.format(n_src_vcb, n_trg_vcb))
    vocabs['src'], vocabs['trg'] = src_vocab, trg_vocab

    wlog('\nPreparing training set from {} and {} ... '.format(src, trg))
    trains = {}
    train_src_tlst, train_trg_tlst = wrap_data(wargs.dir_data, wargs.train_prefix,
                                               wargs.train_src_suffix, wargs.train_trg_suffix,
                                               src_vocab, trg_vocab, shuffle=True,
                                               sort_k_batches=wargs.sort_k_batches,
                                               max_seq_len=wargs.max_seq_len,
                                               char=wargs.src_char)
    '''
    list [torch.LongTensor (sentence), torch.LongTensor, torch.LongTensor, ...]
    no padding
    '''
    batch_train = Input(train_src_tlst, train_trg_tlst, wargs.batch_size,
                        batch_type=wargs.batch_type, bow=wargs.trg_bow, batch_sort=False)
    wlog('Sentence-pairs count in training data: {}'.format(len(train_src_tlst)))

    batch_valid = None
    if wargs.val_prefix is not None:
        val_src_file = os.path.join(wargs.val_tst_dir, '{}.{}'.format(wargs.val_prefix, wargs.val_src_suffix))
        val_trg_file = os.path.join(wargs.val_tst_dir, '{}.{}'.format(wargs.val_prefix, wargs.val_ref_suffix))
        wlog('\nPreparing validation set from {} and {} ... '.format(val_src_file, val_trg_file))
        valid_src_tlst, valid_trg_tlst = wrap_data(wargs.val_tst_dir, wargs.val_prefix,
                                                   wargs.val_src_suffix, wargs.val_ref_suffix,
                                                   src_vocab, trg_vocab, shuffle=False,
                                                   max_seq_len=wargs.dev_max_seq_len,
                                                   char=wargs.src_char)
        batch_valid = Input(valid_src_tlst, valid_trg_tlst, 1, batch_sort=False)

    batch_tests = None
    if wargs.tests_prefix is not None:
        assert isinstance(wargs.tests_prefix, list), 'Test files should be list.'
        init_dir(wargs.dir_tests)
        batch_tests = {}
        for prefix in wargs.tests_prefix:
            init_dir(wargs.dir_tests + '/' + prefix)
            test_file = '{}{}.{}'.format(wargs.val_tst_dir, prefix, wargs.val_src_suffix)
            wlog('\nPreparing test set from {} ... '.format(test_file))
            test_src_tlst, _ = wrap_tst_data(test_file, src_vocab, char=wargs.src_char)
            batch_tests[prefix] = Input(test_src_tlst, None, 1, batch_sort=False)
    wlog('\n## Finish to Prepare Dataset ! ##\n')

    src_emb = WordEmbedding(n_src_vcb, wargs.d_src_emb, wargs.input_dropout,
                            wargs.position_encoding, prefix='Src')
    trg_emb = WordEmbedding(n_trg_vcb, wargs.d_trg_emb, wargs.input_dropout,
                            wargs.position_encoding, prefix='Trg')
    # share the embedding matrix - preprocess with share_vocab required.
    if wargs.embs_share_weight:
        if n_src_vcb != n_trg_vcb:
            raise AssertionError('The `-share_vocab` should be set during '
                                 'preprocess if you use share_embeddings!')
        src_emb.we.weight = trg_emb.we.weight

    nmtModel = build_NMT(src_emb, trg_emb)

    if not wargs.copy_attn:
        classifier = Classifier(wargs.d_model if wargs.decoder_type == 'att' else 2 * wargs.d_enc_hid,
                                n_trg_vcb, trg_emb, loss_norm=wargs.loss_norm,
                                label_smoothing=wargs.label_smoothing,
                                emb_loss=wargs.emb_loss, bow_loss=wargs.bow_loss)
    nmtModel.decoder.classifier = classifier

    if wargs.gpu_id is not None:
        wlog('push model onto GPU {} ... '.format(wargs.gpu_id), 0)
        #nmtModel = nn.DataParallel(nmtModel, device_ids=wargs.gpu_id)
        nmtModel.to(tc.device('cuda'))
    else:
        wlog('push model onto CPU ... ', 0)
        nmtModel.to(tc.device('cpu'))
    wlog('done.')

    if wargs.pre_train is not None:
        assert os.path.exists(wargs.pre_train)
        from tools.utils import load_model
        _dict = load_model(wargs.pre_train)
        # initializing parameters of interactive attention model
        class_dict = None
        if len(_dict) == 5:
            model_dict, class_dict, eid, bid, optim = _dict
        elif len(_dict) == 4:
            model_dict, eid, bid, optim = _dict
        for name, param in nmtModel.named_parameters():
            if name in model_dict:
                param.requires_grad = not wargs.fix_pre_params
                param.data.copy_(model_dict[name])
                wlog('{:7} -> grad {}\t{}'.format('Model', param.requires_grad, name))
            elif name.endswith('map_vocab.weight'):
                if class_dict is not None:
                    param.requires_grad = not wargs.fix_pre_params
                    param.data.copy_(class_dict['map_vocab.weight'])
                    wlog('{:7} -> grad {}\t{}'.format('Model', param.requires_grad, name))
            elif name.endswith('map_vocab.bias'):
                if class_dict is not None:
                    param.requires_grad = not wargs.fix_pre_params
                    param.data.copy_(class_dict['map_vocab.bias'])
                    wlog('{:7} -> grad {}\t{}'.format('Model', param.requires_grad, name))
            else: init_params(param, name, init_D=wargs.param_init_D, a=float(wargs.u_gain))

        wargs.start_epoch = eid + 1

    else:
        optim = Optim(wargs.opt_mode, wargs.learning_rate, wargs.max_grad_norm)
        #for n, p in nmtModel.named_parameters():
            # bias can not be initialized uniformly
            #if wargs.encoder_type != 'att' and wargs.decoder_type != 'att':
            #    init_params(p, n, init_D=wargs.param_init_D, a=float(wargs.u_gain))

    wlog(nmtModel)
    wlog(optim)
    pcnt1 = len([p for p in nmtModel.parameters()])
    pcnt2 = sum([p.nelement() for p in nmtModel.parameters()])
    wlog('parameters number: {}/{}'.format(pcnt1, pcnt2))

    wlog('\n' + '*' * 30 + ' trainable parameters ' + '*' * 30)
    for n, p in nmtModel.named_parameters():
        if p.requires_grad: wlog('{:60} : {}'.format(n, p.size()))

    optim.init_optimizer(nmtModel.parameters())

    trainer = Trainer(nmtModel, batch_train, vocabs, optim, batch_valid, batch_tests)

    trainer.train()
Esempio n. 5
0
def main():
    # if wargs.ss_type is not None: assert wargs.model == 1, 'Only rnnsearch support schedule sample'
    init_dir(wargs.dir_model)
    init_dir(wargs.dir_valid)

    src = os.path.join(
        wargs.dir_data, '{}.{}'.format(wargs.train_prefix,
                                       wargs.train_src_suffix))
    trg = os.path.join(
        wargs.dir_data, '{}.{}'.format(wargs.train_prefix,
                                       wargs.train_trg_suffix))
    src, trg = os.path.abspath(src), os.path.abspath(trg)
    vocabs = {}
    if wargs.share_vocab is False:
        wlog('\nPreparing source vocabulary from {} ... '.format(src))
        src_vocab = extract_vocab(src,
                                  wargs.src_vcb,
                                  wargs.n_src_vcb_plan,
                                  wargs.max_seq_len,
                                  char=wargs.src_char)
        wlog('\nPreparing target vocabulary from {} ... '.format(trg))
        trg_vocab = extract_vocab(trg, wargs.trg_vcb, wargs.n_trg_vcb_plan,
                                  wargs.max_seq_len)
        n_src_vcb, n_trg_vcb = src_vocab.size(), trg_vocab.size()
        wlog('Vocabulary size: |source|={}, |target|={}'.format(
            n_src_vcb, n_trg_vcb))
    else:
        wlog('\nPreparing the shared vocabulary from \n\t{}\n\t{}'.format(
            src, trg))
        trg_vocab = src_vocab = extract_vocab(src,
                                              wargs.src_vcb,
                                              wargs.n_src_vcb_plan,
                                              wargs.max_seq_len,
                                              share_vocab=True,
                                              trg_file=trg)
        n_src_vcb, n_trg_vcb = src_vocab.size(), trg_vocab.size()
        wlog('Shared vocabulary size: |vocab|={}'.format(src_vocab.size()))

    vocabs['src'], vocabs['trg'] = src_vocab, trg_vocab

    wlog('\nPreparing training set from {} and {} ... '.format(src, trg))
    trains = {}
    train_src_tlst, train_trg_tlst = wrap_data(
        wargs.dir_data,
        wargs.train_prefix,
        wargs.train_src_suffix,
        wargs.train_trg_suffix,
        src_vocab,
        trg_vocab,
        shuffle=True,
        sort_k_batches=wargs.sort_k_batches,
        max_seq_len=wargs.max_seq_len,
        char=wargs.src_char)
    '''
    list [torch.LongTensor (sentence), torch.LongTensor, torch.LongTensor, ...]
    no padding
    '''
    batch_train = Input(train_src_tlst,
                        train_trg_tlst,
                        wargs.batch_size,
                        batch_type=wargs.batch_type,
                        bow=wargs.trg_bow,
                        batch_sort=False,
                        gpu_ids=device_ids)
    wlog('Sentence-pairs count in training data: {}'.format(
        len(train_src_tlst)))

    batch_valid = None
    if wargs.val_prefix is not None:
        val_src_file = os.path.join(
            wargs.val_tst_dir, '{}.{}'.format(wargs.val_prefix,
                                              wargs.val_src_suffix))
        val_trg_file = os.path.join(
            wargs.val_tst_dir, '{}.{}'.format(wargs.val_prefix,
                                              wargs.val_ref_suffix))
        val_src_file, val_trg_file = os.path.abspath(
            val_src_file), os.path.abspath(val_trg_file)
        wlog('\nPreparing validation set from {} and {} ... '.format(
            val_src_file, val_trg_file))
        valid_src_tlst, valid_trg_tlst = wrap_data(
            wargs.val_tst_dir,
            wargs.val_prefix,
            wargs.val_src_suffix,
            wargs.val_ref_suffix,
            src_vocab,
            trg_vocab,
            shuffle=False,
            max_seq_len=wargs.dev_max_seq_len,
            char=wargs.src_char)
        batch_valid = Input(valid_src_tlst,
                            valid_trg_tlst,
                            batch_size=wargs.valid_batch_size,
                            batch_sort=False,
                            gpu_ids=device_ids)

    batch_tests = None
    if wargs.tests_prefix is not None:
        assert isinstance(wargs.tests_prefix,
                          list), 'Test files should be list.'
        init_dir(wargs.dir_tests)
        batch_tests = {}
        for prefix in wargs.tests_prefix:
            init_dir(wargs.dir_tests + '/' + prefix)
            test_file = '{}{}.{}'.format(wargs.val_tst_dir, prefix,
                                         wargs.val_src_suffix)
            test_file = os.path.abspath(test_file)
            wlog('\nPreparing test set from {} ... '.format(test_file))
            test_src_tlst, _ = wrap_tst_data(test_file,
                                             src_vocab,
                                             char=wargs.src_char)
            batch_tests[prefix] = Input(test_src_tlst,
                                        None,
                                        batch_size=wargs.test_batch_size,
                                        batch_sort=False,
                                        gpu_ids=device_ids)
    wlog('\n## Finish to Prepare Dataset ! ##\n')

    src_emb = WordEmbedding(n_src_vcb,
                            wargs.d_src_emb,
                            wargs.input_dropout,
                            wargs.position_encoding,
                            prefix='Src')
    trg_emb = WordEmbedding(n_trg_vcb,
                            wargs.d_trg_emb,
                            wargs.input_dropout,
                            wargs.position_encoding,
                            prefix='Trg')
    # share the embedding matrix between the source and target
    if wargs.share_vocab is True: src_emb.we.weight = trg_emb.we.weight

    nmtModel = build_NMT(src_emb, trg_emb)

    if device_ids is not None:
        wlog('push model onto GPU {} ... '.format(device_ids[0]), 0)
        nmtModel_par = nn.DataParallel(nmtModel, device_ids=device_ids)
        nmtModel_par.to(device)
    else:
        wlog('push model onto CPU ... ', 0)
        nmtModel.to(tc.device('cpu'))
    wlog('done.')

    if wargs.pre_train is not None:
        wlog(wargs.pre_train)
        assert os.path.exists(wargs.pre_train)
        from tools.utils import load_model
        _dict = load_model(wargs.pre_train)
        # initializing parameters of interactive attention model
        class_dict = None
        if len(_dict) == 5:
            # model_dict, e_idx, e_bidx, n_steps, optim = _dict['model'], _dict['epoch'], _dict['batch'], _dict['steps'], _dict['optim']
            model_dict, e_idx, e_bidx, n_steps, optim = _dict
        elif len(_dict) == 4:
            model_dict, e_idx, e_bidx, optim = _dict
        for name, param in nmtModel.named_parameters():
            if name in model_dict:
                param.requires_grad = not wargs.fix_pre_params
                param.data.copy_(model_dict[name])
                # wlog('{:7} -> grad {}\t{}'.format('Model', param.requires_grad, name))
            elif name.endswith('map_vocab.weight'):
                if class_dict is not None:
                    param.requires_grad = not wargs.fix_pre_params
                    param.data.copy_(class_dict['map_vocab.weight'])
                    # wlog('{:7} -> grad {}\t{}'.format('Model', param.requires_grad, name))
            elif name.endswith('map_vocab.bias'):
                if class_dict is not None:
                    param.requires_grad = not wargs.fix_pre_params
                    param.data.copy_(class_dict['map_vocab.bias'])
                    # wlog('{:7} -> grad {}\t{}'.format('Model', param.requires_grad, name))
            else:
                init_params(param,
                            name,
                            init_D=wargs.param_init_D,
                            a=float(wargs.u_gain))

        # wargs.start_epoch = e_idx + 1
        # # 不重新开始
        # optim.n_current_steps = 0

    else:
        optim = Optim(wargs.opt_mode, wargs.learning_rate, wargs.max_grad_norm)
        for n, p in nmtModel.named_parameters():
            # bias can not be initialized uniformly
            if 'norm' in n:
                wlog('ignore layer norm init ...')
                continue
            if 'emb' in n:
                wlog('ignore word embedding weight init ...')
                continue
            if 'vcb_proj' in n:
                wlog('ignore vcb_proj weight init ...')
                continue
            init_params(p, n, init_D=wargs.param_init_D, a=float(wargs.u_gain))
            # if wargs.encoder_type != 'att' and wargs.decoder_type != 'att':
            #    init_params(p, n, init_D=wargs.param_init_D, a=float(wargs.u_gain))

    # wlog(nmtModel)
    wlog(optim)
    pcnt1 = len([p for p in nmtModel.parameters()])
    pcnt2 = sum([p.nelement() for p in nmtModel.parameters()])
    wlog('parameters number: {}/{}'.format(pcnt1, pcnt2))

    # wlog('\n' + '*' * 30 + ' trainable parameters ' + '*' * 30)
    # for n, p in nmtModel.named_parameters():
    #     if p.requires_grad: wlog('{:60} : {}'.format(n, p.size()))
    opt_state = None
    if wargs.pre_train:
        opt_state = optim.optimizer.state_dict()

    if wargs.use_reinfore_ce is False:
        criterion = LabelSmoothingCriterion(
            trg_emb.n_vocab, label_smoothing=wargs.label_smoothing)
    else:
        word2vec = tc.load(wargs.word2vec_weight)['w2v']
        # criterion = Word2VecDistanceCriterion(word2vec)
        criterion = CosineDistance(word2vec)

    if device_ids is not None:
        wlog('push criterion onto GPU {} ... '.format(device_ids[0]), 0)
        criterion = criterion.to(device)
        wlog('done.')
    # if wargs.reinfore_type == 0 or wargs.reinfore_type == 1:
    #     param = list(nmtModel.parameters())
    # else:
    #     param = list(nmtModel.parameters()) + list(criterion.parameters())
    param = list(nmtModel.parameters())
    optim.init_optimizer(param)

    lossCompute = MultiGPULossCompute(
        nmtModel.generator,
        criterion,
        wargs.d_model if wargs.decoder_type == 'att' else 2 * wargs.d_enc_hid,
        n_trg_vcb,
        trg_emb,
        nmtModel.bowMapper,
        loss_norm=wargs.loss_norm,
        chunk_size=wargs.chunk_size,
        device_ids=device_ids)

    trainer = Trainer(nmtModel_par, batch_train, vocabs, optim, lossCompute,
                      nmtModel, batch_valid, batch_tests, writer)

    trainer.train()
    writer.close()
Esempio n. 6
0
def main():

    #if wargs.ss_type is not None: assert wargs.model == 1, 'Only rnnsearch support schedule sample'
    init_dir(wargs.dir_model)
    init_dir(wargs.dir_valid)

    src = os.path.join(
        wargs.dir_data, '{}.{}'.format(wargs.train_prefix,
                                       wargs.train_src_suffix))
    trg = os.path.join(
        wargs.dir_data, '{}.{}'.format(wargs.train_prefix,
                                       wargs.train_trg_suffix))
    vocabs = {}
    wlog('\n[o/Subword] Preparing source vocabulary from {} ... '.format(src))
    src_vocab = extract_vocab(src,
                              wargs.src_dict,
                              wargs.src_dict_size,
                              wargs.max_seq_len,
                              char=wargs.src_char)
    wlog('\n[o/Subword] Preparing target vocabulary from {} ... '.format(trg))
    trg_vocab = extract_vocab(trg, wargs.trg_dict, wargs.trg_dict_size,
                              wargs.max_seq_len)
    src_vocab_size, trg_vocab_size = src_vocab.size(), trg_vocab.size()
    wlog('Vocabulary size: |source|={}, |target|={}'.format(
        src_vocab_size, trg_vocab_size))
    vocabs['src'], vocabs['trg'] = src_vocab, trg_vocab

    wlog('\nPreparing training set from {} and {} ... '.format(src, trg))
    trains = {}
    train_src_tlst, train_trg_tlst = wrap_data(wargs.dir_data,
                                               wargs.train_prefix,
                                               wargs.train_src_suffix,
                                               wargs.train_trg_suffix,
                                               src_vocab,
                                               trg_vocab,
                                               max_seq_len=wargs.max_seq_len,
                                               char=wargs.src_char)
    '''
    list [torch.LongTensor (sentence), torch.LongTensor, torch.LongTensor, ...]
    no padding
    '''
    batch_train = Input(train_src_tlst,
                        train_trg_tlst,
                        wargs.batch_size,
                        batch_sort=True)
    wlog('Sentence-pairs count in training data: {}'.format(
        len(train_src_tlst)))

    batch_valid = None
    if wargs.val_prefix is not None:
        val_src_file = '{}{}.{}'.format(wargs.val_tst_dir, wargs.val_prefix,
                                        wargs.val_src_suffix)
        val_trg_file = '{}{}.{}'.format(wargs.val_tst_dir, wargs.val_prefix,
                                        wargs.val_ref_suffix)
        wlog('\nPreparing validation set from {} and {} ... '.format(
            val_src_file, val_trg_file))
        valid_src_tlst, valid_trg_tlst = wrap_data(
            wargs.val_tst_dir,
            wargs.val_prefix,
            wargs.val_src_suffix,
            wargs.val_ref_suffix,
            src_vocab,
            trg_vocab,
            shuffle=False,
            sort_data=False,
            max_seq_len=wargs.dev_max_seq_len,
            char=wargs.src_char)
        batch_valid = Input(valid_src_tlst,
                            valid_trg_tlst,
                            1,
                            volatile=True,
                            batch_sort=False)

    batch_tests = None
    if wargs.tests_prefix is not None:
        assert isinstance(wargs.tests_prefix,
                          list), 'Test files should be list.'
        init_dir(wargs.dir_tests)
        batch_tests = {}
        for prefix in wargs.tests_prefix:
            init_dir(wargs.dir_tests + '/' + prefix)
            test_file = '{}{}.{}'.format(wargs.val_tst_dir, prefix,
                                         wargs.val_src_suffix)
            wlog('\nPreparing test set from {} ... '.format(test_file))
            test_src_tlst, _ = wrap_tst_data(test_file,
                                             src_vocab,
                                             char=wargs.src_char)
            batch_tests[prefix] = Input(test_src_tlst,
                                        None,
                                        1,
                                        volatile=True,
                                        batch_sort=False)
    wlog('\n## Finish to Prepare Dataset ! ##\n')

    nmtModel = NMT(src_vocab_size, trg_vocab_size)

    if wargs.pre_train is not None:

        assert os.path.exists(wargs.pre_train)

        _dict = _load_model(wargs.pre_train)
        # initializing parameters of interactive attention model
        class_dict = None
        if len(_dict) == 4: model_dict, eid, bid, optim = _dict
        elif len(_dict) == 5:
            model_dict, class_dict, eid, bid, optim = _dict
        for name, param in nmtModel.named_parameters():
            if name in model_dict:
                param.requires_grad = not wargs.fix_pre_params
                param.data.copy_(model_dict[name])
                wlog('{:7} -> grad {}\t{}'.format('Model', param.requires_grad,
                                                  name))
            elif name.endswith('map_vocab.weight'):
                if class_dict is not None:
                    param.requires_grad = not wargs.fix_pre_params
                    param.data.copy_(class_dict['map_vocab.weight'])
                    wlog('{:7} -> grad {}\t{}'.format('Model',
                                                      param.requires_grad,
                                                      name))
            elif name.endswith('map_vocab.bias'):
                if class_dict is not None:
                    param.requires_grad = not wargs.fix_pre_params
                    param.data.copy_(class_dict['map_vocab.bias'])
                    wlog('{:7} -> grad {}\t{}'.format('Model',
                                                      param.requires_grad,
                                                      name))
            else:
                init_params(param, name, True)

        wargs.start_epoch = eid + 1

    else:
        for n, p in nmtModel.named_parameters():
            init_params(p, n, True)
        optim = Optim(wargs.opt_mode,
                      wargs.learning_rate,
                      wargs.max_grad_norm,
                      learning_rate_decay=wargs.learning_rate_decay,
                      start_decay_from=wargs.start_decay_from,
                      last_valid_bleu=wargs.last_valid_bleu,
                      model=wargs.model)

    if wargs.gpu_id is not None:
        wlog('Push model onto GPU {} ... '.format(wargs.gpu_id), 0)
        nmtModel.cuda()
    else:
        wlog('Push model onto CPU ... ', 0)
        nmtModel.cpu()

    wlog('done.')

    wlog(nmtModel)
    wlog(optim)
    pcnt1 = len([p for p in nmtModel.parameters()])
    pcnt2 = sum([p.nelement() for p in nmtModel.parameters()])
    wlog('Parameters number: {}/{}'.format(pcnt1, pcnt2))

    optim.init_optimizer(nmtModel.parameters())

    trainer = Trainer(nmtModel, batch_train, vocabs, optim, batch_valid,
                      batch_tests)

    trainer.train()
Esempio n. 7
0
def main():

    # Check if CUDA is available
    if cuda.is_available():
        wlog('CUDA is available, specify device by gpu_id argument (i.e. gpu_id=[3])')
    else:
        wlog('Warning: CUDA is not available, try CPU')

    if wargs.gpu_id:
        cuda.set_device(wargs.gpu_id[0])
        wlog('Using GPU {}'.format(wargs.gpu_id[0]))

    init_dir(wargs.dir_model)
    init_dir(wargs.dir_valid)
    init_dir(wargs.dir_tests)
    for prefix in wargs.tests_prefix:
        if not prefix == wargs.val_prefix: init_dir(wargs.dir_tests + '/' + prefix)

    wlog('Preparing data ... ', 0)

    train_srcD_file = wargs.dir_data + 'train.10k.zh5'
    wlog('\nPreparing source vocabulary from {} ... '.format(train_srcD_file))
    src_vocab = extract_vocab(train_srcD_file, wargs.src_dict, wargs.src_dict_size)

    train_trgD_file = wargs.dir_data + 'train.10k.en5'
    wlog('\nPreparing target vocabulary from {} ... '.format(train_trgD_file))
    trg_vocab = extract_vocab(train_trgD_file, wargs.trg_dict, wargs.trg_dict_size)

    train_src_file = wargs.dir_data + 'train.10k.zh0'
    train_trg_file = wargs.dir_data + 'train.10k.en0'
    wlog('\nPreparing training set from {} and {} ... '.format(train_src_file, train_trg_file))
    train_src_tlst, train_trg_tlst = wrap_data(train_src_file, train_trg_file, src_vocab, trg_vocab)
    #list [torch.LongTensor (sentence), torch.LongTensor, torch.LongTensor, ...], no padding
    wlog('Sentence-pairs count in training data: {}'.format(len(train_src_tlst)))
    src_vocab_size, trg_vocab_size = src_vocab.size(), trg_vocab.size()
    wlog('Vocabulary size: |source|={}, |target|={}'.format(src_vocab_size, trg_vocab_size))
    batch_train = Input(train_src_tlst, train_trg_tlst, wargs.batch_size)

    tests_data = None
    if wargs.tests_prefix is not None:
        tests_data = {}
        for prefix in wargs.tests_prefix:
            test_file = wargs.val_tst_dir + prefix + '.src'
            test_src_tlst, _ = val_wrap_data(test_file, src_vocab)
            # we select best model by nist03 testing data
            if prefix == wargs.val_prefix:
                wlog('\nPreparing model-select set from {} ... '.format(test_file))
                batch_valid = Input(test_src_tlst, None, 1, volatile=True, prefix=prefix)
            else:
                wlog('\nPreparing test set from {} ... '.format(test_file))
                tests_data[prefix] = Input(test_src_tlst, None, 1, volatile=True)

    nmtModel = NMT()
    classifier = Classifier(wargs.out_size, trg_vocab_size)

    if wargs.pre_train:

        model_dict, class_dict, eid, bid, optim = load_pytorch_model(wargs.pre_train)
        if isinstance(optim, list): _, _, optim = optim
        # initializing parameters of interactive attention model
        for p in nmtModel.named_parameters(): p[1].data = model_dict[p[0]]
        for p in classifier.named_parameters(): p[1].data = class_dict[p[0]]
        #wargs.start_epoch = eid + 1
    else:

        for p in nmtModel.parameters(): init_params(p, uniform=True)
        for p in classifier.parameters(): init_params(p, uniform=True)
        optim = Optim(
            wargs.opt_mode, wargs.learning_rate, wargs.max_grad_norm,
            learning_rate_decay=wargs.learning_rate_decay,
            start_decay_from=wargs.start_decay_from,
            last_valid_bleu=wargs.last_valid_bleu
        )

    if wargs.gpu_id:
        wlog('Push model onto GPU ... ')
        nmtModel.cuda()
        classifier.cuda()
    else:
        wlog('Push model onto CPU ... ')
        nmtModel.cpu()
        classifier.cuda()

    nmtModel.classifier = classifier
    wlog(nmtModel)
    pcnt1 = len([p for p in nmtModel.parameters()])
    pcnt2 = sum([p.nelement() for p in nmtModel.parameters()])
    wlog('Parameters number: {}/{}'.format(pcnt1, pcnt2))

    optim.init_optimizer(nmtModel.parameters())

    #tor = Translator(nmtModel, src_vocab.idx2key, trg_vocab.idx2key)
    #tor.trans_tests(tests_data, pre_dict['epoch'], pre_dict['batch'])

    trainer = Trainer(nmtModel, src_vocab.idx2key, trg_vocab.idx2key, optim, trg_vocab_size)

    dev_src0 = wargs.dir_data + 'dev.1k.zh0'
    dev_trg0 = wargs.dir_data + 'dev.1k.en0'
    wlog('\nPreparing dev set for tuning from {} and {} ... '.format(dev_src0, dev_trg0))
    dev_src0, dev_trg0 = wrap_data(dev_src0, dev_trg0, src_vocab, trg_vocab)
    wlog(len(train_src_tlst))
    # add 1000 to train
    train_all_chunks = (train_src_tlst, train_trg_tlst)
    dh = DataHisto(train_all_chunks)

    dev_src1 = wargs.dir_data + 'dev.1k.zh1'
    dev_trg1 = wargs.dir_data + 'dev.1k.en1'
    wlog('\nPreparing dev set for tuning from {} and {} ... '.format(dev_src1, dev_trg1))
    dev_src1, dev_trg1 = wrap_data(dev_src1, dev_trg1, src_vocab, trg_vocab)

    dev_src2 = wargs.dir_data + 'dev.1k.zh2'
    dev_trg2 = wargs.dir_data + 'dev.1k.en2'
    wlog('\nPreparing dev set for tuning from {} and {} ... '.format(dev_src2, dev_trg2))
    dev_src2, dev_trg2 = wrap_data(dev_src2, dev_trg2, src_vocab, trg_vocab)

    dev_src3 = wargs.dir_data + 'dev.1k.zh3'
    dev_trg3 = wargs.dir_data + 'dev.1k.en3'
    wlog('\nPreparing dev set for tuning from {} and {} ... '.format(dev_src3, dev_trg3))
    dev_src3, dev_trg3 = wrap_data(dev_src3, dev_trg3, src_vocab, trg_vocab)

    dev_src4 = wargs.dir_data + 'dev.1k.zh4'
    dev_trg4 = wargs.dir_data + 'dev.1k.en4'
    wlog('\nPreparing dev set for tuning from {} and {} ... '.format(dev_src4, dev_trg4))
    dev_src4, dev_trg4 = wrap_data(dev_src4, dev_trg4, src_vocab, trg_vocab)
    wlog(len(dev_src4+dev_src3+dev_src2+dev_src1+dev_src0))
    dev_input = Input(dev_src4+dev_src3+dev_src2+dev_src1+dev_src0, dev_trg4+dev_trg3+dev_trg2+dev_trg1+dev_trg0, wargs.batch_size)
    trainer.train(dh, dev_input, 0, batch_valid, tests_data, merge=True, name='DH_{}'.format('dev'))

    '''
    chunk_size = 1000
    rand_ids = tc.randperm(len(train_src_tlst))[:chunk_size * 1000]
    rand_ids = rand_ids.split(chunk_size)
    #train_chunks = [(dev_src, dev_trg)]
    train_chunks = []
    for k in range(len(rand_ids)):
        rand_id = rand_ids[k]
        chunk_src_tlst = [train_src_tlst[i] for i in rand_id]
        chunk_trg_tlst = [train_trg_tlst[i] for i in rand_id]
        #wlog('Sentence-pairs count in training data: {}'.format(len(src_samples_train)))
        #batch_train = Input(train_src_tlst, train_trg_tlst, wargs.batch_size)
        #batch_train = Input(src_samples_train, trg_samples_train, wargs.batch_size)
        train_chunks.append((chunk_src_tlst, chunk_trg_tlst))

    chunk_D0 = train_chunks[0]
    dh = DataHisto(chunk_D0)
    c0_input = Input(chunk_D0[0], chunk_D0[1], wargs.batch_size)
    trainer.train(dh, c0_input, 0, batch_valid, tests_data, merge=False, name='DH_{}'.format(0))
    for k in range(1, len(train_chunks)):
        wlog('*' * 30, False)
        wlog(' Next Data {} '.format(k), False)
        wlog('*' * 30)
        chunk_Dk = train_chunks[k]
        ck_input = Input(chunk_Dk[0], chunk_Dk[1], wargs.batch_size)
        trainer.train(dh, ck_input, k, batch_valid, tests_data, merge=True, name='DH_{}'.format(k))
        dh.add_batch_data(chunk_Dk)
    '''

    if tests_data and wargs.final_test:

        bestModel = NMT()
        classifier = Classifier(wargs.out_size, trg_vocab_size)

        assert os.path.exists(wargs.best_model)
        model_dict = tc.load(wargs.best_model)

        best_model_dict = model_dict['model']
        best_model_dict = {k: v for k, v in best_model_dict.items() if 'classifier' not in k}

        bestModel.load_state_dict(best_model_dict)
        classifier.load_state_dict(model_dict['class'])

        if wargs.gpu_id:
            wlog('Push NMT model onto GPU ... ')
            bestModel.cuda()
            classifier.cuda()
        else:
            wlog('Push NMT model onto CPU ... ')
            bestModel.cpu()
            classifier.cpu()

        bestModel.classifier = classifier

        tor = Translator(bestModel, src_vocab.idx2key, trg_vocab.idx2key)
        tor.trans_tests(tests_data, model_dict['epoch'], model_dict['batch'])
Esempio n. 8
0
def main():

    init_dir(wargs.dir_model)
    init_dir(wargs.dir_valid)

    vocab_data = {}
    train_srcD_file = wargs.src_vocab_from
    wlog('\nPreparing source vocabulary from {} ... '.format(train_srcD_file))
    src_vocab = extract_vocab(train_srcD_file, wargs.src_dict,
                              wargs.src_dict_size)
    vocab_data['src'] = src_vocab

    train_trgD_file = wargs.trg_vocab_from
    wlog('\nPreparing target vocabulary from {} ... '.format(train_trgD_file))
    trg_vocab = extract_vocab(train_trgD_file, wargs.trg_dict,
                              wargs.trg_dict_size)
    vocab_data['trg'] = trg_vocab

    train_src_file = wargs.train_src
    train_trg_file = wargs.train_trg
    wlog('\nPreparing training set from {} and {} ... '.format(
        train_src_file, train_trg_file))
    train_src_tlst, train_trg_tlst = wrap_data(train_src_file,
                                               train_trg_file,
                                               src_vocab,
                                               trg_vocab,
                                               max_seq_len=wargs.max_seq_len)
    '''
    list [torch.LongTensor (sentence), torch.LongTensor, torch.LongTensor, ...]
    no padding
    '''
    '''
    devs = {}
    dev_src = wargs.val_tst_dir + wargs.val_prefix + '.src'
    dev_trg = wargs.val_tst_dir + wargs.val_prefix + '.ref0'
    wlog('\nPreparing dev set for tuning from {} and {} ... '.format(dev_src, dev_trg))
    dev_src, dev_trg = wrap_data(dev_src, dev_trg, src_vocab, trg_vocab)
    devs['src'], devs['trg'] = dev_src, dev_trg
    '''

    valid_file = '{}{}.{}'.format(wargs.val_tst_dir, wargs.val_prefix,
                                  wargs.val_src_suffix)
    wlog('\nPreparing validation set from {} ... '.format(valid_file))
    valid_src_tlst, valid_src_lens = val_wrap_data(valid_file, src_vocab)

    wlog('Sentence-pairs count in training data: {}'.format(
        len(train_src_tlst)))
    src_vocab_size, trg_vocab_size = vocab_data['src'].size(
    ), vocab_data['trg'].size()
    wlog('Vocabulary size: |source|={}, |target|={}'.format(
        src_vocab_size, trg_vocab_size))

    batch_train = Input(train_src_tlst, train_trg_tlst, wargs.batch_size)
    batch_valid = Input(valid_src_tlst, None, 1, volatile=True)

    tests_data = None
    if wargs.tests_prefix is not None:
        init_dir(wargs.dir_tests)
        tests_data = {}
        for prefix in wargs.tests_prefix:
            init_dir(wargs.dir_tests + '/' + prefix)
            test_file = '{}{}.{}'.format(wargs.val_tst_dir, prefix,
                                         wargs.val_src_suffix)
            wlog('Preparing test set from {} ... '.format(test_file))
            test_src_tlst, _ = val_wrap_data(test_file, src_vocab)
            tests_data[prefix] = Input(test_src_tlst, None, 1, volatile=True)
    '''
    # lookup_table on cpu to save memory
    src_lookup_table = nn.Embedding(wargs.src_dict_size + 4,
                                    wargs.src_wemb_size, padding_idx=utils.PAD).cpu()
    trg_lookup_table = nn.Embedding(wargs.trg_dict_size + 4,
                                    wargs.trg_wemb_size, padding_idx=utils.PAD).cpu()

    wlog('Lookup table on CPU ... ')
    wlog(src_lookup_table)
    wlog(trg_lookup_table)
    '''

    sv = vocab_data['src'].idx2key
    tv = vocab_data['trg'].idx2key

    nmtModel = NMT(src_vocab_size, trg_vocab_size)
    #classifier = Classifier(wargs.out_size, trg_vocab_size,
    #                        nmtModel.decoder.trg_lookup_table if wargs.copy_trg_emb is True else None)

    if wargs.pre_train:

        assert os.path.exists(wargs.pre_train)
        _dict = _load_model(wargs.pre_train)
        # initializing parameters of interactive attention model
        class_dict = None
        if len(_dict) == 4: model_dict, eid, bid, optim = _dict
        elif len(_dict) == 5:
            model_dict, class_dict, eid, bid, optim = _dict
        for name, param in nmtModel.named_parameters():
            if name in model_dict:
                param.requires_grad = not wargs.fix_pre_params
                param.data.copy_(model_dict[name])
                wlog('{:7} -> grad {}\t{}'.format('Model', param.requires_grad,
                                                  name))
            elif name.endswith('map_vocab.weight'):
                if class_dict is not None:
                    param.requires_grad = not wargs.fix_pre_params
                    param.data.copy_(class_dict['map_vocab.weight'])
                    wlog('{:7} -> grad {}\t{}'.format('Model',
                                                      param.requires_grad,
                                                      name))
            elif name.endswith('map_vocab.bias'):
                if class_dict is not None:
                    param.requires_grad = not wargs.fix_pre_params
                    param.data.copy_(class_dict['map_vocab.bias'])
                    wlog('{:7} -> grad {}\t{}'.format('Model',
                                                      param.requires_grad,
                                                      name))
            else:
                init_params(param, name, True)

        wargs.start_epoch = eid + 1

        #tor = Translator(nmtModel, sv, tv)
        #tor.trans_tests(tests_data, eid, bid)

    else:
        for n, p in nmtModel.named_parameters():
            init_params(p, n, True)
        #for n, p in classifier.named_parameters(): init_params(p, n, True)
        optim = Optim(wargs.opt_mode,
                      wargs.learning_rate,
                      wargs.max_grad_norm,
                      learning_rate_decay=wargs.learning_rate_decay,
                      start_decay_from=wargs.start_decay_from,
                      last_valid_bleu=wargs.last_valid_bleu)

    if wargs.gpu_id:
        nmtModel.cuda()
        #classifier.cuda()
        wlog('Push model onto GPU[{}] ... '.format(wargs.gpu_id[0]))
    else:
        nmtModel.cpu()
        #classifier.cpu()
        wlog('Push model onto CPU ... ')

    #nmtModel.classifier = classifier
    #nmtModel.decoder.map_vocab = classifier.map_vocab
    '''
    nmtModel.src_lookup_table = src_lookup_table
    nmtModel.trg_lookup_table = trg_lookup_table
    print nmtModel.src_lookup_table.weight.data.is_cuda

    nmtModel.classifier.init_weights(nmtModel.trg_lookup_table)
    '''

    wlog(nmtModel)
    wlog(optim)
    pcnt1 = len([p for p in nmtModel.parameters()])
    pcnt2 = sum([p.nelement() for p in nmtModel.parameters()])
    wlog('Parameters number: {}/{}'.format(pcnt1, pcnt2))

    optim.init_optimizer(nmtModel.parameters())

    #tor = Translator(nmtModel, sv, tv, wargs.search_mode)
    #tor.trans_tests(tests_data, pre_dict['epoch'], pre_dict['batch'])

    trainer = Trainer(nmtModel, batch_train, vocab_data, optim, batch_valid,
                      tests_data)

    trainer.train()
Esempio n. 9
0
def main():

    # Check if CUDA is available
    if cuda.is_available():
        wlog(
            'CUDA is available, specify device by gpu_id argument (i.e. gpu_id=[3])'
        )
    else:
        wlog('Warning: CUDA is not available, try CPU')

    if wargs.gpu_id:
        cuda.set_device(wargs.gpu_id[0])
        wlog('Using GPU {}'.format(wargs.gpu_id[0]))

    init_dir(wargs.dir_model)
    init_dir(wargs.dir_valid)
    '''
    train_srcD_file = wargs.dir_data + 'train.10k.zh5'
    wlog('\nPreparing source vocabulary from {} ... '.format(train_srcD_file))
    src_vocab = extract_vocab(train_srcD_file, wargs.src_dict, wargs.src_dict_size)

    train_trgD_file = wargs.dir_data + 'train.10k.en5'
    wlog('\nPreparing target vocabulary from {} ... '.format(train_trgD_file))
    trg_vocab = extract_vocab(train_trgD_file, wargs.trg_dict, wargs.trg_dict_size)

    train_src_file = wargs.dir_data + 'train.10k.zh0'
    train_trg_file = wargs.dir_data + 'train.10k.en0'
    wlog('\nPreparing training set from {} and {} ... '.format(train_src_file, train_trg_file))
    train_src_tlst, train_trg_tlst = wrap_data(train_src_file, train_trg_file, src_vocab, trg_vocab)
    #list [torch.LongTensor (sentence), torch.LongTensor, torch.LongTensor, ...], no padding
    wlog('Sentence-pairs count in training data: {}'.format(len(train_src_tlst)))
    src_vocab_size, trg_vocab_size = src_vocab.size(), trg_vocab.size()
    wlog('Vocabulary size: |source|={}, |target|={}'.format(src_vocab_size, trg_vocab_size))
    batch_train = Input(train_src_tlst, train_trg_tlst, wargs.batch_size)
    '''

    src = os.path.join(
        wargs.dir_data, '{}.{}'.format(wargs.train_prefix,
                                       wargs.train_src_suffix))
    trg = os.path.join(
        wargs.dir_data, '{}.{}'.format(wargs.train_prefix,
                                       wargs.train_trg_suffix))
    vocabs = {}
    wlog('\nPreparing source vocabulary from {} ... '.format(src))
    src_vocab = extract_vocab(src, wargs.src_dict, wargs.src_dict_size)
    wlog('\nPreparing target vocabulary from {} ... '.format(trg))
    trg_vocab = extract_vocab(trg, wargs.trg_dict, wargs.trg_dict_size)
    src_vocab_size, trg_vocab_size = src_vocab.size(), trg_vocab.size()
    wlog('Vocabulary size: |source|={}, |target|={}'.format(
        src_vocab_size, trg_vocab_size))
    vocabs['src'], vocabs['trg'] = src_vocab, trg_vocab

    wlog('\nPreparing training set from {} and {} ... '.format(src, trg))
    trains = {}
    train_src_tlst, train_trg_tlst = wrap_data(wargs.dir_data,
                                               wargs.train_prefix,
                                               wargs.train_src_suffix,
                                               wargs.train_trg_suffix,
                                               src_vocab,
                                               trg_vocab,
                                               max_seq_len=wargs.max_seq_len)
    '''
    list [torch.LongTensor (sentence), torch.LongTensor, torch.LongTensor, ...]
    no padding
    '''
    batch_train = Input(train_src_tlst,
                        train_trg_tlst,
                        wargs.batch_size,
                        batch_sort=True)
    wlog('Sentence-pairs count in training data: {}'.format(
        len(train_src_tlst)))

    batch_valid = None
    if wargs.val_prefix is not None:
        val_src_file = '{}{}.{}'.format(wargs.val_tst_dir, wargs.val_prefix,
                                        wargs.val_src_suffix)
        val_trg_file = '{}{}.{}'.format(wargs.val_tst_dir, wargs.val_prefix,
                                        wargs.val_ref_suffix)
        wlog('\nPreparing validation set from {} and {} ... '.format(
            val_src_file, val_trg_file))
        valid_src_tlst, valid_trg_tlst = wrap_data(
            wargs.val_tst_dir,
            wargs.val_prefix,
            wargs.val_src_suffix,
            wargs.val_ref_suffix,
            src_vocab,
            trg_vocab,
            shuffle=False,
            sort_data=False,
            max_seq_len=wargs.dev_max_seq_len)
        batch_valid = Input(valid_src_tlst,
                            valid_trg_tlst,
                            1,
                            volatile=True,
                            batch_sort=False)

    batch_tests = None
    if wargs.tests_prefix is not None:
        assert isinstance(wargs.tests_prefix,
                          list), 'Test files should be list.'
        init_dir(wargs.dir_tests)
        batch_tests = {}
        for prefix in wargs.tests_prefix:
            init_dir(wargs.dir_tests + '/' + prefix)
            test_file = '{}{}.{}'.format(wargs.val_tst_dir, prefix,
                                         wargs.val_src_suffix)
            wlog('\nPreparing test set from {} ... '.format(test_file))
            test_src_tlst, _ = wrap_tst_data(test_file, src_vocab)
            batch_tests[prefix] = Input(test_src_tlst,
                                        None,
                                        1,
                                        volatile=True,
                                        batch_sort=False)
    wlog('\n## Finish to Prepare Dataset ! ##\n')

    nmtModel = NMT(src_vocab_size, trg_vocab_size)
    if wargs.pre_train is not None:

        assert os.path.exists(wargs.pre_train)

        _dict = _load_model(wargs.pre_train)
        # initializing parameters of interactive attention model
        class_dict = None
        if len(_dict) == 4: model_dict, eid, bid, optim = _dict
        elif len(_dict) == 5:
            model_dict, class_dict, eid, bid, optim = _dict
        for name, param in nmtModel.named_parameters():
            if name in model_dict:
                param.requires_grad = not wargs.fix_pre_params
                param.data.copy_(model_dict[name])
                wlog('{:7} -> grad {}\t{}'.format('Model', param.requires_grad,
                                                  name))
            elif name.endswith('map_vocab.weight'):
                if class_dict is not None:
                    param.requires_grad = not wargs.fix_pre_params
                    param.data.copy_(class_dict['map_vocab.weight'])
                    wlog('{:7} -> grad {}\t{}'.format('Model',
                                                      param.requires_grad,
                                                      name))
            elif name.endswith('map_vocab.bias'):
                if class_dict is not None:
                    param.requires_grad = not wargs.fix_pre_params
                    param.data.copy_(class_dict['map_vocab.bias'])
                    wlog('{:7} -> grad {}\t{}'.format('Model',
                                                      param.requires_grad,
                                                      name))
            else:
                init_params(param, name, True)

        wargs.start_epoch = eid + 1

    else:
        for n, p in nmtModel.named_parameters():
            init_params(p, n, True)
        optim = Optim(wargs.opt_mode,
                      wargs.learning_rate,
                      wargs.max_grad_norm,
                      learning_rate_decay=wargs.learning_rate_decay,
                      start_decay_from=wargs.start_decay_from,
                      last_valid_bleu=wargs.last_valid_bleu)
        optim.init_optimizer(nmtModel.parameters())

    if wargs.gpu_id:
        wlog('Push model onto GPU {} ... '.format(wargs.gpu_id), 0)
        nmtModel.cuda()
    else:
        wlog('Push model onto CPU ... ', 0)
        nmtModel.cpu()

    wlog('done.')
    wlog(nmtModel)
    wlog(optim)
    pcnt1 = len([p for p in nmtModel.parameters()])
    pcnt2 = sum([p.nelement() for p in nmtModel.parameters()])
    wlog('Parameters number: {}/{}'.format(pcnt1, pcnt2))

    trainer = Trainer(nmtModel,
                      src_vocab.idx2key,
                      trg_vocab.idx2key,
                      optim,
                      trg_vocab_size,
                      valid_data=batch_valid,
                      tests_data=batch_tests)

    # add 1000 to train
    train_all_chunks = (train_src_tlst, train_trg_tlst)
    dh = DataHisto(train_all_chunks)
    '''
    dev_src0 = wargs.dir_data + 'dev.1k.zh0'
    dev_trg0 = wargs.dir_data + 'dev.1k.en0'
    wlog('\nPreparing dev set for tuning from {} and {} ... '.format(dev_src0, dev_trg0))
    dev_src0, dev_trg0 = wrap_data(dev_src0, dev_trg0, src_vocab, trg_vocab)
    wlog(len(train_src_tlst))

    dev_src1 = wargs.dir_data + 'dev.1k.zh1'
    dev_trg1 = wargs.dir_data + 'dev.1k.en1'
    wlog('\nPreparing dev set for tuning from {} and {} ... '.format(dev_src1, dev_trg1))
    dev_src1, dev_trg1 = wrap_data(dev_src1, dev_trg1, src_vocab, trg_vocab)

    dev_src2 = wargs.dir_data + 'dev.1k.zh2'
    dev_trg2 = wargs.dir_data + 'dev.1k.en2'
    wlog('\nPreparing dev set for tuning from {} and {} ... '.format(dev_src2, dev_trg2))
    dev_src2, dev_trg2 = wrap_data(dev_src2, dev_trg2, src_vocab, trg_vocab)

    dev_src3 = wargs.dir_data + 'dev.1k.zh3'
    dev_trg3 = wargs.dir_data + 'dev.1k.en3'
    wlog('\nPreparing dev set for tuning from {} and {} ... '.format(dev_src3, dev_trg3))
    dev_src3, dev_trg3 = wrap_data(dev_src3, dev_trg3, src_vocab, trg_vocab)

    dev_src4 = wargs.dir_data + 'dev.1k.zh4'
    dev_trg4 = wargs.dir_data + 'dev.1k.en4'
    wlog('\nPreparing dev set for tuning from {} and {} ... '.format(dev_src4, dev_trg4))
    dev_src4, dev_trg4 = wrap_data(dev_src4, dev_trg4, src_vocab, trg_vocab)
    wlog(len(dev_src4+dev_src3+dev_src2+dev_src1+dev_src0))
    batch_dev = Input(dev_src4+dev_src3+dev_src2+dev_src1+dev_src0, dev_trg4+dev_trg3+dev_trg2+dev_trg1+dev_trg0, wargs.batch_size)
    '''

    batch_dev = None
    assert wargs.dev_prefix is not None, 'Requires development to tuning.'
    dev_src_file = '{}{}.{}'.format(wargs.val_tst_dir, wargs.dev_prefix,
                                    wargs.val_src_suffix)
    dev_trg_file = '{}{}.{}'.format(wargs.val_tst_dir, wargs.dev_prefix,
                                    wargs.val_ref_suffix)
    wlog('\nPreparing dev set from {} and {} ... '.format(
        dev_src_file, dev_trg_file))
    valid_src_tlst, valid_trg_tlst = wrap_data(
        wargs.val_tst_dir,
        wargs.dev_prefix,
        wargs.val_src_suffix,
        wargs.val_ref_suffix,
        src_vocab,
        trg_vocab,
        shuffle=True,
        sort_data=True,
        max_seq_len=wargs.dev_max_seq_len)
    batch_dev = Input(valid_src_tlst,
                      valid_trg_tlst,
                      wargs.batch_size,
                      batch_sort=True)

    trainer.train(dh, batch_dev, 0, merge=True, name='DH_{}'.format('dev'))
    '''