コード例 #1
0
ファイル: data_generator.py プロジェクト: varenick/seq2seq
 def test_data_generator(self, batch_size=64):
     self.dev_src = "data/nmt_iwslt/test.de-en.de"
     self.dev_tgt = "data/nmt_iwslt/test.de-en.en"
     eval_data_src = read_corpus(self.dev_src, source='src')
     eval_data_tgt = read_corpus(self.dev_tgt, source='tgt')
     return data_iter(zip(eval_data_src, eval_data_tgt),
                      batch_size=batch_size)
コード例 #2
0
def evaluate_loss(model, data, crit):
    model.eval()
    cum_loss = 0.
    cum_tgt_words = 0.
    for src_sents, tgt_sents in data_iter(data,
                                          batch_size=args.batch_size,
                                          shuffle=False):
        pred_tgt_word_num = sum(len(s[1:])
                                for s in tgt_sents)  # omitting leading `<s>`
        src_sents_len = [len(s) for s in src_sents]

        src_sents_var = to_input_variable(src_sents,
                                          model.vocab.src,
                                          cuda=args.cuda,
                                          is_test=True)
        tgt_sents_var = to_input_variable(tgt_sents,
                                          model.vocab.tgt,
                                          cuda=args.cuda,
                                          is_test=True)

        # (tgt_sent_len, batch_size, tgt_vocab_size)
        scores = model(src_sents_var, src_sents_len, tgt_sents_var[:-1])
        loss = crit(scores.view(-1, scores.size(2)),
                    tgt_sents_var[1:].view(-1))

        cum_loss += loss.item()
        cum_tgt_words += pred_tgt_word_num

    loss = cum_loss / cum_tgt_words
    return loss
コード例 #3
0
ファイル: data_generator.py プロジェクト: varenick/seq2seq
 def train_data_generator(self, batch_size=64):
     self.train_src = "data/nmt_iwslt/train.de-en.de.wmixerprep"
     self.train_tgt = "data/nmt_iwslt/train.de-en.en.wmixerprep"
     train_data_src = read_corpus(self.train_src, source='src')
     train_data_tgt = read_corpus(self.train_tgt, source='tgt')
     return data_iter(zip(train_data_src, train_data_tgt),
                      batch_size=batch_size)
コード例 #4
0
def sample(args):
    train_data_src = read_corpus(args.src_file, source='src')
    train_data_tgt = read_corpus(args.tgt_file, source='tgt')
    train_data = zip(train_data_src, train_data_tgt)

    # load model params
    print('load model from [%s]' % args.model_bin, file=sys.stderr)
    params = torch.load(args.model_bin,
                        map_location=lambda storage, loc: storage)
    vocab = params['vocab']
    opt = params['args']
    state_dict = params['state_dict']

    # build model
    model = NMT(opt, vocab)
    model.load_state_dict(state_dict)
    model.eval()
    model = model.cuda()

    # sampling
    print('begin sampling')
    train_iter = cum_samples = 0
    for src_sents, tgt_sents in data_iter(train_data, batch_size=1):
        train_iter += 1
        samples = model.sample(src_sents, sample_size=5, to_word=True)
        cum_samples += sum(len(sample) for sample in samples)

        for i, tgt_sent in enumerate(tgt_sents):
            print('*' * 80)
            print('target:' + ' '.join(tgt_sent))
            tgt_samples = samples[i]
            print('samples:')
            for sid, sample in enumerate(tgt_samples, 1):
                print('[%d] %s' % (sid, ' '.join(sample[1:-1])))
            print('*' * 80)
コード例 #5
0
def eval(args, data, model, eval_loss_func, vocab_src, vocab_trg, prefix='test', save_to_file=False, is_test=False):
    # if data is None:
    #     return 0, 0, 0, 0
    model.eval()
    all_tags_pred, all_tags_true = [], []
    all_tags_prob = []
    all_words = []
    cum_loss = cum_examples = 0.0
    i = 0
    print('len of data set in eval', len(data))

    for (src, hyp, align, tag, feat, hyp_orig) in data_iter(data, batch_size=args.batch_size, shuffle=False, sort=False, is_test=is_test):
        src_pad, src_len = pad_source(src, '<pad>', vocab_src)
        hyp = word2id(hyp, vocab_trg)
        align_pad, max_src_len, max_trg_len = pad_align(align)
        assert src_len == max_src_len, 'src_len={} != max_src_len={}'.format(src_len, max_src_len)
        assert len(hyp[0]) == len(feat[0])

        src_v = Variable(torch.LongTensor(src_pad))      # [n, T1]
        hyp_v = Variable(torch.LongTensor(hyp))          # [n, T2]
        align_v = Variable(torch.FloatTensor(align_pad)) # [n, T1, T2]
        tags_v = Variable(torch.LongTensor(tag))         # [n, T2]
        feat_v = Variable(torch.FloatTensor(feat))       # [n, T2, extra_feat_size]

        if args.cuda:
            src_v = src_v.cuda()
            hyp_v = hyp_v.cuda()
            align_v = align_v.cuda()
            tags_v = tags_v.cuda()
            feat_v = feat_v.cuda()

        tags_pred = model(src_v, hyp_v, align_v, feat_v)
        loss = eval_loss_func(tags_pred.view(-1, 2), tags_v.view(-1))
        cum_loss += loss.item() * len(hyp) * len(hyp[0])
        cum_examples += len(hyp) * len(hyp[0])

        all_tags_prob.extend(tags_pred.cpu().data.numpy().tolist())    # [n, T2, 2]
        all_tags_pred.extend(torch.max(tags_pred, dim=-1)[1].cpu().data.numpy().tolist())  # [n, T2]
        all_tags_true.extend(tag)                                                          # [n, T2]
        # hyp_words = [[model.vocab.share.id2word[w] for w in sent] for sent in hyp_orig]
        all_words.extend(hyp_orig)
        # all_tags_pred.extend(torch.max(tags_pred, dim=-1)[1].cpu().data.numpy().flatten())
        # all_tags_true.extend(tags_v.cpu().data.numpy().flatten())
    f1_bad, f1_good = f1_score(flatten_list(all_tags_true), flatten_list(all_tags_pred), average=None, pos_label=None)
    # print("F1-BAD: ", f1_bad, "F1-OK: ", f1_good)
    # print("F1-score multiplied: ", f1_bad * f1_good)

    if save_to_file:
        submission = {'probability': all_tags_prob, 'prediction': all_tags_pred, 'target': all_tags_true, 'words': all_words}
        pickle.dump(submission, open(args.save_submission + prefix + '.pkl', 'wb'))
        save_submission(args.save_submission + prefix + '.txt', all_tags_pred, all_words, args.submission_name, args.prediction_type)

    model.train()
    return cum_loss / cum_examples, f1_bad * f1_good, f1_bad, f1_good
コード例 #6
0
def sample(args):
    train_data_src = read_corpus(args.train_src, source='src')
    train_data_tgt = read_corpus(args.train_tgt, source='tgt')
    train_data = zip(train_data_src, train_data_tgt)

    if args.load_model:
        print('load model from [%s]' % args.load_model)
        params = torch.load(args.load_model,
                            map_location=lambda storage, loc: storage)
        vocab = params['vocab']
        opt = params['args']
        state_dict = params['state_dict']

        model = NMT(opt, vocab)
        model.load_state_dict(state_dict)
    else:
        vocab = torch.load(args.vocab)
        model = NMT(args, vocab)

    model.eval()

    if args.cuda:
        # model = nn.DataParallel(model).cuda()
        model = model.cuda()

    print('begin sampling')

    check_every = 10
    train_iter = cum_samples = 0
    train_time = time.time()
    for src_sents, tgt_sents in data_iter(train_data,
                                          batch_size=args.batch_size):
        train_iter += 1
        samples = model.sample(src_sents,
                               sample_size=args.sample_size,
                               to_word=True)
        cum_samples += sum(len(sample) for sample in samples)

        if train_iter % check_every == 0:
            elapsed = time.time() - train_time
            print('sampling speed: %d/s' % (cum_samples / elapsed))
            cum_samples = 0
            train_time = time.time()

        for i, tgt_sent in enumerate(tgt_sents):
            print('*' * 80)
            print('target:' + ' '.join(tgt_sent))
            tgt_samples = samples[i]
            print('samples:')
            for sid, sample in enumerate(tgt_samples, 1):
                print('[%d] %s' % (sid, ' '.join(sample[1:-1])))
            print('*' * 80)
コード例 #7
0
def train(args):
    model, vocab, optimizer, eval_loss_func, train_loss_func = init_model(args)
    scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, 'min', factor=args.lr_decay, patience=1, min_lr=1e-6)
    if args.share_vocab:
        vocab_src = vocab_trg = vocab.share
    else:
        vocab_src = vocab.src
        vocab_trg = vocab.trg

    train_data = load_data(args.trn_dev_path, 'train', args.suffix, args.skip_gap, args.baseline_feature_path)
    dev_data = load_data(args.trn_dev_path, 'dev', args.suffix, args.skip_gap, args.baseline_feature_path)
    test_data = load_data(args.tst_path, 'test', args.suffix, args.skip_gap, args.baseline_feature_path, is_test=True)

    print('begin MLE training')
    train_time = start_time = time.time()
    epoch = train_iter = 0
    cum_train_loss = cum_train_example = train_loss = train_example = valid_num = 0.0
    cum_tags_pred, cum_tags_true, train_tags_pred, train_tags_true = [], [], [], []
    hist_valid_scores = []
    while True:
        epoch += 1
        for (src, hyp, align, tag, feat, hyp_orig) in data_iter(train_data, batch_size=args.batch_size, shuffle=True):
            train_iter += 1

            src_pad, src_len = pad_source(src, '<pad>', vocab_src)
            align_pad, max_src_len, max_trg_len = pad_align(align)
            hyp = word2id(hyp, vocab_trg)
            assert src_len == max_src_len, 'src_len={} != max_src_len={}'.format(src_len, max_src_len)
            assert len(hyp[0]) == len(feat[0])

            src_v = Variable(torch.LongTensor(src_pad))      # [n, T1]
            hyp_v = Variable(torch.LongTensor(hyp))           # [n, T2]
            align_v = Variable(torch.FloatTensor(align_pad))  # [n, T1, T2]
            tags_v = Variable(torch.LongTensor(tag))         # [n, T2]
            feat_v = Variable(torch.FloatTensor(feat))       # [n, T2, d]
            if args.cuda:
                src_v = src_v.cuda()
                hyp_v = hyp_v.cuda()
                align_v = align_v.cuda()
                tags_v = tags_v.cuda()
                feat_v = feat_v.cuda()
            #print('src, hyp, tags', src_v.size(), hyp_v.size(), tags_v.size(), align_v.size())
            tags_pred = model(src_v, hyp_v, align_v, feat_v)    # [n, T2, 2]
            #print('tags_pred', tags_pred.size())
            loss = train_loss_func(tags_pred.view(-1, 2), tags_v.view(-1))
            optimizer.zero_grad()
            loss.backward()
            grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), 5.0)
            optimizer.step()

            train_loss += loss.item() * len(hyp) * len(hyp[0])
            train_example += len(hyp) * len(hyp[0])
            cum_train_loss += loss.item() * len(hyp) * len(hyp[0])
            cum_train_example += len(hyp) * len(hyp[0])

            tags_pred_label = torch.max(tags_pred, dim=-1)[1]
            train_tags_pred.extend(tags_pred_label.cpu().data.numpy().flatten())
            train_tags_true.extend(tags_v.cpu().data.numpy().flatten())
            cum_tags_pred.extend(tags_pred_label.cpu().data.numpy().flatten())
            cum_tags_true.extend(tags_v.cpu().data.numpy().flatten())


            if train_iter % args.log_niter == 0:
                train_tags_true = [int(t) for t in train_tags_true]
                train_tags_pred = [int(t) for t in train_tags_pred]
                #print('true tags', type(train_tags_true), type(train_tags_true[0]), len(train_tags_true), train_tags_true)
                #print('pred tags', type(train_tags_pred), type(train_tags_pred[0]), len(train_tags_pred), train_tags_pred)
                output = f1_score(train_tags_true, train_tags_pred, average=None, pos_label=None)
                #print(output)
                f1_bad, f1_good = output
                gnorm = compute_grad_norm(model)
                pnorm = compute_param_norm(model)
                print("Epoch %r, iter %r: train loss=%.4f, train f1-multi=%.4f, train f1-bad=%.4f, train f1-good=%.4f," \
                      "grad_norm=%.4f, p_norm=%.4f, time=%.2fs" % (epoch, train_iter, train_loss / train_example, \
                        f1_bad * f1_good, f1_bad, f1_good, gnorm, pnorm, time.time() - start_time))
                train_loss = train_example= 0.0
                train_tags_pred, train_tags_true = [], []

            # Perform dev
            if train_iter % args.valid_niter == 0:
                valid_num += 1
                model.eval()
                print('Begin validation ...')
                dev_loss, dev_multi, dev_bad, dev_good = eval(args, dev_data, model, eval_loss_func, vocab_src, vocab_trg, prefix='dev.iter'+str(train_iter), save_to_file=True)
                tst_loss, tst_multi, tst_bad, tst_good = eval(args, test_data, model, eval_loss_func, vocab_src, vocab_trg, prefix='test.iter'+str(train_iter), save_to_file=True, is_test=True)
                trn_loss = cum_train_loss / cum_train_example
                trn_bad, trn_good = f1_score(cum_tags_true, cum_tags_pred, average=None, pos_label=None)
                trn_multi = trn_bad * trn_good
                # trn_sp = trn_kd = 0
                print("validation: epoch %d, iter %d, train loss=%.4f, lr=%e, " \
                      "dev loss=%.4f, test loss=%.4f, time=%.2fs" % (epoch, train_iter, trn_loss, optimizer.param_groups[0]['lr'],
                                                                     dev_loss, tst_loss, time.time() - start_time))
                print("   iter %d, train multi=%.4f, dev multi=%.4f, test multi=%.4f" % (train_iter, trn_multi, dev_multi, tst_multi))
                print("   iter %d, train bad=%.4f, dev bad=%.4f, test bad=%.4f" % (train_iter, trn_bad, dev_bad, tst_bad))
                print("   iter %d, train good=%.4f, dev good=%.4f, test good=%.4f" % (train_iter, trn_good, dev_good, tst_good))
                if args.valid_metric == 'f1-multi':
                    valid_metric = dev_multi
                elif args.valid_metric == 'f1-bad':
                    valid_metric = dev_bad
                elif args.valid_metric == 'f1-good':
                    valid_metric = dev_good
                else:
                    valid_metric = - dev_loss
                cum_train_loss = cum_train_example = 0.0
                cum_tags_pred, cum_tags_true = [], []
                model.train()

                is_better = len(hist_valid_scores) == 0 or valid_metric > max(hist_valid_scores)
                is_better_than_last = len(hist_valid_scores) == 0 or valid_metric > hist_valid_scores[-1]
                hist_valid_scores.append(valid_metric)

                if valid_num > args.save_model_after:
                    model_file = args.save_model + '.iter{}.bin'.format(train_iter)
                    model.save(model_file)
                    print('save model to [%s]' % model_file)

                #if (not is_better_than_last) and args.lr_decay:
                #    lr = max(optimizer.param_groups[0]['lr'] * args.lr_decay, args.min_lr)
                #    print('decay learning rate to %e' % lr)
                #    optimizer.param_groups[0]['lr'] = lr
                scheduler.step(dev_loss)

                if is_better:
                    patience = 0
                    best_model_iter = train_iter

                    if valid_num > args.save_model_after:
                        print('save current best model ... ')
                        model_file_abs_path = os.path.abspath(model_file)
                        symlin_file_abs_path = os.path.abspath(args.save_model + '.bin')
                        os.system('ln -sf %s %s' % (model_file_abs_path, symlin_file_abs_path))
                else:
                    patience += 1
                    print('hit patience %d' % patience)
                    if patience == args.patience:
                        print('early stop! the best model is from [%d], best valid score %f' % (best_model_iter, max(hist_valid_scores)))
                        exit(0)
コード例 #8
0
def train_raml(args):
    tau = args.temp

    train_data_src = read_corpus(args.train_src, source='src')
    train_data_tgt = read_corpus(args.train_tgt, source='tgt')
    train_data = list(zip(train_data_src, train_data_tgt))

    dev_data_src = read_corpus(args.dev_src, source='src')
    dev_data_tgt = read_corpus(args.dev_tgt, source='tgt')
    dev_data = list(zip(dev_data_src, dev_data_tgt))

    dev_data = dev_data[:args.dev_limit]

    vocab, model, optimizer, nll_loss, cross_entropy_loss = init_training(args)

    if args.raml_sample_mode == 'pre_sample':
        # dict of (src, [tgt: (sent, prob)])
        print('read in raml training data...', file=sys.stderr, end='')
        begin_time = time.time()
        raml_samples = read_raml_train_data(args.raml_sample_file, temp=tau)
        print('done[%d s].' % (time.time() - begin_time))
    elif args.raml_sample_mode.startswith('hamming_distance'):
        print('sample from hamming distance payoff distribution')
        payoff_prob, Z_qs = generate_hamming_distance_payoff_distribution(
            max(len(sent) for sent in train_data_tgt),
            vocab_size=len(vocab.tgt) - 3,
            tau=tau)

    train_iter = patience = cum_loss = report_loss = cum_tgt_words = report_tgt_words = 0
    report_weighted_loss = cum_weighted_loss = 0
    cum_examples = cum_batches = report_examples = epoch = valid_num = best_model_iter = 0
    hist_valid_scores = []
    train_time = begin_time = time.time()
    _info = f"""
        begin RAML training
        ・学習:{len(train_data)}ペア, {args.train_log_file}
        ・テスト:{len(dev_data)}ペア, {args.valid_niter}iter毎 {args.validation_log_file}
        ・バッチサイズ:{args.batch_size}
        ・1epoch = {len(train_data)}ペア = {int(len(train_data)/args.batch_size)}iter
        """
    print(_info)

    log_data = {'args': args}

    if args.notify_slack:
        slack.post(f"""
        {_info}
        {args}
        """)

    # smoothing function for BLEU
    sm_func = None
    if args.smooth_bleu:
        sm_func = SmoothingFunction().method3

    with open(args.train_log_file,
              "w") as train_output, open(args.validation_log_file,
                                         "w") as validation_output:

        while True:
            epoch += 1
            for src_sents, tgt_sents in data_iter(train_data,
                                                  batch_size=args.batch_size):
                train_iter += 1

                raml_src_sents = []
                raml_tgt_sents = []
                raml_tgt_weights = []

                if args.raml_sample_mode == 'pre_sample':
                    for src_sent in src_sents:
                        sent = ' '.join(src_sent)
                        tgt_samples_all = raml_samples[sent]
                        # print(f'src_sent: "{sent}", target_samples_all: {len(list(tgt_samples_all))}')
                        if args.sample_size >= len(list(tgt_samples_all)):
                            tgt_samples = tgt_samples_all
                        else:
                            tgt_samples_id = np.random.choice(
                                range(1, len(list(tgt_samples_all))),
                                size=args.sample_size - 1,
                                replace=False)
                            tgt_samples = [tgt_samples_all[0]] + [
                                tgt_samples_all[i] for i in tgt_samples_id
                            ]  # make sure the ground truth y* is in the samples

                        raml_src_sents.extend([src_sent] *
                                              len(list(tgt_samples)))
                        raml_tgt_sents.extend(
                            [['<s>'] + sent.split(' ') + ['</s>']
                             for sent, weight in tgt_samples])
                        raml_tgt_weights.extend(
                            [weight for sent, weight in tgt_samples])
                elif args.raml_sample_mode in [
                        'hamming_distance', 'hamming_distance_impt_sample'
                ]:
                    for src_sent, tgt_sent in zip(src_sents, tgt_sents):
                        tgt_samples = [
                        ]  # make sure the ground truth y* is in the samples
                        tgt_sent_len = len(
                            tgt_sent
                        ) - 3  # remove <s> and </s> and ending period .
                        tgt_ref_tokens = tgt_sent[1:-1]
                        bleu_scores = []
                        # print('y*: %s' % ' '.join(tgt_sent))
                        # sample an edit distances
                        e_samples = np.random.choice(
                            range(tgt_sent_len + 1),
                            p=payoff_prob[tgt_sent_len],
                            size=args.sample_size,
                            replace=True)

                        # make sure the ground truth y* is in the samples
                        if args.raml_bias_groundtruth and (not 0 in e_samples):
                            e_samples[0] = 0

                        for i, e in enumerate(e_samples):
                            if e > 0:
                                # sample a new tgt_sent $y$
                                old_word_pos = np.random.choice(range(
                                    1, tgt_sent_len + 1),
                                                                size=e,
                                                                replace=False)
                                new_words = [
                                    vocab.tgt.id2word[wid]
                                    for wid in np.random.randint(
                                        3, len(vocab.tgt), size=e)
                                ]
                                new_tgt_sent = list(tgt_sent)
                                for pos, word in zip(old_word_pos, new_words):
                                    new_tgt_sent[pos] = word
                            else:
                                new_tgt_sent = list(tgt_sent)

                            # if enable importance sampling, compute bleu score
                            if args.raml_sample_mode == 'hamming_distance_impt_sample':
                                if e > 0:
                                    # remove <s> and </s>
                                    bleu_score = sentence_bleu(
                                        [tgt_ref_tokens],
                                        new_tgt_sent[1:-1],
                                        smoothing_function=sm_func)
                                    bleu_scores.append(bleu_score)
                                else:
                                    bleu_scores.append(1.)

                            # print('y: %s' % ' '.join(new_tgt_sent))
                            tgt_samples.append(new_tgt_sent)

                        # if enable importance sampling, compute importance weight
                        if args.raml_sample_mode == 'hamming_distance_impt_sample':
                            tgt_sample_weights = [
                                math.exp(bleu_score / tau) / math.exp(-e / tau)
                                for e, bleu_score in zip(
                                    e_samples, bleu_scores)
                            ]
                            normalizer = sum(tgt_sample_weights)
                            tgt_sample_weights = [
                                w / normalizer for w in tgt_sample_weights
                            ]
                        else:
                            tgt_sample_weights = [1.] * args.sample_size

                        raml_src_sents.extend([src_sent] * len(tgt_samples))
                        raml_tgt_sents.extend(tgt_samples)
                        raml_tgt_weights.extend(tgt_sample_weights)

                        if args.debug:
                            print('*' * 30)
                            print('Target: %s' % ' '.join(tgt_sent))
                            for tgt_sample, e, bleu_score, weight in zip(
                                    tgt_samples, e_samples, bleu_scores,
                                    tgt_sample_weights):
                                print(
                                    'Sample: %s ||| e: %d ||| bleu: %f ||| weight: %f'
                                    % (' '.join(tgt_sample), e, bleu_score,
                                       weight))
                            print()
                            break

                src_sents_var = to_input_variable(raml_src_sents,
                                                  vocab.src,
                                                  cuda=args.cuda)
                tgt_sents_var = to_input_variable(raml_tgt_sents,
                                                  vocab.tgt,
                                                  cuda=args.cuda)
                weights_var = Variable(torch.FloatTensor(raml_tgt_weights),
                                       requires_grad=False)
                if args.cuda:
                    weights_var = weights_var.cuda()

                batch_size = len(
                    raml_src_sents
                )  # batch_size = args.batch_size * args.sample_size
                src_sents_len = [len(s) for s in raml_src_sents]
                pred_tgt_word_num = sum(len(
                    s[1:]) for s in raml_tgt_sents)  # omitting leading `<s>`
                optimizer.zero_grad()

                # (tgt_sent_len, batch_size, tgt_vocab_size)
                scores = model(src_sents_var, src_sents_len,
                               tgt_sents_var[:-1])
                # (tgt_sent_len * batch_size, tgt_vocab_size)
                log_scores = F.log_softmax(scores.view(-1, scores.size(2)))
                # remove leading <s> in tgt sent, which is not used as the target
                flattened_tgt_sents = tgt_sents_var[1:].view(-1)

                # batch_size * tgt_sent_len
                tgt_log_scores = torch.gather(
                    log_scores, 1, flattened_tgt_sents.unsqueeze(1)).squeeze(1)
                unweighted_loss = -tgt_log_scores * (
                    1. - torch.eq(flattened_tgt_sents, 0).float())
                weighted_loss = unweighted_loss * weights_var.repeat(
                    scores.size(0))
                weighted_loss = weighted_loss.sum()
                weighted_loss_val = weighted_loss.item()
                nll_loss_val = unweighted_loss.sum().item()
                # weighted_log_scores = log_scores * weights.view(-1, scores.size(2))
                # weighted_loss = nll_loss(weighted_log_scores, flattened_tgt_sents)

                loss = weighted_loss / batch_size
                # nll_loss_val = nll_loss(log_scores, flattened_tgt_sents).item()

                loss.backward()
                # clip gradient
                grad_norm = torch.nn.utils.clip_grad_norm_(
                    model.parameters(), args.clip_grad)
                optimizer.step()

                report_weighted_loss += weighted_loss_val
                cum_weighted_loss += weighted_loss_val
                report_loss += nll_loss_val
                cum_loss += nll_loss_val
                report_tgt_words += pred_tgt_word_num
                cum_tgt_words += pred_tgt_word_num
                report_examples += batch_size
                cum_examples += batch_size
                cum_batches += batch_size

                if train_iter % args.log_every == 0 or train_iter % args.notify_slack_every == 0:
                    _log = 'epoch %d, iter %d, avg. loss %.2f, avg. ppl %.2f cum. examples %d, speed %.2f words/sec, time elapsed %.2f sec' % (
                        epoch, train_iter,
                        report_weighted_loss / report_examples,
                        np.exp(report_loss / report_tgt_words), cum_examples,
                        report_tgt_words /
                        (time.time() - train_time), time.time() - begin_time)
                    print(_log)
                    print(_log, file=train_output)

                    _list_dict_update(
                        log_data, {
                            'epoch': epoch,
                            'train_iter': train_iter,
                            'loss': report_weighted_loss / report_examples,
                            'ppl': np.exp(report_loss / report_tgt_words),
                            'examples': cum_examples,
                            'speed': report_tgt_words /
                            (time.time() - train_time),
                            'elapsed': time.time() - begin_time
                        }, 'train')

                    train_time = time.time()
                    report_loss = report_weighted_loss = report_tgt_words = report_examples = 0.
                    if train_iter % args.notify_slack_every == 0 and args.notify_slack:
                        print('post slack')
                        slack.post(_log)

                # perform validation
                if train_iter % args.valid_niter == 0:
                    print('epoch %d, iter %d, cum. loss %.2f, '
                          'cum. ppl %.2f cum. examples %d' %
                          (epoch, train_iter, cum_weighted_loss / cum_batches,
                           np.exp(cum_loss / cum_tgt_words), cum_examples),
                          file=sys.stderr)

                    cum_loss = cum_weighted_loss = cum_batches = cum_tgt_words = 0.
                    valid_num += 1

                    print('begin validation ...')
                    model.eval()

                    # compute dev. ppl and bleu

                    dev_loss = evaluate_loss(model, dev_data,
                                             cross_entropy_loss)
                    dev_ppl = np.exp(dev_loss)

                    if args.valid_metric in ['bleu', 'word_acc', 'sent_acc']:
                        dev_hyps = decode(model,
                                          dev_data,
                                          f=validation_output,
                                          verbose=False)
                        dev_hyps = [hyps[0] for hyps in dev_hyps]
                        if args.valid_metric == 'bleu':
                            valid_metric = get_bleu(
                                [tgt for src, tgt in dev_data], dev_hyps)
                        else:
                            valid_metric = get_acc(
                                [tgt for src, tgt in dev_data],
                                dev_hyps,
                                acc_type=args.valid_metric)
                        _log = 'validation: iter %d, dev. ppl %f, dev. %s %f' % (
                            train_iter, dev_ppl, args.valid_metric,
                            valid_metric)
                        print(_log)
                        print(_log, file=validation_output)
                        if args.notify_slack:
                            slack.post(_log)

                    else:
                        valid_metric = -dev_ppl
                        print('validation: iter %d, dev. ppl %f' %
                              (train_iter, dev_ppl),
                              file=sys.stderr)

                    if 'dev_data' in log_data:
                        log_data['dev_data'] = dev_data

                    _list_dict_update(log_data, {
                        'epoch': epoch,
                        'train_iter': train_iter,
                        'loss': dev_loss,
                        'ppl': dev_ppl,
                        args.valid_metric: valid_metric,
                        'hyps': dev_hyps,
                    },
                                      'validation',
                                      is_save=True)

                    model.train()

                    is_better = len(
                        hist_valid_scores
                    ) == 0 or valid_metric > max(hist_valid_scores)
                    is_better_than_last = len(
                        hist_valid_scores
                    ) == 0 or valid_metric > hist_valid_scores[-1]
                    hist_valid_scores.append(valid_metric)

                    if valid_num > args.save_model_after:
                        model_file = args.save_to + '.iter%d.bin' % train_iter
                        print('save model to [%s]' % model_file)
                        model.save(model_file)

                    if (not is_better_than_last) and args.lr_decay:
                        lr = optimizer.param_groups[0]['lr'] * args.lr_decay
                        print('decay learning rate to %f' % lr)
                        optimizer.param_groups[0]['lr'] = lr

                    if is_better:
                        patience = 0
                        best_model_iter = train_iter

                        if valid_num > args.save_model_after:
                            print('save currently the best model ..')
                            model_file_abs_path = os.path.abspath(model_file)
                            symlin_file_abs_path = os.path.abspath(
                                args.save_to + '.bin')
                            os.system(
                                'ln -sf %s %s' %
                                (model_file_abs_path, symlin_file_abs_path))
                    else:
                        patience += 1
                        print('hit patience %d' % patience)
                        if patience == args.patience:
                            _log = f"""
                            {'hit patience %d' % patience}
                            early stop!
                            {'the best model is from iteration [%d]' % best_model_iter}
                            """
                            print(_log)
                            if args.notify_slack:
                                slack.post(_log)
                            exit(0)

                if args.debug:
                    print(f'debug epoch:{epoch} exit')
                    model_file = args.save_to + '.bin'
                    print('save model to [%s]' % model_file)
                    model.save(model_file)
                    exit(0)
コード例 #9
0
def train(args):
    train_data_src = read_corpus(args.train_src, source='src')
    train_data_tgt = read_corpus(args.train_tgt, source='tgt')

    dev_data_src = read_corpus(args.dev_src, source='src')
    dev_data_tgt = read_corpus(args.dev_tgt, source='tgt')

    train_data = list(zip(train_data_src, train_data_tgt))
    dev_data = list(zip(dev_data_src, dev_data_tgt))

    dev_data = dev_data[:args.dev_limit]

    vocab, model, optimizer, nll_loss, cross_entropy_loss = init_training(args)

    train_iter = patience = cum_loss = report_loss = cum_tgt_words = report_tgt_words = 0
    cum_examples = cum_batches = report_examples = epoch = valid_num = best_model_iter = 0
    hist_valid_scores = []
    train_time = begin_time = time.time()

    _info = f"""
        begin Maximum Likelihood training
        ・学習:{len(train_data)}ペア, {args.train_log_file}
        ・テスト:{len(dev_data)}ペア, {args.valid_niter}iter毎 {args.validation_log_file}
        ・バッチサイズ:{args.batch_size}
        ・1epoch = {len(train_data)}ペア = {int(len(train_data)/args.batch_size)}iter
        """
    print(_info)

    if args.notify_slack:
        slack.post(f"""
        {_info}
        {args}
        """)

    with open(args.train_log_file,
              "w") as train_output, open(args.validation_log_file,
                                         "w") as validation_output:

        while True:
            epoch += 1
            for src_sents, tgt_sents in data_iter(train_data,
                                                  batch_size=args.batch_size):
                train_iter += 1

                src_sents_var = to_input_variable(src_sents,
                                                  vocab.src,
                                                  cuda=args.cuda)
                tgt_sents_var = to_input_variable(tgt_sents,
                                                  vocab.tgt,
                                                  cuda=args.cuda)

                batch_size = len(src_sents)
                src_sents_len = [len(s) for s in src_sents]
                pred_tgt_word_num = sum(
                    len(s[1:]) for s in tgt_sents)  # omitting leading `<s>`

                optimizer.zero_grad()

                # (tgt_sent_len, batch_size, tgt_vocab_size)
                scores = model(src_sents_var, src_sents_len,
                               tgt_sents_var[:-1])

                word_loss = cross_entropy_loss(scores.view(-1, scores.size(2)),
                                               tgt_sents_var[1:].view(-1))
                loss = word_loss / batch_size
                word_loss_val = word_loss.item()
                loss_val = loss.item()

                loss.backward()
                # clip gradient
                grad_norm = torch.nn.utils.clip_grad_norm_(
                    model.parameters(), args.clip_grad)
                optimizer.step()

                report_loss += word_loss_val
                cum_loss += word_loss_val
                report_tgt_words += pred_tgt_word_num
                cum_tgt_words += pred_tgt_word_num
                report_examples += batch_size
                cum_examples += batch_size
                cum_batches += batch_size

                if train_iter % args.log_every == 0:
                    _log = 'epoch %d, iter %d, avg. loss %.2f, avg. ppl %.2f ' \
                           'cum. examples %d, speed %.2f words/sec, time elapsed %.2f sec' % (epoch, train_iter,
                                                                                              report_loss / report_examples,
                                                                                              np.exp(
                                                                                                  report_loss / report_tgt_words),
                                                                                              cum_examples,
                                                                                              report_tgt_words / (
                                                                                                      time.time() - train_time),
                                                                                              time.time() - begin_time)
                    print(_log)
                    print(_log, file=train_output)

                    train_time = time.time()
                    report_loss = report_tgt_words = report_examples = 0.

                # perform validation
                if train_iter % args.valid_niter == 0:
                    print(
                        'epoch %d, iter %d, cum. loss %.2f, cum. ppl %.2f cum. examples %d'
                        % (epoch, train_iter, cum_loss / cum_batches,
                           np.exp(cum_loss / cum_tgt_words), cum_examples),
                        file=sys.stderr)

                    cum_loss = cum_batches = cum_tgt_words = 0.
                    valid_num += 1

                    print('begin validation ...', file=sys.stderr)
                    model.eval()

                    # compute dev. ppl and bleu

                    dev_loss = evaluate_loss(model, dev_data,
                                             cross_entropy_loss)
                    dev_ppl = np.exp(dev_loss)

                    if args.valid_metric in ['bleu', 'word_acc', 'sent_acc']:
                        dev_hyps = decode(model, dev_data)
                        dev_hyps = [hyps[0] for hyps in dev_hyps]
                        if args.valid_metric == 'bleu':
                            valid_metric = get_bleu(
                                [tgt for src, tgt in dev_data], dev_hyps)
                        else:
                            valid_metric = get_acc(
                                [tgt for src, tgt in dev_data],
                                dev_hyps,
                                acc_type=args.valid_metric)
                        _log = 'validation: iter %d, dev. ppl %f, dev. %s %f' % (
                            train_iter, dev_ppl, args.valid_metric,
                            valid_metric)
                        print(_log, file=sys.stderr)
                        print(_log, file=validation_output)
                        if args.notify_slack:
                            slack.post(_log)

                    else:
                        valid_metric = -dev_ppl
                        print('validation: iter %d, dev. ppl %f' %
                              (train_iter, dev_ppl),
                              file=sys.stderr)

                    model.train()

                    is_better = len(
                        hist_valid_scores
                    ) == 0 or valid_metric > max(hist_valid_scores)
                    is_better_than_last = len(
                        hist_valid_scores
                    ) == 0 or valid_metric > hist_valid_scores[-1]
                    hist_valid_scores.append(valid_metric)

                    if valid_num > args.save_model_after:
                        model_file = args.save_to + '.iter%d.bin' % train_iter
                        print('save model to [%s]' % model_file,
                              file=sys.stderr)
                        model.save(model_file)

                    if (not is_better_than_last) and args.lr_decay:
                        lr = optimizer.param_groups[0]['lr'] * args.lr_decay
                        print('decay learning rate to %f' % lr,
                              file=sys.stderr)
                        optimizer.param_groups[0]['lr'] = lr

                    if is_better:
                        patience = 0
                        best_model_iter = train_iter

                        if valid_num > args.save_model_after:
                            print('save currently the best model ..',
                                  file=sys.stderr)
                            model_file_abs_path = os.path.abspath(model_file)
                            symlin_file_abs_path = os.path.abspath(
                                args.save_to + '.bin')
                            os.system(
                                'ln -sf %s %s' %
                                (model_file_abs_path, symlin_file_abs_path))
                    else:
                        patience += 1
                        print('hit patience %d' % patience, file=sys.stderr)
                        if patience == args.patience:
                            print('early stop!', file=sys.stderr)
                            print('the best model is from iteration [%d]' %
                                  best_model_iter,
                                  file=sys.stderr)
                            exit(0)
コード例 #10
0
ファイル: pc_nmt.py プロジェクト: hrlinlp/pytorch_NMT-1
def train_raml(args):
    vocab = torch.load(args.vocab)

    train_data_src = read_corpus(args.train_src, source='src')
    train_data_tgt = read_corpus(args.train_tgt, source='tgt')
    train_data = zip(train_data_src, train_data_tgt)

    dev_data_src = read_corpus(args.dev_src, source='src')
    dev_data_tgt = read_corpus(args.dev_tgt, source='tgt')
    dev_data = zip(dev_data_src, dev_data_tgt)

    # dict of (src, [tgt: (sent, prob)])
    print('read in raml training data...', file=sys.stderr, end='')
    begin_time = time.time()
    raml_samples = read_raml_train_data(args.raml_sample_file, temp=args.temp)
    print('done[%d s].' % (time.time() - begin_time))

    vocab, model, optimizer, nll_loss, cross_entropy_loss = init_training(args)

    train_iter = patience = cum_loss = report_loss = cum_tgt_words = report_tgt_words = 0
    report_weighted_loss = cum_weighted_loss = 0
    cum_examples = cum_batches = report_examples = epoch = valid_num = best_model_iter = 0
    hist_valid_scores = []
    train_time = begin_time = time.time()
    print('begin RAML training')

    while True:
        epoch += 1
        for src_sents, tgt_sents in data_iter(train_data, batch_size=args.batch_size):
            train_iter += 1

            raml_src_sents = []
            raml_tgt_sents = []
            raml_tgt_weights = []
            for src_sent in src_sents:
                tgt_samples_all = raml_samples[' '.join(src_sent)]

                if args.sample_size >= len(tgt_samples_all):
                    tgt_samples = tgt_samples_all
                else:
                    tgt_samples_id = np.random.choice(range(1, len(tgt_samples_all)), size=args.sample_size - 1, replace=False)
                    tgt_samples = [tgt_samples_all[0]] + [tgt_samples_all[i] for i in tgt_samples_id] # make sure the ground truth y* is in the samples

                raml_src_sents.extend([src_sent] * len(tgt_samples))
                raml_tgt_sents.extend([['<s>'] + sent.split(' ') + ['</s>'] for sent, weight in tgt_samples])
                raml_tgt_weights.extend([weight for sent, weight in tgt_samples])

            src_sents_var = to_input_variable(raml_src_sents, vocab.src, cuda=args.cuda)
            tgt_sents_var = to_input_variable(raml_tgt_sents, vocab.tgt, cuda=args.cuda)
            weights_var = Variable(torch.FloatTensor(raml_tgt_weights), requires_grad=False)
            if args.cuda:
                weights_var = weights_var.cuda()

            batch_size = len(raml_src_sents)  # batch_size = args.batch_size * args.sample_size
            src_sents_len = [len(s) for s in raml_src_sents]
            pred_tgt_word_num = sum(len(s[1:]) for s in raml_tgt_sents)  # omitting leading `<s>`
            optimizer.zero_grad()

            # (tgt_sent_len, batch_size, tgt_vocab_size)
            scores = model(src_sents_var, src_sents_len, tgt_sents_var[:-1])
            log_scores = F.log_softmax(scores.view(-1, scores.size(2)))
            # weights = weights_var.view(1, weights_var.size(0), 1).expand_as(scores).contiguous()
            flattened_tgt_sents = tgt_sents_var[1:].view(-1)

            # batch_size * tgt_sent_len
            tgt_log_scores = torch.gather(log_scores, 1, flattened_tgt_sents.unsqueeze(1)).squeeze(1)
            unweighted_loss = -tgt_log_scores * (1. - torch.eq(flattened_tgt_sents, 0).float())
            weighted_loss = unweighted_loss * weights_var.repeat(scores.size(0))
            weighted_loss = weighted_loss.sum()
            weighted_loss_val = weighted_loss.data[0]
            nll_loss_val = unweighted_loss.sum().data[0]
            # weighted_log_scores = log_scores * weights.view(-1, scores.size(2))
            # weighted_loss = nll_loss(weighted_log_scores, flattened_tgt_sents)

            loss = weighted_loss / batch_size
            # nll_loss_val = nll_loss(log_scores, flattened_tgt_sents).data[0]

            loss.backward()
            # clip gradient
            grad_norm = torch.nn.utils.clip_grad_norm(model.parameters(), args.clip_grad)
            optimizer.step()

            report_weighted_loss += weighted_loss_val
            cum_weighted_loss += weighted_loss_val
            report_loss += nll_loss_val
            cum_loss += nll_loss_val
            report_tgt_words += pred_tgt_word_num
            cum_tgt_words += pred_tgt_word_num
            report_examples += batch_size
            cum_examples += batch_size
            cum_batches += batch_size

            if train_iter % args.log_every == 0:
                print('epoch %d, iter %d, avg. loss %.2f, '
                      'avg. ppl %.2f cum. examples %d, '
                      'speed %.2f words/sec, time elapsed %.2f sec' % (epoch, train_iter,
                                                                       report_weighted_loss / report_examples,
                                                                       np.exp(report_loss / report_tgt_words),
                                                                       cum_examples,
                                                                       report_tgt_words / (time.time() - train_time),
                                                                       time.time() - begin_time),
                      file=sys.stderr)

                train_time = time.time()
                report_loss = report_weighted_loss = report_tgt_words = report_examples = 0.

            # perform validation
            if train_iter % args.valid_niter == 0:
                print('epoch %d, iter %d, cum. loss %.2f, '
                      'cum. ppl %.2f cum. examples %d' % (epoch, train_iter,
                                                          cum_weighted_loss / cum_batches,
                                                          np.exp(cum_loss / cum_tgt_words),
                                                          cum_examples),
                      file=sys.stderr)

                cum_loss = cum_weighted_loss = cum_batches = cum_tgt_words = 0.
                valid_num += 1

                print('begin validation ...', file=sys.stderr)
                model.eval()

                # compute dev. ppl and bleu

                dev_loss = evaluate_loss(model, dev_data, cross_entropy_loss)
                dev_ppl = np.exp(dev_loss)

                if args.valid_metric in ['bleu', 'word_acc', 'sent_acc']:
                    dev_hyps = decode(model, dev_data)
                    dev_hyps = [hyps[0] for hyps in dev_hyps]
                    if args.valid_metric == 'bleu':
                        valid_metric = get_bleu([tgt for src, tgt in dev_data], dev_hyps)
                    else:
                        valid_metric = get_acc([tgt for src, tgt in dev_data], dev_hyps, acc_type=args.valid_metric)
                    print('validation: iter %d, dev. ppl %f, dev. %s %f' % (
                    train_iter, dev_ppl, args.valid_metric, valid_metric),
                          file=sys.stderr)
                else:
                    valid_metric = -dev_ppl
                    print('validation: iter %d, dev. ppl %f' % (train_iter, dev_ppl),
                          file=sys.stderr)

                model.train()

                is_better = len(hist_valid_scores) == 0 or valid_metric > max(hist_valid_scores)
                is_better_than_last = len(hist_valid_scores) == 0 or valid_metric > hist_valid_scores[-1]
                hist_valid_scores.append(valid_metric)

                if valid_num > args.save_model_after:
                    model_file = args.save_to + '.iter%d.bin' % train_iter
                    print('save model to [%s]' % model_file, file=sys.stderr)
                    model.save(model_file)

                if (not is_better_than_last) and args.lr_decay:
                    lr = optimizer.param_groups[0]['lr'] * args.lr_decay
                    print('decay learning rate to %f' % lr, file=sys.stderr)
                    optimizer.param_groups[0]['lr'] = lr

                if is_better:
                    patience = 0
                    best_model_iter = train_iter

                    if valid_num > args.save_model_after:
                        print('save currently the best model ..', file=sys.stderr)
                        model_file_abs_path = os.path.abspath(model_file)
                        symlin_file_abs_path = os.path.abspath(args.save_to + '.bin')
                        os.system('ln -sf %s %s' % (model_file_abs_path, symlin_file_abs_path))
                else:
                    patience += 1
                    print('hit patience %d' % patience, file=sys.stderr)
                    if patience == args.patience:
                        print('early stop!', file=sys.stderr)
                        print('the best model is from iteration [%d]' % best_model_iter, file=sys.stderr)
                        exit(0)
コード例 #11
0
ファイル: nmt.py プロジェクト: wangyu1997/CSCGDual
def train(args):
    train_data_src = read_corpus(args.train_src, source='src')
    train_data_tgt = read_corpus(args.train_tgt, source='tgt')

    dev_data_src = read_corpus(args.dev_src, source='src')
    dev_data_tgt = read_corpus(args.dev_tgt, source='tgt')

    train_data = list(zip(train_data_src, train_data_tgt))
    dev_data = list(zip(dev_data_src, dev_data_tgt))

    vocab, model, optimizer, nll_loss, cross_entropy_loss = init_training(args)

    train_iter = patience = cum_loss = report_loss = cum_tgt_words = report_tgt_words = 0
    cum_examples = cum_batches = report_examples = epoch = valid_num = best_model_iter = 0

    if args.load_model:
        import re
        train_iter = int(re.search('(?<=iter)\d+', args.load_model).group(0))
        print('start from train_iter = %d' % train_iter)

        valid_num = train_iter // args.valid_niter

    hist_valid_scores = []
    train_time = begin_time = time.time()
    print('begin Maximum Likelihood training')

    while True:
        epoch += 1
        print('start of epoch {:d}'.format(epoch))

        for src_sents, tgt_sents in data_iter(train_data,
                                              batch_size=args.batch_size):
            train_iter += 1

            src_sents_var = to_input_variable(src_sents,
                                              vocab.src,
                                              cuda=args.cuda)
            tgt_sents_var = to_input_variable(tgt_sents,
                                              vocab.tgt,
                                              cuda=args.cuda)

            # src_sents_var = to_input_variable(src_sents, vocab.src, cuda=False)
            # tgt_sents_var = to_input_variable(tgt_sents, vocab.tgt, cuda=False)

            batch_size = len(src_sents)
            src_sents_len = [len(s) for s in src_sents]
            pred_tgt_word_num = sum(
                len(s[1:]) for s in tgt_sents)  # omitting leading `<s>`

            optimizer.zero_grad()

            # (tgt_sent_len, batch_size, tgt_vocab_size)
            scores, _ = model(src_sents_var, src_sents_len, tgt_sents_var[:-1])
            # if args.cuda:
            #     tgt_sents_var = tgt_sents_var.cuda()
            word_loss = cross_entropy_loss(scores.view(-1, scores.size(2)),
                                           tgt_sents_var[1:].view(-1))
            loss = word_loss / batch_size
            word_loss_val = word_loss.data[0]
            loss_val = loss.data[0]

            loss.backward()
            # clip gradient
            grad_norm = torch.nn.utils.clip_grad_norm(model.parameters(),
                                                      args.clip_grad)
            optimizer.step()

            report_loss += word_loss_val
            cum_loss += word_loss_val
            report_tgt_words += pred_tgt_word_num
            cum_tgt_words += pred_tgt_word_num
            report_examples += batch_size
            cum_examples += batch_size
            cum_batches += batch_size

            if train_iter % args.log_every == 0:
                print('epoch %d, iter %d, avg. loss %.2f, avg. ppl %.2f ' \
                      'cum. examples %d, speed %.2f words/sec, time elapsed %.2f sec' % (epoch, train_iter,
                                                                                         report_loss / report_examples,
                                                                                         np.exp(report_loss / report_tgt_words),
                                                                                         cum_examples,
                                                                                         report_tgt_words / (time.time() - train_time),
                                                                                         time.time() - begin_time), file=sys.stderr)

                train_time = time.time()
                report_loss = report_tgt_words = report_examples = 0.

            # perform validation
            if train_iter % args.valid_niter == 0:
                print(
                    'epoch %d, iter %d, cum. loss %.2f, cum. ppl %.2f cum. examples %d'
                    % (epoch, train_iter, cum_loss / cum_batches,
                       np.exp(cum_loss / cum_tgt_words), cum_examples),
                    file=sys.stderr)

                cum_loss = cum_batches = cum_tgt_words = 0.
                valid_num += 1

                print('begin validation ...', file=sys.stderr)
                model.eval()

                # compute dev. ppl and bleu

                dev_loss = evaluate_loss(model, dev_data, cross_entropy_loss)
                dev_ppl = np.exp(dev_loss)

                if args.valid_metric in ['bleu', 'word_acc', 'sent_acc']:
                    dev_hyps = decode(model, dev_data)
                    dev_hyps = [hyps[0] for hyps in dev_hyps]
                    print(dev_hyps[:3])
                    if args.valid_metric == 'bleu':
                        valid_metric = get_bleu([tgt for src, tgt in dev_data],
                                                dev_hyps, 'valid')
                    else:
                        valid_metric = get_acc([tgt for src, tgt in dev_data],
                                               dev_hyps,
                                               acc_type=args.valid_metric)
                    print(
                        'validation: iter %d, dev. ppl %f, dev. %s %f' %
                        (train_iter, dev_ppl, args.valid_metric, valid_metric),
                        file=sys.stderr)
                else:
                    valid_metric = -dev_ppl
                    print('validation: iter %d, dev. ppl %f' %
                          (train_iter, dev_ppl),
                          file=sys.stderr)

                model.train()

                is_better = len(hist_valid_scores
                                ) == 0 or valid_metric > max(hist_valid_scores)
                is_better_than_last = len(
                    hist_valid_scores
                ) == 0 or valid_metric > hist_valid_scores[-1]
                hist_valid_scores.append(valid_metric)

                if (not is_better_than_last) and args.lr_decay and epoch > 10:
                    lr = optimizer.param_groups[0]['lr'] * args.lr_decay
                    print('decay learning rate to %f' % lr, file=sys.stderr)
                    optimizer.param_groups[0]['lr'] = lr

                if is_better:
                    patience = 0
                    best_model_iter = train_iter

                    if valid_num > args.save_model_after:
                        model_file = args.save_to + '.iter%d.bin' % train_iter
                        print('save model to [%s]' % model_file,
                              file=sys.stderr)
                        model.save(model_file)

                        # print('save currently the best model ..', file=sys.stderr)
                        # model_file_abs_path = os.path.abspath(model_file)
                        # symlin_file_abs_path = os.path.abspath(args.save_to + '.bin')
                        # os.system('ln -sf %s %s' % (model_file_abs_path, symlin_file_abs_path))
                else:
                    patience += 1
                    print('hit patience %d' % patience, file=sys.stderr)
                    if patience == args.patience:
                        print('early stop!', file=sys.stderr)
                        print('the best model is from iteration [%d]' %
                              best_model_iter,
                              file=sys.stderr)
                        exit(0)
                if abs(optimizer.param_groups[0]['lr'] - 0.0) <= 1e-5:
                    print('stop! because lr is too small', file=sys.stderr)
                    print('the best model is from iteration [%d]' %
                          best_model_iter,
                          file=sys.stderr)
                    exit(0)