示例#1
0
    def train(self):

        wlog('Start training ... ')
        assert wargs.sample_size < wargs.batch_size, 'Batch size < sample count'
        # [low, high)
        batch_count = len(self.train_data) if wargs.fine_tune is False else 0
        batch_count_domain = len(self.train_data_domain)
        batch_start_sample = tc.randperm(batch_count + batch_count_domain)[0]
        wlog('Randomly select {} samples in the {}th/{} batch'.format(
            wargs.sample_size, batch_start_sample,
            batch_count + batch_count_domain))
        bidx, eval_cnt = 0, [0]
        wlog('Self-normalization alpha -> {}'.format(wargs.self_norm_alpha))

        train_start = time.time()
        wlog('')
        wlog('#' * 120)
        wlog('#' * 30, False)
        wlog(' Start Training ', False)
        wlog('#' * 30)
        wlog('#' * 120)
        #DANN
        loss_domain = tc.nn.NLLLoss()

        for epoch in range(wargs.start_epoch, wargs.max_epochs + 1):

            epoch_start = time.time()

            # train for one epoch on the training data
            wlog('')
            wlog('$' * 30, False)
            wlog(' Epoch [{}/{}] '.format(epoch, wargs.max_epochs), False)
            wlog('$' * 30)

            if wargs.epoch_shuffle and epoch > wargs.epoch_shuffle_minibatch:
                self.train_data.shuffle()
                self.train_data_domain.shuffle()
            # shuffle the original batch
            shuffled_batch_idx = tc.randperm(batch_count + batch_count_domain)

            sample_size = wargs.sample_size
            epoch_loss, epoch_trg_words, epoch_num_correct = 0, 0, 0
            show_loss, show_src_words, show_trg_words, show_correct_num = 0, 0, 0, 0
            domain_loss = 0
            sample_spend, eval_spend, epoch_bidx = 0, 0, 0
            show_start = time.time()

            for name, par in self.model.named_parameters():
                if name.split('.')[0] != "domain_discriminator":
                    par.requires_grad = True
                else:
                    par.requires_grad = False

            for k in range(batch_count + batch_count_domain):

                bidx += 1
                epoch_bidx = k + 1
                batch_idx = shuffled_batch_idx[
                    k] if epoch >= wargs.epoch_shuffle_minibatch else k
                #p = float(k + epoch * (batch_count+batch_count_domain)) / wargs.max_epochs / (batch_count+batch_count_domain)
                #alpha = 2. / (1. + np.exp(-10 * p)) - 1
                if batch_idx < batch_count:

                    # (max_slen_batch, batch_size)
                    _, srcs, trgs, slens, srcs_m, trgs_m = self.train_data[
                        batch_idx]

                    self.model.zero_grad()
                    # (max_tlen_batch - 1, batch_size, out_size)
                    #forward compute DANN out of domain
                    decoder_outputs, domain_outputs = self.model(
                        srcs,
                        trgs[:-1],
                        srcs_m,
                        trgs_m[:-1],
                        'OUT',
                        alpha=wargs.alpha)
                    if len(decoder_outputs) == 2:
                        (decoder_outputs, _checks) = decoder_outputs
                    this_bnum = decoder_outputs.size(1)

                    #batch_loss, grad_output, batch_correct_num = memory_efficient(
                    #    outputs, trgs[1:], trgs_m[1:], self.model.classifier)
                    #backward compute, now we have the grad
                    batch_loss, batch_correct_num, batch_log_norm = self.model.classifier.snip_back_prop(
                        decoder_outputs, trgs[1:], trgs_m[1:], wargs.snip_size)

                    #self.model.zero_grad()
                    #domain_outputs = self.model(srcs, None, srcs_m, None, 'OUT', alpha=alpha)
                    #if wargs.cross_entropy is True:
                    #   domain_label = tc.zeros(len(domain_outputs))
                    #  domain_label = domain_label.long()
                    # domain_label = domain_label.cuda()
                    #domainv_label = Variable(domain_label, volatile=False)

                    #    domain_loss_a = loss_domain(tc.log(domain_outputs), domainv_label)
                    #   domain_loss_a.backward(retain_graph = True if wargs.max_entropy is True else False)
                    if wargs.max_entropy is True:
                        #try to max this entropy
                        lam = wargs.alpha if epoch > 1 else 0
                        domain_loss_b = lam * tc.dot(tc.log(domain_outputs),
                                                     domain_outputs)
                        domain_loss_b.backward()

                    batch_src_words = srcs.data.ne(PAD).sum()
                    assert batch_src_words == slens.data.sum()
                    batch_trg_words = trgs[1:].data.ne(PAD).sum()

                elif batch_idx >= batch_count:

                    #domain_loss_out.backward(retain_graph=True)
                    #DANN in domain compute
                    _, srcs_domain, trgs_domain, slens_domain, srcs_m_domain, trgs_m_domain = self.train_data_domain[
                        batch_idx - batch_count]
                    self.model.zero_grad()
                    decoder_outputs, domain_outputs = self.model(
                        srcs_domain,
                        trgs_domain[:-1],
                        srcs_m_domain,
                        trgs_m_domain[:-1],
                        'IN',
                        alpha=wargs.alpha)
                    if len(decoder_outputs) == 2:
                        (decoder_outputs, _checks) = decoder_outputs
                    this_bnum = decoder_outputs.size(1)
                    batch_loss, batch_correct_num, batch_log_norm = self.model.classifier.snip_back_prop(
                        decoder_outputs, trgs_domain[1:], trgs_m_domain[1:],
                        wargs.snip_size)
                    #domain_outputs = self.model(srcs_domain, None, srcs_m_domain, None, 'IN', alpha=alpha)
                    #if wargs.cross_entropy is True:
                    #   domain_label = tc.ones(len(domain_outputs))
                    #  domain_label = domain_label.long()
                    # domain_label = domain_label.cuda()
                    #domainv_label = Variable(domain_label, volatile=False)

                    #    domain_loss_a = loss_domain(tc.log(domain_outputs), domainv_label)
                    #   domain_loss_a.backward(retain_graph = True if wargs.max_entropy is True else False)
                    if wargs.max_entropy is True:
                        lam = wargs.alpha if epoch > 1 else 0
                        domain_loss_b = lam * tc.dot(tc.log(domain_outputs),
                                                     domain_outputs)
                        domain_loss_b.backward()

                    batch_src_words = srcs_domain.data.ne(PAD).sum()
                    assert batch_src_words == slens_domain.data.sum()
                    batch_trg_words = trgs_domain[1:].data.ne(PAD).sum()

                _grad_nan = False
                for n, p in self.model.named_parameters():
                    if p.grad is None:
                        debug('grad None | {}'.format(n))
                        continue
                    tmp_grad = p.grad.data.cpu().numpy()
                    if numpy.isnan(tmp_grad).any(
                    ):  # we check gradient here for vanishing Gradient
                        wlog("grad contains 'nan' | {}".format(n))
                        #wlog("gradient\n{}".format(tmp_grad))
                        _grad_nan = True
                    if n == 'decoder.l_f1_0.weight' or n == 's_init.weight' or n=='decoder.l_f1_1.weight' \
                       or n == 'decoder.l_conv.0.weight' or n == 'decoder.l_f2.weight':
                        debug('grad zeros |{:5} {}'.format(
                            str(not np.any(tmp_grad)), n))

                if _grad_nan is True and wargs.dynamic_cyk_decoding is True:
                    for _i, items in enumerate(_checks):
                        wlog('step {} Variable----------------:'.format(_i))
                        #for item in items: wlog(item.cpu().data.numpy())
                        wlog('wen _check_tanh_sa ---------------')
                        wlog(items[0].cpu().data.numpy())
                        wlog('wen _check_a1_weight ---------------')
                        wlog(items[1].cpu().data.numpy())
                        wlog('wen _check_a1 ---------------')
                        wlog(items[2].cpu().data.numpy())
                        wlog('wen alpha_ij---------------')
                        wlog(items[3].cpu().data.numpy())
                        wlog('wen before_mask---------------')
                        wlog(items[4].cpu().data.numpy())
                        wlog('wen after_mask---------------')
                        wlog(items[5].cpu().data.numpy())

                #outputs.backward(grad_output)
                self.optim.step()
                #del outputs, grad_output

                show_loss += batch_loss
                #domain_loss += domain_loss_a.data.clone()[0]
                #domain_loss += domain_loss_out
                show_correct_num += batch_correct_num
                epoch_loss += batch_loss
                epoch_num_correct += batch_correct_num
                show_src_words += batch_src_words
                show_trg_words += batch_trg_words
                epoch_trg_words += batch_trg_words

                batch_log_norm = tc.mean(tc.abs(batch_log_norm))

                if epoch_bidx % wargs.display_freq == 0:
                    #print show_correct_num, show_loss, show_trg_words, show_loss/show_trg_words
                    ud = time.time() - show_start - sample_spend - eval_spend
                    wlog(
                        'Epo:{:>2}/{:>2} |[{:^5} {:^5} {:^5}k] |acc:{:5.2f}% |ppl:{:4.2f} '
                        '| |logZ|:{:.2f} '
                        '|stok/s:{:>4}/{:>2}={:>2} |ttok/s:{:>2} '
                        '|stok/sec:{:6.2f} |ttok/sec:{:6.2f} |elapsed:{:4.2f}/{:4.2f}m'
                        .format(epoch, wargs.max_epochs, epoch_bidx, batch_idx,
                                bidx / 1000,
                                (show_correct_num / show_trg_words) * 100,
                                math.exp(show_loss / show_trg_words),
                                batch_log_norm, batch_src_words, this_bnum,
                                int(batch_src_words / this_bnum),
                                int(batch_trg_words / this_bnum),
                                show_src_words / ud, show_trg_words / ud, ud,
                                (time.time() - train_start) / 60.))
                    show_loss, show_src_words, show_trg_words, show_correct_num = 0, 0, 0, 0
                    sample_spend, eval_spend = 0, 0
                    show_start = time.time()

                if epoch_bidx % wargs.sampling_freq == 0:

                    sample_start = time.time()
                    self.model.eval()
                    #self.model.classifier.eval()
                    tor = Translator(self.model, self.sv, self.tv)

                    if batch_idx < batch_count:
                        # (max_len_batch, batch_size)
                        sample_src_tensor = srcs.t()[:sample_size]
                        sample_trg_tensor = trgs.t()[:sample_size]
                        tor.trans_samples(sample_src_tensor, sample_trg_tensor,
                                          "OUT")

                    elif batch_idx >= batch_count:
                        sample_src_tensor = srcs_domain.t()[:sample_size]
                        sample_trg_tensor = trgs_domain.t()[:sample_size]
                        tor.trans_samples(sample_src_tensor, sample_trg_tensor,
                                          "IN")
                    wlog('')
                    sample_spend = time.time() - sample_start
                    self.model.train()

                # Just watch the translation of some source sentences in training data
                if wargs.if_fixed_sampling and bidx == batch_start_sample:
                    # randomly select sample_size sample from current batch
                    rand_rows = np.random.choice(this_bnum,
                                                 sample_size,
                                                 replace=False)
                    sample_src_tensor = tc.Tensor(sample_size,
                                                  srcs.size(0)).long()
                    sample_src_tensor.fill_(PAD)
                    sample_trg_tensor = tc.Tensor(sample_size,
                                                  trgs.size(0)).long()
                    sample_trg_tensor.fill_(PAD)

                    for id in xrange(sample_size):
                        sample_src_tensor[id, :] = srcs.t()[rand_rows[id], :]
                        sample_trg_tensor[id, :] = trgs.t()[rand_rows[id], :]

                if wargs.epoch_eval is not True and bidx > wargs.eval_valid_from and \
                   bidx % wargs.eval_valid_freq == 0:
                    eval_start = time.time()
                    eval_cnt[0] += 1
                    wlog('\nAmong epoch, batch [{}], [{}] eval save model ...'.
                         format(epoch_bidx, eval_cnt[0]))

                    self.mt_eval(epoch, epoch_bidx, "IN")

                    eval_spend = time.time() - eval_start

            for name, par in self.model.named_parameters():
                if name.split('.')[0] != "domain_discriminator":
                    par.requires_grad = False
                else:
                    par.requires_grad = True

            for k in range(batch_count + batch_count_domain):
                epoch_bidx = k + 1
                batch_idx = shuffled_batch_idx[
                    k] if epoch >= wargs.epoch_shuffle_minibatch else k

                if batch_idx < batch_count:

                    _, srcs, trgs, slens, srcs_m, trgs_m = self.train_data[
                        batch_idx]

                    self.model.zero_grad()
                    # (max_tlen_batch - 1, batch_size, out_size)
                    #forward compute DANN out of domain
                    domain_outputs = self.model(srcs,
                                                trgs[:-1],
                                                srcs_m,
                                                trgs_m[:-1],
                                                'OUT',
                                                alpha=wargs.alpha,
                                                adv=True)
                    #if len(decoder_outputs) == 2: (decoder_outputs, _checks) = decoder_outputs
                    #this_bnum = decoder_outputs.size(1)

                    #batch_loss, grad_output, batch_correct_num = memory_efficient(
                    #    outputs, trgs[1:], trgs_m[1:], self.model.classifier)
                    #backward compute, now we have the grad
                    #batch_loss, batch_correct_num, batch_log_norm = self.model.classifier.snip_back_prop(
                    #   decoder_outputs, trgs[1:], trgs_m[1:], wargs.snip_size)

                    #self.model.zero_grad()
                    #domain_outputs = self.model(srcs, None, srcs_m, None, 'OUT', alpha=alpha)
                    if wargs.cross_entropy is True:
                        domain_label = tc.zeros(len(domain_outputs))
                        domain_label = domain_label.long()
                        domain_label = domain_label.cuda()
                        domainv_label = Variable(domain_label, volatile=False)

                        domain_loss_a = loss_domain(tc.log(domain_outputs),
                                                    domainv_label)
                        domain_loss_a.backward()
                    #if wargs.max_entropy is True:
                    #try to max this entropy
                    #   domain_loss_b = -wargs.alpha*tc.dot(tc.log(domain_outputs), domain_outputs)
                    #  domain_loss_b.backward()

                    batch_src_words = srcs.data.ne(PAD).sum()
                    assert batch_src_words == slens.data.sum()
                    batch_trg_words = trgs[1:].data.ne(PAD).sum()

                elif batch_idx >= batch_count:

                    #domain_loss_out.backward(retain_graph=True)
                    #DANN in domain compute
                    _, srcs_domain, trgs_domain, slens_domain, srcs_m_domain, trgs_m_domain = self.train_data_domain[
                        batch_idx - batch_count]
                    self.model.zero_grad()
                    domain_outputs = self.model(srcs_domain,
                                                trgs_domain[:-1],
                                                srcs_m_domain,
                                                trgs_m_domain[:-1],
                                                'IN',
                                                alpha=wargs.alpha,
                                                adv=True)
                    #if len(decoder_outputs) == 2: (decoder_outputs, _checks) = decoder_outputs
                    #this_bnum = decoder_outputs.size(1)
                    #batch_loss, batch_correct_num, batch_log_norm = self.model.classifier.snip_back_prop(
                    #  decoder_outputs, trgs_domain[1:], trgs_m_domain[1:], wargs.snip_size)
                    #domain_outputs = self.model(srcs_domain, None, srcs_m_domain, None, 'IN', alpha=alpha)
                    if wargs.cross_entropy is True:
                        domain_label = tc.ones(len(domain_outputs))
                        domain_label = domain_label.long()
                        domain_label = domain_label.cuda()
                        domainv_label = Variable(domain_label, volatile=False)

                        domain_loss_a = loss_domain(tc.log(domain_outputs),
                                                    domainv_label)
                        domain_loss_a.backward(
                            retain_graph=True
                            if wargs.max_entropy is True else False)
                    #if wargs.max_entropy is True:
                    #   domain_loss_b = -wargs.alpha*tc.dot(tc.log(domain_outputs), domain_outputs)
                    #  domain_loss_b.backward()

                    batch_src_words = srcs_domain.data.ne(PAD).sum()
                    assert batch_src_words == slens_domain.data.sum()
                    batch_trg_words = trgs_domain[1:].data.ne(PAD).sum()

                self.optim.step()

                show_loss += batch_loss
                #domain_loss += domain_loss_a.data.clone()[0]
                #domain_loss += domain_loss_out
                show_correct_num += batch_correct_num
                epoch_loss += batch_loss
                epoch_num_correct += batch_correct_num
                show_src_words += batch_src_words
                show_trg_words += batch_trg_words
                epoch_trg_words += batch_trg_words

                #batch_log_norm = tc.mean(tc.abs(batch_log_norm))

                if epoch_bidx % wargs.display_freq == 0:
                    #print show_correct_num, show_loss, show_trg_words, show_loss/show_trg_words
                    ud = time.time() - show_start - sample_spend - eval_spend
                    wlog(
                        'Epo:{:>2}/{:>2} |[{:^5} {:^5} {:^5}k] |acc:{:5.2f}% |ppl:{:4.2f} '
                        '| |logZ|:{:.2f} '
                        '|stok/s:{:>4}/{:>2}={:>2} |ttok/s:{:>2} '
                        '|stok/sec:{:6.2f} |ttok/sec:{:6.2f} |elapsed:{:4.2f}/{:4.2f}m'
                        .format(epoch, wargs.max_epochs, epoch_bidx, batch_idx,
                                bidx / 1000,
                                (show_correct_num / show_trg_words) * 100,
                                math.exp(show_loss / show_trg_words), 0,
                                batch_src_words, this_bnum,
                                int(batch_src_words / this_bnum),
                                int(batch_trg_words / this_bnum),
                                show_src_words / ud, show_trg_words / ud, ud,
                                (time.time() - train_start) / 60.))
                    show_loss, show_src_words, show_trg_words, show_correct_num = 0, 0, 0, 0
                    sample_spend, eval_spend = 0, 0
                    show_start = time.time()

                if epoch_bidx % wargs.sampling_freq == 0:

                    sample_start = time.time()
                    self.model.eval()
                    #self.model.classifier.eval()
                    tor = Translator(self.model, self.sv, self.tv)

                    if batch_idx < batch_count:
                        # (max_len_batch, batch_size)
                        sample_src_tensor = srcs.t()[:sample_size]
                        sample_trg_tensor = trgs.t()[:sample_size]
                        tor.trans_samples(sample_src_tensor, sample_trg_tensor,
                                          "OUT")

                    elif batch_idx >= batch_count:
                        sample_src_tensor = srcs_domain.t()[:sample_size]
                        sample_trg_tensor = trgs_domain.t()[:sample_size]
                        tor.trans_samples(sample_src_tensor, sample_trg_tensor,
                                          "IN")
                    wlog('')
                    sample_spend = time.time() - sample_start
                    self.model.train()

                # Just watch the translation of some source sentences in training data
                if wargs.if_fixed_sampling and bidx == batch_start_sample:
                    # randomly select sample_size sample from current batch
                    rand_rows = np.random.choice(this_bnum,
                                                 sample_size,
                                                 replace=False)
                    sample_src_tensor = tc.Tensor(sample_size,
                                                  srcs.size(0)).long()
                    sample_src_tensor.fill_(PAD)
                    sample_trg_tensor = tc.Tensor(sample_size,
                                                  trgs.size(0)).long()
                    sample_trg_tensor.fill_(PAD)

                    for id in xrange(sample_size):
                        sample_src_tensor[id, :] = srcs.t()[rand_rows[id], :]
                        sample_trg_tensor[id, :] = trgs.t()[rand_rows[id], :]

            avg_epoch_loss = epoch_loss / epoch_trg_words
            avg_epoch_acc = epoch_num_correct / epoch_trg_words
            wlog('\nEnd epoch [{}]'.format(epoch))
            wlog('Train accuracy {:4.2f}%'.format(avg_epoch_acc * 100))
            wlog('Average loss {:4.2f}'.format(avg_epoch_loss))
            wlog('Train perplexity: {0:4.2f}'.format(math.exp(avg_epoch_loss)))
            #wlog('Epoch domain loss is {:4.2f}'.format(float(domain_loss)))

            wlog('End epoch, batch [{}], [{}] eval save model ...'.format(
                epoch_bidx, eval_cnt[0]))
            mteval_bleu = self.mt_eval(epoch, epoch_bidx, "IN")
            self.optim.update_learning_rate(mteval_bleu, epoch)

            epoch_time_consume = time.time() - epoch_start
            wlog('Consuming: {:4.2f}s'.format(epoch_time_consume))

        wlog('Train finished, comsuming {:6.2f} hours'.format(
            (time.time() - train_start) / 3600))
        wlog('Congratulations!')
示例#2
0
        #        "问题 交换 了 意见 。"
        #t = "( beijing , syndicated news ) the sino - us relation that was heated momentarily " \
        #        "by the us president bush 's visit to china is cooling down rapidly . china " \
        #        "confirmed yesterday that it has called off its naval fleet visit to the us " \
        #        "ports this year and refused to confirm whether the country 's vice president " \
        #        "hu jintao will visit the united states as planned ."

        s = '当 林肯 去 新奥尔良 时 , 我 听到 密西 西比 河 的 歌声 。'
        t = "When Lincoln goes to New Orleans, I hear Mississippi river's singing sound"
        #s = '新奥尔良 是 爵士 音乐 的 发源 地 。'
        #s = '新奥尔良 以 其 美食 而 闻名 。'
        # = '休斯顿 是 仅 次于 新奥尔良 和 纽约 的 美国 第三 大 港 。'

        s = [[src_vocab.key2idx[x] if x in src_vocab.key2idx else UNK for x in s.split(' ')]]
        t = [[trg_vocab.key2idx[x] if x in trg_vocab.key2idx else UNK for x in t.split(' ')]]
        tor.trans_samples(s, t)
        sys.exit(0)

    input_file = '{}{}.{}'.format(wargs.val_tst_dir, args.input_file, wargs.val_src_suffix)
    input_abspath = os.path.realpath(input_file)
    wlog('Translating test file {} ... '.format(input_abspath))
    ref_file = '{}{}.{}'.format(wargs.val_tst_dir, args.input_file, wargs.val_ref_suffix)
    test_src_tlst, _ = wrap_tst_data(input_abspath, src_vocab, char=wargs.src_char)
    test_input_data = Input(test_src_tlst, None, batch_size=wargs.test_batch_size, batch_sort=False,
                           gpu_ids=args.gpu_ids)

    batch_tst_data = None
    if os.path.exists(ref_file) and False:
        wlog('With force decoding test file {} ... to get alignments'.format(input_file))
        wlog('\t\tRef file {}'.format(ref_file))
        from tools.inputs_handler import wrap_data
示例#3
0
    def train(self):

        wlog('Start training ... ')
        assert wargs.sample_size < wargs.batch_size, 'Batch size < sample count'
        # [low, high)
        batch_count = len(self.train_data)
        batch_start_sample = tc.randperm(batch_count)[-1]
        wlog('Randomly select {} samples in the {}th/{} batch'.format(
            wargs.sample_size, batch_start_sample, batch_count))
        bidx, eval_cnt = 0, [0]
        wlog('Self-normalization alpha -> {}'.format(wargs.self_norm_alpha))
        tor_hook = Translator(self.model, self.sv, self.tv)

        train_start = time.time()
        wlog('\n' + '#' * 120 + '\n' + '#' * 30 + ' Start Training ' +
             '#' * 30 + '\n' + '#' * 120)

        batch_oracles, _checks = None, None
        for epoch in range(wargs.start_epoch, wargs.max_epochs + 1):

            epoch_start = time.time()

            # train for one epoch on the training data
            wlog('\n' + '$' * 30, 0)
            wlog(' Epoch [{}/{}] '.format(epoch, wargs.max_epochs) + '$' * 30)
            if wargs.epoch_shuffle and epoch > wargs.epoch_shuffle_minibatch:
                self.train_data.shuffle()
            # shuffle the original batch
            shuffled_batch_idx = tc.randperm(batch_count)

            sample_size = wargs.sample_size
            epoch_loss, epoch_trg_words, epoch_num_correct, \
                    epoch_batch_logZ, epoch_n_sents = 0, 0, 0, 0, 0
            show_loss, show_src_words, show_trg_words, show_correct_num, \
                    show_batch_logZ, show_n_sents = 0, 0, 0, 0, 0, 0
            sample_spend, eval_spend, epoch_bidx = 0, 0, 0
            show_start = time.time()

            for k in range(batch_count):

                bidx += 1
                epoch_bidx = k + 1
                batch_idx = shuffled_batch_idx[
                    k] if epoch >= wargs.epoch_shuffle_minibatch else k

                # (max_slen_batch, batch_size)
                #_, srcs, ttrgs_for_files, slens, srcs_m, trg_mask_for_files = self.train_data[batch_idx]
                _, srcs, spos, ttrgs_for_files, ttpos_for_files, slens, srcs_m, trg_mask_for_files = self.train_data[
                    batch_idx]
                trgs, tpos, trgs_m = ttrgs_for_files[0], ttpos_for_files[
                    0], trg_mask_for_files[0]
                #self.model.zero_grad()
                self.optim.zero_grad()
                # (max_tlen_batch - 1, batch_size, out_size)
                if wargs.model == 8:
                    gold, gold_mask = trgs[1:], trgs_m[1:]
                    # (L, B) -> (B, L)
                    T_srcs, T_spos, T_trgs, T_tpos = srcs.t(), spos.t(
                    ), trgs.t(), tpos.t()
                    src, trg = (T_srcs, T_spos), (T_trgs, T_tpos)
                    gold, gold_mask = gold.t(), gold_mask.t()
                    #gold, gold_mask = trg[0][:, 1:], trgs_m.t()[:, 1:]
                    if gold.is_contiguous() is False: gold = gold.contiguous()
                    if gold_mask.is_contiguous() is False:
                        gold_mask = gold_mask.contiguous()
                    outputs = self.model(src, trg)
                else:
                    gold, gold_mask = trgs[1:], trgs_m[1:]
                    outputs = self.model(srcs, trgs[:-1], srcs_m, trgs_m[:-1])
                    if len(outputs) == 2: (outputs, _checks) = outputs
                    if len(outputs) == 2: (outputs, attends) = outputs

                this_bnum = outputs.size(1)
                epoch_n_sents += this_bnum
                show_n_sents += this_bnum
                #batch_loss, batch_correct_num, batch_log_norm = self.classifier(outputs, trgs[1:], trgs_m[1:])
                #batch_loss.div(this_bnum).backward()
                #batch_loss = batch_loss.data[0]
                #batch_correct_num = batch_correct_num.data[0]
                batch_loss, batch_correct_num, batch_Z = self.classifier.snip_back_prop(
                    outputs, gold, gold_mask, wargs.snip_size)

                self.optim.step()
                grad_checker(self.model, _checks)

                batch_src_words = srcs.data.ne(PAD).sum()
                assert batch_src_words == slens.data.sum()
                batch_trg_words = trgs[1:].data.ne(PAD).sum()

                show_loss += batch_loss
                show_correct_num += batch_correct_num
                epoch_loss += batch_loss
                epoch_num_correct += batch_correct_num
                show_src_words += batch_src_words
                show_trg_words += batch_trg_words
                epoch_trg_words += batch_trg_words

                show_batch_logZ += batch_Z
                epoch_batch_logZ += batch_Z

                if epoch_bidx % wargs.display_freq == 0:
                    #print show_correct_num, show_loss, show_trg_words, show_loss/show_trg_words
                    ud = time.time() - show_start - sample_spend - eval_spend
                    wlog(
                        'Epo:{:>2}/{:>2} |[{:^5} {:^5} {:^5}k] |acc:{:5.2f}% |ppl:{:4.2f} '
                        '||w-logZ|:{:.2f} ||s-logZ|:{:.2f} '
                        '|stok/s:{:>4}/{:>2}={:>2} |ttok/s:{:>2} '
                        '|stok/sec:{:6.2f} |ttok/sec:{:6.2f} |elapsed:{:4.2f}/{:4.2f}m'
                        .format(epoch, wargs.max_epochs, epoch_bidx, batch_idx,
                                bidx / 1000,
                                (show_correct_num / show_trg_words) * 100,
                                math.exp(show_loss / show_trg_words),
                                show_batch_logZ / show_trg_words,
                                show_batch_logZ / show_n_sents,
                                batch_src_words, this_bnum,
                                int(batch_src_words / this_bnum),
                                int(batch_trg_words / this_bnum),
                                show_src_words / ud, show_trg_words / ud, ud,
                                (time.time() - train_start) / 60.))
                    show_loss, show_src_words, show_trg_words, show_correct_num, \
                            show_batch_logZ, show_n_sents = 0, 0, 0, 0, 0, 0
                    sample_spend, eval_spend = 0, 0
                    show_start = time.time()

                if epoch_bidx % wargs.sampling_freq == 0:

                    sample_start = time.time()
                    self.model.eval()
                    # (max_len_batch, batch_size)
                    sample_src_tensor = srcs.t()[:sample_size]
                    sample_trg_tensor = trgs.t()[:sample_size]
                    sample_src_tensor_pos = T_spos[:
                                                   sample_size] if wargs.model == 8 else None
                    tor_hook.trans_samples(sample_src_tensor,
                                           sample_trg_tensor,
                                           sample_src_tensor_pos)
                    wlog('')
                    sample_spend = time.time() - sample_start
                    self.model.train()

                # Just watch the translation of some source sentences in training data
                if wargs.if_fixed_sampling and bidx == batch_start_sample:
                    # randomly select sample_size sample from current batch
                    rand_rows = np.random.choice(this_bnum,
                                                 sample_size,
                                                 replace=False)
                    sample_src_tensor = tc.Tensor(sample_size,
                                                  srcs.size(0)).long()
                    sample_src_tensor.fill_(PAD)
                    sample_trg_tensor = tc.Tensor(sample_size,
                                                  trgs.size(0)).long()
                    sample_trg_tensor.fill_(PAD)

                    for id in xrange(sample_size):
                        sample_src_tensor[id, :] = srcs.t()[rand_rows[id], :]
                        sample_trg_tensor[id, :] = trgs.t()[rand_rows[id], :]

                if wargs.epoch_eval is not True and bidx > wargs.eval_valid_from and \
                   bidx % wargs.eval_valid_freq == 0:

                    eval_start = time.time()
                    eval_cnt[0] += 1
                    wlog('\nAmong epoch, batch [{}], [{}] eval save model ...'.
                         format(epoch_bidx, eval_cnt[0]))

                    self.mt_eval(epoch, epoch_bidx)

                    eval_spend = time.time() - eval_start

            avg_epoch_loss = epoch_loss / epoch_trg_words
            avg_epoch_acc = epoch_num_correct / epoch_trg_words
            wlog('\nEnd epoch [{}]'.format(epoch))
            wlog('Train accuracy {:4.2f}%'.format(avg_epoch_acc * 100))
            wlog('Average loss {:4.2f}'.format(avg_epoch_loss))
            wlog('Train perplexity: {0:4.2f}'.format(math.exp(avg_epoch_loss)))
            wlog('Train average |w-logZ|: {}/{}={} |s-logZ|: {}/{}={}'.format(
                epoch_batch_logZ, epoch_trg_words,
                epoch_batch_logZ / epoch_trg_words, epoch_batch_logZ,
                epoch_n_sents, epoch_batch_logZ / epoch_n_sents))
            wlog('End epoch, batch [{}], [{}] eval save model ...'.format(
                epoch_bidx, eval_cnt[0]))

            mteval_bleu = self.mt_eval(epoch, epoch_bidx)
            self.optim.update_learning_rate(mteval_bleu, epoch)
            epoch_time_consume = time.time() - epoch_start
            wlog('Consuming: {:4.2f}s'.format(epoch_time_consume))

        wlog('Finish training, comsuming {:6.2f} hours'.format(
            (time.time() - train_start) / 3600))
        wlog('Congratulations!')
示例#4
0
class Trainer(object):
    def __init__(self,
                 model_par,
                 train_data,
                 vocab_data,
                 optim,
                 lossCompute,
                 model,
                 valid_data=None,
                 tests_data=None):

        self.model_par, self.model, self.lossCompute, self.optim = model_par, model, lossCompute, optim
        self.sv, self.tv = vocab_data['src'].idx2key, vocab_data['trg'].idx2key
        self.train_data, self.valid_data, self.tests_data = train_data, valid_data, tests_data
        self.max_epochs, self.start_epoch = wargs.max_epochs, wargs.start_epoch

        self.n_look = wargs.n_look
        assert self.n_look <= wargs.batch_size, 'eyeball count > batch size'
        self.n_batches = len(train_data)  # [low, high)

        self.look_xs, self.look_ys = None, None
        if wargs.fix_looking is True:
            rand_idxs = random.sample(range(train_data.n_sent), self.n_look)
            wlog(
                'randomly look {} samples frow the whole training data'.format(
                    self.n_look))
            self.look_xs = [train_data.x_list[i][0] for i in rand_idxs]
            self.look_ys = [train_data.y_list_files[i][0] for i in rand_idxs]
        self.tor = Translator(model, self.sv, self.tv, gpu_ids=wargs.gpu_ids)
        self.n_eval = 1

        self.grad_accum_count = wargs.grad_accum_count

        self.epoch_shuffle_train = wargs.epoch_shuffle_train
        self.epoch_shuffle_batch = wargs.epoch_shuffle_batch
        self.ss_cur_prob = wargs.ss_prob_begin
        if wargs.ss_type is not None:
            wlog(
                'word-level optimizing bias between training and decoding ...')
            if wargs.bleu_sampling is True:
                wlog('sentence-level optimizing ...')
            wlog('schedule sampling value {}'.format(self.ss_cur_prob))
            if self.ss_cur_prob < 1. and wargs.bleu_sampling is True:
                self.sampler = Nbs(self.model,
                                   self.tv,
                                   k=3,
                                   noise=wargs.bleu_gumbel_noise,
                                   batch_sample=True)
        if self.grad_accum_count > 1:
            assert (
                wargs.chunk_size == 0
            ), 'to accumulate grads, disable target sequence truncating'

    def accum_matrics(self, batch_size, xtoks, ytoks, nll, ok_ytoks, bow_loss):

        self.look_sents += batch_size
        self.e_sents += batch_size
        self.look_nll += nll
        self.look_bow_loss += bow_loss
        self.look_ok_ytoks += ok_ytoks
        self.e_nll += nll
        self.e_ok_ytoks += ok_ytoks
        self.look_xtoks += xtoks
        self.look_ytoks += ytoks
        self.e_ytoks += ytoks

    def grad_accumulate(self, real_batches, e_idx, n_upds):

        #if self.grad_accum_count > 1:
        #    self.model_par.zero_grad()

        for batch in real_batches:

            # (batch_size, max_slen_batch)
            _, xs, y_for_files, bows, x_lens, xs_mask, y_mask_for_files, bows_mask = batch
            _batch_size = xs.size(0)
            ys, ys_mask = y_for_files[0], y_mask_for_files[0]
            #wlog('x: {}, x_mask: {}, y: {}, y_mask: {}'.format(
            #    xs.size(), xs_mask.size(), ys.size(), ys_mask.size()))
            if bows is not None:
                bows, bows_mask = bows[0], bows_mask[0]
                #wlog('bows: {}, bows_mask: {})'.format(bows.size(), bows_mask.size()))
            _xtoks = xs.data.ne(PAD).sum().item()
            assert _xtoks == x_lens.data.sum().item()
            _ytoks = ys[:, 1:].data.ne(PAD).sum().item()

            #if self.grad_accum_count == 1: self.model_par.zero_grad()
            # exclude last target word from inputs
            results = self.model_par(xs, ys[:, :-1], xs_mask, ys_mask[:, :-1],
                                     self.ss_cur_prob)
            logits, alphas, contexts = results['logit'], results[
                'attend'], results['context']
            # (batch_size, y_Lm1, out_size)

            gold, gold_mask = ys[:, 1:].contiguous(), ys_mask[:,
                                                              1:].contiguous()
            # 3. Compute loss in shards for memory efficiency.
            _nll, _ok_ytoks, _bow_loss = self.lossCompute(
                logits, e_idx, n_upds, gold, gold_mask, None, bows, bows_mask,
                contexts)

            self.accum_matrics(_batch_size, _xtoks, _ytoks, _nll, _ok_ytoks,
                               _bow_loss)
        # 3. Update the parameters and statistics.
        self.optim.step()
        self.optim.optimizer.zero_grad()
        #tc.cuda.empty_cache()

    def look_samples(self, n_steps):

        if n_steps % wargs.look_freq == 0:

            look_start = time.time()
            self.model_par.eval()  # affect the dropout !!!
            self.model.eval()
            if self.look_xs is not None and self.look_ys is not None:
                _xs, _ys = self.look_xs, self.look_ys
            else:
                rand_idxs = random.sample(range(self.train_data.n_sent),
                                          self.n_look)
                wlog('randomly look {} samples frow the whole training data'.
                     format(self.n_look))
                _xs = [self.train_data.x_list[i][0] for i in rand_idxs]
                _ys = [self.train_data.y_list_files[i][0] for i in rand_idxs]
            self.tor.trans_samples(_xs, _ys)
            wlog('')
            self.look_spend = time.time() - look_start
            self.model_par.train()
            self.model.train()

    def try_valid(self, e_idx, e_bidx, n_steps):

        if n_steps > 150000: wargs.eval_valid_freq = 1000
        if wargs.epoch_eval is not True and n_steps > wargs.eval_valid_from and \
           n_steps % wargs.eval_valid_freq == 0:
            eval_start = time.time()
            wlog('\nAmong epoch, e_batch:{}, n_steps:{}, {}-th validation ...'.
                 format(e_bidx, n_steps, self.n_eval))
            self.mt_eval(e_idx, e_bidx, n_steps)
            self.eval_spend = time.time() - eval_start

    def mt_eval(self, e_idx, e_bidx, n_steps):

        state_dict = {
            'model': self.model.state_dict(),
            'epoch': e_idx,
            'batch': e_bidx,
            'steps': n_steps,
            'optim': self.optim
        }

        if wargs.save_one_model:
            model_file = '{}.pt'.format(wargs.model_prefix)
        else:
            model_file = '{}_e{}_upd{}.pt'.format(wargs.model_prefix, e_idx,
                                                  n_steps)
        tc.save(state_dict, model_file)
        wlog('Saving temporary model in {}'.format(model_file))

        self.model_par.eval()
        self.model.eval()
        self.tor.trans_eval(self.valid_data, e_idx, e_bidx, n_steps,
                            model_file, self.tests_data)
        self.model_par.train()
        self.model.train()
        self.n_eval += 1

    def train(self):

        wlog('start training ... ')
        train_start = time.time()
        wlog('\n' + '#' * 120 + '\n' + '#' * 30 + ' Start Training ' +
             '#' * 30 + '\n' + '#' * 120)
        batch_oracles, _checks, accum_batches, real_batches = None, None, 0, []
        current_steps = self.optim.n_current_steps
        self.model_par.train()
        self.model.train()

        show_start = time.time()
        self.look_nll, self.look_ytoks, self.look_ok_ytoks, self.look_sents, self.look_bow_loss = 0, 0, 0, 0, 0
        for e_idx in range(self.start_epoch, self.max_epochs + 1):

            wlog('\n{} Epoch [{}/{}] {}'.format('$' * 30, e_idx,
                                                self.max_epochs, '$' * 30))
            if wargs.bow_loss is True:
                wlog('bow: {}'.format(schedule_bow_lambda(e_idx, 5, 0.5, 0.2)))
            # shuffle the training data for each epoch
            if self.epoch_shuffle_train: self.train_data.shuffle()
            self.e_nll, self.e_ytoks, self.e_ok_ytoks, self.e_sents = 0, 0, 0, 0
            self.look_xtoks, self.look_spend, b_counter, self.eval_spend = 0, 0, 0, 0
            epo_start = time.time()
            if self.epoch_shuffle_batch:
                shuffled_bidx = tc.randperm(self.n_batches)

            #for bidx in range(self.n_batches):
            bidx = 0
            cond = True if wargs.lr_update_way != 'invsqrt' else self.optim.learning_rate > wargs.min_lr
            while cond:
                if self.train_data.eos() is True: break
                if current_steps >= wargs.max_update:
                    wlog('Touch the max update {}'.format(wargs.max_update))
                    sys.exit(0)
                b_counter += 1
                e_bidx = shuffled_bidx[
                    bidx] if self.epoch_shuffle_batch else bidx
                if wargs.ss_type is not None and self.ss_cur_prob < 1. and wargs.bleu_sampling:
                    batch_beam_trgs = self.sampler.beam_search_trans(
                        xs, xs_mask, ys_mask)
                    batch_beam_trgs = [
                        list(zip(*b)[0]) for b in batch_beam_trgs
                    ]
                    #wlog(batch_beam_trgs)
                    batch_oracles = batch_search_oracle(
                        batch_beam_trgs, ys[1:], ys_mask[1:])
                    #wlog(batch_oracles)
                    batch_oracles = batch_oracles[:-1].cuda()
                    batch_oracles = self.model.decoder.trg_lookup_table(
                        batch_oracles)

                batch = self.train_data[e_bidx]
                real_batches.append(batch)
                accum_batches += 1
                if accum_batches == self.grad_accum_count:

                    self.grad_accumulate(real_batches, e_idx, current_steps)
                    current_steps = self.optim.n_current_steps
                    del real_batches
                    accum_batches, real_batches = 0, []
                    tc.cuda.empty_cache()
                    #grad_checker(self.model, _checks)
                    if current_steps % wargs.display_freq == 0:
                        #wlog('look_ok_ytoks:{}, look_nll:{}, look_ytoks:{}'.format(self.look_ok_ytoks, self.look_nll, self.look_ytoks))
                        ud = time.time(
                        ) - show_start - self.look_spend - self.eval_spend
                        wlog(
                            'Epo:{:>2}/{:>2} |[{:^5}/{} {:^5}] |acc:{:5.2f}% |{:4.2f}/{:4.2f}=nll:{:4.2f} |bow:{:4.2f}'
                            ' |w-ppl:{:4.2f} |x(y)/s:{:>4}({:>4})/{}={}({}) |x(y)/sec:{}({}) |lr:{:7.6f}'
                            ' |{:4.2f}s/{:4.2f}m'.format(
                                e_idx, self.max_epochs, b_counter,
                                len(self.train_data), current_steps,
                                (self.look_ok_ytoks / self.look_ytoks) * 100,
                                self.look_nll, self.look_ytoks,
                                self.look_nll / self.look_ytoks,
                                self.look_bow_loss / self.look_ytoks,
                                math.exp(self.look_nll / self.look_ytoks),
                                self.look_xtoks, self.look_ytoks,
                                self.look_sents,
                                int(round(self.look_xtoks / self.look_sents)),
                                int(round(self.look_ytoks / self.look_sents)),
                                int(round(self.look_xtoks / ud)),
                                int(round(self.look_ytoks / ud)),
                                self.optim.learning_rate, ud,
                                (time.time() - train_start) / 60.))
                        self.look_nll, self.look_xtoks, self.look_ytoks, self.look_ok_ytoks, self.look_sents, self.look_bow_loss = 0, 0, 0, 0, 0, 0
                        self.look_spend, self.eval_spend = 0, 0
                        show_start = time.time()

                    self.look_samples(current_steps)
                    self.try_valid(e_idx, e_bidx, current_steps)
                bidx += 1

            avg_epo_acc, avg_epo_nll = self.e_ok_ytoks / self.e_ytoks, self.e_nll / self.e_ytoks
            wlog('\nEnd epoch [{}]'.format(e_idx))
            wlog('avg. w-acc: {:4.2f}%, w-nll: {:4.2f}, w-ppl: {:4.2f}'.format(
                avg_epo_acc * 100, avg_epo_nll, math.exp(avg_epo_nll)))
            if wargs.epoch_eval is True:
                wlog(
                    '\nEnd epoch, e_batch:{}, n_steps:{}, {}-th validation ...'
                    .format(e_bidx, n_steps, self.n_eval))
                self.mt_eval(e_idx, e_bidx, self.optim.n_current_steps)
            # decay the probability value epslion of scheduled sampling per batch
            if wargs.ss_type is not None:
                self.ss_cur_prob = ss_prob_decay(e_idx)  # start from 1.
            epo_time_consume = time.time() - epo_start
            wlog('Consuming: {:4.2f}s'.format(epo_time_consume))

        wlog('Finish training, comsuming {:6.2f} hours'.format(
            (time.time() - train_start) / 3600))
        wlog('Congratulations!')
示例#5
0
    def train(self):

        wlog('Start training ... ')
        assert wargs.sample_size < wargs.batch_size, 'Batch size < sample count'
        # [low, high)
        batch_count = len(self.train_data)
        batch_start_sample = tc.randperm(batch_count)[0]
        wlog('Randomly select {} samples in the {}th/{} batch'.format(wargs.sample_size, batch_start_sample, batch_count))
        bidx, eval_cnt, ss_eps_cur = 0, [0], wargs.ss_eps_begin
        wlog('Self-normalization alpha -> {}'.format(wargs.self_norm_alpha))

        train_start = time.time()
        wlog('')
        wlog('#' * 120)
        wlog('#' * 30, False)
        wlog(' Start Training ', False)
        wlog('#' * 30)
        wlog('#' * 120)

        for epoch in range(wargs.start_epoch, wargs.max_epochs + 1):

            epoch_start = time.time()

            # train for one epoch on the training data
            wlog('')
            wlog('$' * 30, False)
            wlog(' Epoch [{}/{}] '.format(epoch, wargs.max_epochs), False)
            wlog('$' * 30)
            wlog('Schedule sampling value {}'.format(ss_eps_cur))

            if wargs.epoch_shuffle and epoch > wargs.epoch_shuffle_minibatch: self.train_data.shuffle()
            # shuffle the original batch
            shuffled_batch_idx = tc.randperm(batch_count)

            sample_size = wargs.sample_size
            epoch_loss, epoch_trg_words, epoch_num_correct = 0, 0, 0
            show_loss, show_src_words, show_trg_words, show_correct_num = 0, 0, 0, 0
            sample_spend, eval_spend, epoch_bidx = 0, 0, 0
            show_start = time.time()

            for k in range(batch_count):

                bidx += 1
                epoch_bidx = k + 1
                batch_idx = shuffled_batch_idx[k] if epoch >= wargs.epoch_shuffle_minibatch else k

                # (max_slen_batch, batch_size)
                _, srcs, trgs, slens, srcs_m, trgs_m = self.train_data[batch_idx]

                self.model.zero_grad()
                # (max_tlen_batch - 1, batch_size, out_size)
                outputs = self.model(srcs, trgs[:-1], srcs_m, trgs_m[:-1], ss_eps=ss_eps_cur)
                if len(outputs) == 2: (outputs, _checks) = outputs
                this_bnum = outputs.size(1)

                #batch_loss, grad_output, batch_correct_num = memory_efficient(
                #    outputs, trgs[1:], trgs_m[1:], self.model.classifier)
                batch_loss, batch_correct_num, batch_log_norm = self.model.decoder.classifier.snip_back_prop(
                    outputs, trgs[1:], trgs_m[1:], wargs.snip_size)

                _grad_nan = False
                for n, p in self.model.named_parameters():
                    if p.grad is None:
                        debug('grad None | {}'.format(n))
                        continue
                    tmp_grad = p.grad.data.cpu().numpy()
                    if numpy.isnan(tmp_grad).any(): # we check gradient here for vanishing Gradient
                        wlog("grad contains 'nan' | {}".format(n))
                        #wlog("gradient\n{}".format(tmp_grad))
                        _grad_nan = True
                    if n == 'decoder.l_f1_0.weight' or n == 's_init.weight' or n=='decoder.l_f1_1.weight' \
                       or n == 'decoder.l_conv.0.weight' or n == 'decoder.l_f2.weight':
                        debug('grad zeros |{:5} {}'.format(str(not np.any(tmp_grad)), n))

                if _grad_nan is True and wargs.dynamic_cyk_decoding is True:
                    for _i, items in enumerate(_checks):
                        wlog('step {} Variable----------------:'.format(_i))
                        #for item in items: wlog(item.cpu().data.numpy())
                        wlog('wen _check_tanh_sa ---------------')
                        wlog(items[0].cpu().data.numpy())
                        wlog('wen _check_a1_weight ---------------')
                        wlog(items[1].cpu().data.numpy())
                        wlog('wen _check_a1 ---------------')
                        wlog(items[2].cpu().data.numpy())
                        wlog('wen alpha_ij---------------')
                        wlog(items[3].cpu().data.numpy())
                        wlog('wen before_mask---------------')
                        wlog(items[4].cpu().data.numpy())
                        wlog('wen after_mask---------------')
                        wlog(items[5].cpu().data.numpy())

                #outputs.backward(grad_output)
                self.optim.step()
                #del outputs, grad_output

                batch_src_words = srcs.data.ne(PAD).sum()
                assert batch_src_words == slens.data.sum()
                batch_trg_words = trgs[1:].data.ne(PAD).sum()

                show_loss += batch_loss
                show_correct_num += batch_correct_num
                epoch_loss += batch_loss
                epoch_num_correct += batch_correct_num
                show_src_words += batch_src_words
                show_trg_words += batch_trg_words
                epoch_trg_words += batch_trg_words

                batch_log_norm = tc.mean(tc.abs(batch_log_norm))

                if epoch_bidx % wargs.display_freq == 0:
                    #print show_correct_num, show_loss, show_trg_words, show_loss/show_trg_words
                    ud = time.time() - show_start - sample_spend - eval_spend
                    wlog(
                        'Epo:{:>2}/{:>2} |[{:^5} {:^5} {:^5}k] |acc:{:5.2f}% |ppl:{:4.2f} '
                        '| |logZ|:{:.2f} '
                        '|stok/s:{:>4}/{:>2}={:>2} |ttok/s:{:>2} '
                        '|stok/sec:{:6.2f} |ttok/sec:{:6.2f} |elapsed:{:4.2f}/{:4.2f}m'.format(
                            epoch, wargs.max_epochs, epoch_bidx, batch_idx, bidx/1000,
                            (show_correct_num / show_trg_words) * 100,
                            math.exp(show_loss / show_trg_words), batch_log_norm,
                            batch_src_words, this_bnum, int(batch_src_words / this_bnum),
                            int(batch_trg_words / this_bnum),
                            show_src_words / ud, show_trg_words / ud, ud,
                            (time.time() - train_start) / 60.)
                    )
                    show_loss, show_src_words, show_trg_words, show_correct_num = 0, 0, 0, 0
                    sample_spend, eval_spend = 0, 0
                    show_start = time.time()

                if epoch_bidx % wargs.sampling_freq == 0:

                    sample_start = time.time()
                    self.model.eval()
                    #self.model.classifier.eval()
                    tor = Translator(self.model, self.sv, self.tv)

                    # (max_len_batch, batch_size)
                    sample_src_tensor = srcs.t()[:sample_size]
                    sample_trg_tensor = trgs.t()[:sample_size]
                    tor.trans_samples(sample_src_tensor, sample_trg_tensor)
                    wlog('')
                    sample_spend = time.time() - sample_start
                    self.model.train()

                # Just watch the translation of some source sentences in training data
                if wargs.if_fixed_sampling and bidx == batch_start_sample:
                    # randomly select sample_size sample from current batch
                    rand_rows = np.random.choice(this_bnum, sample_size, replace=False)
                    sample_src_tensor = tc.Tensor(sample_size, srcs.size(0)).long()
                    sample_src_tensor.fill_(PAD)
                    sample_trg_tensor = tc.Tensor(sample_size, trgs.size(0)).long()
                    sample_trg_tensor.fill_(PAD)

                    for id in xrange(sample_size):
                        sample_src_tensor[id, :] = srcs.t()[rand_rows[id], :]
                        sample_trg_tensor[id, :] = trgs.t()[rand_rows[id], :]

                if wargs.epoch_eval is not True and bidx > wargs.eval_valid_from and \
                   bidx % wargs.eval_valid_freq == 0:

                    eval_start = time.time()
                    eval_cnt[0] += 1
                    wlog('\nAmong epoch, batch [{}], [{}] eval save model ...'.format(
                        epoch_bidx, eval_cnt[0]))

                    self.mt_eval(epoch, epoch_bidx)

                    eval_spend = time.time() - eval_start

            avg_epoch_loss = epoch_loss / epoch_trg_words
            avg_epoch_acc = epoch_num_correct / epoch_trg_words
            wlog('\nEnd epoch [{}]'.format(epoch))
            wlog('Train accuracy {:4.2f}%'.format(avg_epoch_acc * 100))
            wlog('Average loss {:4.2f}'.format(avg_epoch_loss))
            wlog('Train perplexity: {0:4.2f}'.format(math.exp(avg_epoch_loss)))

            wlog('End epoch, batch [{}], [{}] eval save model ...'.format(epoch_bidx, eval_cnt[0]))
            mteval_bleu = self.mt_eval(epoch, epoch_bidx)
            self.optim.update_learning_rate(mteval_bleu, epoch)

            # decay the probability value epslion of scheduled sampling per batch
            ss_eps_cur = schedule_sample_eps_decay(epoch)   # start from 1

            epoch_time_consume = time.time() - epoch_start
            wlog('Consuming: {:4.2f}s'.format(epoch_time_consume))

        wlog('Finish training, comsuming {:6.2f} hours'.format((time.time() - train_start) / 3600))
        wlog('Congratulations!')