Пример #1
0
def self_training(args):
    """Perform self-training

    First load decoding results on disjoint data
    also load pre-trained model and perform supervised
    training on both existing training data and the
    decoded results
    """

    print('load pre-trained model from [%s]' % args.load_model,
          file=sys.stderr)
    params = torch.load(args.load_model,
                        map_location=lambda storage, loc: storage)
    vocab = params['vocab']
    transition_system = params['transition_system']
    saved_args = params['args']
    saved_state = params['state_dict']

    # transfer arguments
    saved_args.cuda = args.cuda
    saved_args.save_to = args.save_to
    saved_args.train_file = args.train_file
    saved_args.unlabeled_file = args.unlabeled_file
    saved_args.dev_file = args.dev_file
    saved_args.load_decode_results = args.load_decode_results
    args = saved_args

    update_args(args)

    model = Parser(saved_args, vocab, transition_system)
    model.load_state_dict(saved_state)

    if args.cuda: model = model.cuda()
    model.train()
    optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)

    print('load unlabeled data [%s]' % args.unlabeled_file, file=sys.stderr)
    unlabeled_data = Dataset.from_bin_file(args.unlabeled_file)

    print('load decoding results of unlabeled data [%s]' %
          args.load_decode_results,
          file=sys.stderr)
    decode_results = pickle.load(open(args.load_decode_results))

    labeled_data = Dataset.from_bin_file(args.train_file)
    dev_set = Dataset.from_bin_file(args.dev_file)

    print('Num. examples in unlabeled data: %d' % len(unlabeled_data),
          file=sys.stderr)
    assert len(unlabeled_data) == len(decode_results)
    self_train_examples = []
    for example, hyps in zip(unlabeled_data, decode_results):
        if hyps:
            hyp = hyps[0]
            sampled_example = Example(idx='self_train-%s' % example.idx,
                                      src_sent=example.src_sent,
                                      tgt_code=hyp.code,
                                      tgt_actions=hyp.action_infos,
                                      tgt_ast=hyp.tree)
            self_train_examples.append(sampled_example)
    print('Num. self training examples: %d, Num. labeled examples: %d' %
          (len(self_train_examples), len(labeled_data)),
          file=sys.stderr)

    train_set = Dataset(examples=labeled_data.examples + self_train_examples)

    print('begin training, %d training examples, %d dev examples' %
          (len(train_set), len(dev_set)),
          file=sys.stderr)
    print('vocab: %s' % repr(vocab), file=sys.stderr)

    epoch = train_iter = 0
    report_loss = report_examples = 0.
    history_dev_scores = []
    num_trial = patience = 0
    while True:
        epoch += 1
        epoch_begin = time.time()

        for batch_examples in train_set.batch_iter(batch_size=args.batch_size,
                                                   shuffle=True):
            batch_examples = [
                e for e in batch_examples
                if len(e.tgt_actions) <= args.decode_max_time_step
            ]

            train_iter += 1
            optimizer.zero_grad()

            loss = -model.score(batch_examples)
            # print(loss.data)
            loss_val = torch.sum(loss).data[0]
            report_loss += loss_val
            report_examples += len(batch_examples)
            loss = torch.mean(loss)

            loss.backward()

            # clip gradient
            if args.clip_grad > 0.:
                grad_norm = torch.nn.utils.clip_grad_norm(
                    model.parameters(), args.clip_grad)

            optimizer.step()

            if train_iter % args.log_every == 0:
                print('[Iter %d] encoder loss=%.5f' %
                      (train_iter, report_loss / report_examples),
                      file=sys.stderr)

                report_loss = report_examples = 0.

        print('[Epoch %d] epoch elapsed %ds' %
              (epoch, time.time() - epoch_begin),
              file=sys.stderr)
        # model_file = args.save_to + '.iter%d.bin' % train_iter
        # print('save model to [%s]' % model_file, file=sys.stderr)
        # model.save(model_file)

        # perform validation
        print('[Epoch %d] begin validation' % epoch, file=sys.stderr)
        eval_start = time.time()
        eval_results = evaluation.evaluate(dev_set.examples,
                                           model,
                                           args,
                                           verbose=True)
        dev_acc = eval_results['accuracy']
        print('[Epoch %d] code generation accuracy=%.5f took %ds' %
              (epoch, dev_acc, time.time() - eval_start),
              file=sys.stderr)
        is_better = history_dev_scores == [] or dev_acc > max(
            history_dev_scores)
        history_dev_scores.append(dev_acc)

        if is_better:
            patience = 0
            model_file = args.save_to + '.bin'
            print('save currently the best model ..', file=sys.stderr)
            print('save model to [%s]' % model_file, file=sys.stderr)
            model.save(model_file)
            # also save the optimizers' state
            torch.save(optimizer.state_dict(), args.save_to + '.optim.bin')
        elif epoch == args.max_epoch:
            print('reached max epoch, stop!', file=sys.stderr)
            exit(0)
        elif patience < args.patience:
            patience += 1
            print('hit patience %d' % patience, file=sys.stderr)

        if patience == args.patience:
            num_trial += 1
            print('hit #%d trial' % num_trial, file=sys.stderr)
            if num_trial == args.max_num_trial:
                print('early stop!', file=sys.stderr)
                exit(0)

            # decay lr, and restore from previously best checkpoint
            lr = optimizer.param_groups[0]['lr'] * args.lr_decay
            print('load previously best model and decay learning rate to %f' %
                  lr,
                  file=sys.stderr)

            # load model
            params = torch.load(args.save_to + '.bin',
                                map_location=lambda storage, loc: storage)
            model.load_state_dict(params['state_dict'])
            if args.cuda: model = model.cuda()

            # load optimizers
            if args.reset_optimizer:
                print('reset optimizer', file=sys.stderr)
                optimizer = torch.optim.Adam(
                    model.inference_model.parameters(), lr=lr)
            else:
                print('restore parameters of the optimizers', file=sys.stderr)
                optimizer.load_state_dict(
                    torch.load(args.save_to + '.optim.bin'))

            # set new lr
            for param_group in optimizer.param_groups:
                param_group['lr'] = lr

            # reset patience
            patience = 0
Пример #2
0
def train(args):
    grammar = ASDLGrammar.from_text(open(args.asdl_file).read())
    transition_system = TransitionSystem.get_class_by_lang(args.lang)(grammar)
    train_set = Dataset.from_bin_file(args.train_file)
    dev_set = Dataset.from_bin_file(args.dev_file)
    vocab = pickle.load(open(args.vocab))

    model = Parser(args, vocab, transition_system)
    model.train()
    if args.cuda: model.cuda()
    optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)

    print('begin training, %d training examples, %d dev examples' % (len(train_set), len(dev_set)), file=sys.stderr)
    print('vocab: %s' % repr(vocab), file=sys.stderr)

    epoch = train_iter = 0
    report_loss = report_examples = 0.
    history_dev_scores = []
    num_trial = patience = 0
    while True:
        epoch += 1
        epoch_begin = time.time()

        for batch_examples in train_set.batch_iter(batch_size=args.batch_size, shuffle=True):
            batch_examples = [e for e in batch_examples if len(e.tgt_actions) <= args.decode_max_time_step]

            train_iter += 1
            optimizer.zero_grad()

            loss = -model.score(batch_examples)
            # print(loss.data)
            loss_val = torch.sum(loss).data[0]
            report_loss += loss_val
            report_examples += len(batch_examples)
            loss = torch.mean(loss)

            loss.backward()

            # clip gradient
            if args.clip_grad > 0.:
                grad_norm = torch.nn.utils.clip_grad_norm(model.parameters(), args.clip_grad)

            optimizer.step()

            if train_iter % args.log_every == 0:
                print('[Iter %d] encoder loss=%.5f' %
                      (train_iter,
                       report_loss / report_examples),
                      file=sys.stderr)

                report_loss = report_examples = 0.

        print('[Epoch %d] epoch elapsed %ds' % (epoch, time.time() - epoch_begin), file=sys.stderr)
        # model_file = args.save_to + '.iter%d.bin' % train_iter
        # print('save model to [%s]' % model_file, file=sys.stderr)
        # model.save(model_file)

        # perform validation
        print('[Epoch %d] begin validation' % epoch, file=sys.stderr)
        eval_start = time.time()
        eval_results = evaluation.evaluate(dev_set.examples, model, args, verbose=True)
        dev_acc = eval_results['accuracy']
        print('[Epoch %d] code generation accuracy=%.5f took %ds' % (epoch, dev_acc, time.time() - eval_start), file=sys.stderr)
        is_better = history_dev_scores == [] or dev_acc > max(history_dev_scores)
        history_dev_scores.append(dev_acc)

        if is_better:
            patience = 0
            model_file = args.save_to + '.bin'
            print('save currently the best model ..', file=sys.stderr)
            print('save model to [%s]' % model_file, file=sys.stderr)
            model.save(model_file)
            # also save the optimizers' state
            torch.save(optimizer.state_dict(), args.save_to + '.optim.bin')
        elif epoch == args.max_epoch:
            print('reached max epoch, stop!', file=sys.stderr)
            exit(0)
        elif patience < args.patience:
            patience += 1
            print('hit patience %d' % patience, file=sys.stderr)

        if patience == args.patience:
            num_trial += 1
            print('hit #%d trial' % num_trial, file=sys.stderr)
            if num_trial == args.max_num_trial:
                print('early stop!', file=sys.stderr)
                exit(0)

            # decay lr, and restore from previously best checkpoint
            lr = optimizer.param_groups[0]['lr'] * args.lr_decay
            print('load previously best model and decay learning rate to %f' % lr, file=sys.stderr)

            # load model
            params = torch.load(args.save_to + '.bin', map_location=lambda storage, loc: storage)
            model.load_state_dict(params['state_dict'])
            if args.cuda: model = model.cuda()

            # load optimizers
            if args.reset_optimizer:
                print('reset optimizer', file=sys.stderr)
                optimizer = torch.optim.Adam(model.inference_model.parameters(), lr=lr)
            else:
                print('restore parameters of the optimizers', file=sys.stderr)
                optimizer.load_state_dict(torch.load(args.save_to + '.optim.bin'))

            # set new lr
            for param_group in optimizer.param_groups:
                param_group['lr'] = lr

            # reset patience
            patience = 0
Пример #3
0
def self_training(args):
    """Perform self-training

    First load decoding results on disjoint data
    also load pre-trained model and perform supervised
    training on both existing training data and the
    decoded results
    """

    print('load pre-trained model from [%s]' % args.load_model, file=sys.stderr)
    params = torch.load(args.load_model, map_location=lambda storage, loc: storage)
    vocab = params['vocab']
    transition_system = params['transition_system']
    saved_args = params['args']
    saved_state = params['state_dict']

    # transfer arguments
    saved_args.cuda = args.cuda
    saved_args.save_to = args.save_to
    saved_args.train_file = args.train_file
    saved_args.unlabeled_file = args.unlabeled_file
    saved_args.dev_file = args.dev_file
    saved_args.load_decode_results = args.load_decode_results
    args = saved_args

    update_args(args)

    model = Parser(saved_args, vocab, transition_system)
    model.load_state_dict(saved_state)

    if args.cuda: model = model.cuda()
    model.train()
    optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)

    print('load unlabeled data [%s]' % args.unlabeled_file, file=sys.stderr)
    unlabeled_data = Dataset.from_bin_file(args.unlabeled_file)

    print('load decoding results of unlabeled data [%s]' % args.load_decode_results, file=sys.stderr)
    decode_results = pickle.load(open(args.load_decode_results))

    labeled_data = Dataset.from_bin_file(args.train_file)
    dev_set = Dataset.from_bin_file(args.dev_file)

    print('Num. examples in unlabeled data: %d' % len(unlabeled_data), file=sys.stderr)
    assert len(unlabeled_data) == len(decode_results)
    self_train_examples = []
    for example, hyps in zip(unlabeled_data, decode_results):
        if hyps:
            hyp = hyps[0]
            sampled_example = Example(idx='self_train-%s' % example.idx,
                                      src_sent=example.src_sent,
                                      tgt_code=hyp.code,
                                      tgt_actions=hyp.action_infos,
                                      tgt_ast=hyp.tree)
            self_train_examples.append(sampled_example)
    print('Num. self training examples: %d, Num. labeled examples: %d' % (len(self_train_examples), len(labeled_data)),
          file=sys.stderr)

    train_set = Dataset(examples=labeled_data.examples + self_train_examples)

    print('begin training, %d training examples, %d dev examples' % (len(train_set), len(dev_set)), file=sys.stderr)
    print('vocab: %s' % repr(vocab), file=sys.stderr)

    epoch = train_iter = 0
    report_loss = report_examples = 0.
    history_dev_scores = []
    num_trial = patience = 0
    while True:
        epoch += 1
        epoch_begin = time.time()

        for batch_examples in train_set.batch_iter(batch_size=args.batch_size, shuffle=True):
            batch_examples = [e for e in batch_examples if len(e.tgt_actions) <= args.decode_max_time_step]

            train_iter += 1
            optimizer.zero_grad()

            loss = -model.score(batch_examples)
            # print(loss.data)
            loss_val = torch.sum(loss).data[0]
            report_loss += loss_val
            report_examples += len(batch_examples)
            loss = torch.mean(loss)

            loss.backward()

            # clip gradient
            if args.clip_grad > 0.:
                grad_norm = torch.nn.utils.clip_grad_norm(model.parameters(), args.clip_grad)

            optimizer.step()

            if train_iter % args.log_every == 0:
                print('[Iter %d] encoder loss=%.5f' %
                      (train_iter,
                       report_loss / report_examples),
                      file=sys.stderr)

                report_loss = report_examples = 0.

        print('[Epoch %d] epoch elapsed %ds' % (epoch, time.time() - epoch_begin), file=sys.stderr)
        # model_file = args.save_to + '.iter%d.bin' % train_iter
        # print('save model to [%s]' % model_file, file=sys.stderr)
        # model.save(model_file)

        # perform validation
        print('[Epoch %d] begin validation' % epoch, file=sys.stderr)
        eval_start = time.time()
        eval_results = evaluation.evaluate(dev_set.examples, model, args, verbose=True)
        dev_acc = eval_results['accuracy']
        print('[Epoch %d] code generation accuracy=%.5f took %ds' % (epoch, dev_acc, time.time() - eval_start), file=sys.stderr)
        is_better = history_dev_scores == [] or dev_acc > max(history_dev_scores)
        history_dev_scores.append(dev_acc)

        if is_better:
            patience = 0
            model_file = args.save_to + '.bin'
            print('save currently the best model ..', file=sys.stderr)
            print('save model to [%s]' % model_file, file=sys.stderr)
            model.save(model_file)
            # also save the optimizers' state
            torch.save(optimizer.state_dict(), args.save_to + '.optim.bin')
        elif epoch == args.max_epoch:
            print('reached max epoch, stop!', file=sys.stderr)
            exit(0)
        elif patience < args.patience:
            patience += 1
            print('hit patience %d' % patience, file=sys.stderr)

        if patience == args.patience:
            num_trial += 1
            print('hit #%d trial' % num_trial, file=sys.stderr)
            if num_trial == args.max_num_trial:
                print('early stop!', file=sys.stderr)
                exit(0)

            # decay lr, and restore from previously best checkpoint
            lr = optimizer.param_groups[0]['lr'] * args.lr_decay
            print('load previously best model and decay learning rate to %f' % lr, file=sys.stderr)

            # load model
            params = torch.load(args.save_to + '.bin', map_location=lambda storage, loc: storage)
            model.load_state_dict(params['state_dict'])
            if args.cuda: model = model.cuda()

            # load optimizers
            if args.reset_optimizer:
                print('reset optimizer', file=sys.stderr)
                optimizer = torch.optim.Adam(model.inference_model.parameters(), lr=lr)
            else:
                print('restore parameters of the optimizers', file=sys.stderr)
                optimizer.load_state_dict(torch.load(args.save_to + '.optim.bin'))

            # set new lr
            for param_group in optimizer.param_groups:
                param_group['lr'] = lr

            # reset patience
            patience = 0
Пример #4
0
def train_semi_jae(args):

    bi_direction = args.bi_direction

    encoder_params = torch.load(args.load_model,
                                map_location=lambda storage, loc: storage)
    decoder_params = torch.load(args.load_decoder,
                                map_location=lambda storage, loc: storage)

    print('loaded encoder at %s' % args.load_model, file=sys.stderr)
    print('loaded decoder at %s' % args.load_decoder, file=sys.stderr)

    transition_system = encoder_params['transition_system']
    encoder_params['args'].cuda = decoder_params['args'].cuda = args.cuda

    encoder = Parser(encoder_params['args'], encoder_params['vocab'],
                     transition_system)
    encoder.load_state_dict(encoder_params['state_dict'])
    decoder = Reconstructor(decoder_params['args'], decoder_params['vocab'],
                            transition_system)
    decoder.load_state_dict(decoder_params['state_dict'])

    zprior = LSTMPrior.load(args.load_prior,
                            transition_system=transition_system,
                            cuda=args.cuda)
    print('loaded p(z) prior at %s' % args.load_prior, file=sys.stderr)
    # freeze prior parameters
    for p in zprior.parameters():
        p.requires_grad = False
    zprior.eval()
    xprior = LSTMLanguageModel.load(args.load_src_lm)
    print('loaded p(x) prior at %s' % args.load_src_lm, file=sys.stderr)
    xprior.eval()

    if args.cache:
        jae = JAE_cache(encoder, decoder, zprior, xprior, args)
    else:
        jae = JAE(encoder, decoder, zprior, xprior, args)

    jae.train()
    encoder.train()
    decoder.train()
    if args.cuda: jae.cuda()

    labeled_data = Dataset.from_bin_file(args.train_file)
    # labeled_data.examples = labeled_data.examples[:10]
    unlabeled_data = Dataset.from_bin_file(
        args.unlabeled_file)  # pretend they are un-labeled!
    dev_set = Dataset.from_bin_file(args.dev_file)
    # dev_set.examples = dev_set.examples[:10]

    optimizer = torch.optim.Adam(
        [p for p in jae.parameters() if p.requires_grad], lr=args.lr)

    print(
        '*** begin semi-supervised training %d labeled examples, %d unlabeled examples ***'
        % (len(labeled_data), len(unlabeled_data)),
        file=sys.stderr)
    report_encoder_loss = report_decoder_loss = report_examples = 0.
    report_unsup_examples = report_unsup_encoder_loss = report_unsup_decoder_loss = report_unsup_baseline_loss = 0.
    patience = 0
    num_trial = 1
    epoch = train_iter = 0
    history_dev_scores = []
    while True:
        epoch += 1
        epoch_begin = time.time()
        unlabeled_examples_iter = unlabeled_data.batch_iter(
            batch_size=args.unsup_batch_size, shuffle=True)
        for labeled_examples in labeled_data.batch_iter(
                batch_size=args.batch_size, shuffle=True):
            labeled_examples = [
                e for e in labeled_examples
                if len(e.tgt_actions) <= args.decode_max_time_step
            ]

            train_iter += 1

            optimizer.zero_grad()

            report_examples += len(labeled_examples)

            sup_encoder_loss = -encoder.score(labeled_examples)
            sup_decoder_loss = -decoder.score(labeled_examples)

            report_encoder_loss += sup_encoder_loss.sum().data[0]
            report_decoder_loss += sup_decoder_loss.sum().data[0]

            sup_encoder_loss = torch.mean(sup_encoder_loss)
            sup_decoder_loss = torch.mean(sup_decoder_loss)

            sup_loss = sup_encoder_loss + sup_decoder_loss

            # compute unsupervised loss

            try:
                unlabeled_examples = next(unlabeled_examples_iter)
            except StopIteration:
                # if finished unlabeled data stream, restart it
                unlabeled_examples_iter = unlabeled_data.batch_iter(
                    batch_size=args.batch_size, shuffle=True)
                unlabeled_examples = next(unlabeled_examples_iter)
                unlabeled_examples = [
                    e for e in unlabeled_examples
                    if len(e.tgt_actions) <= args.decode_max_time_step
                ]

            unsup_encoder_loss, unsup_decoder_loss, meta_data = jae.get_unsupervised_loss(
                unlabeled_examples, args.moves)
            if bi_direction:
                unsup_encoder_loss_back, unsup_decoder_loss_back, meta_data_back = jae.get_unsupervised_loss_backward(
                    unlabeled_examples, args.moves)

            nan = False
            if nn_utils.isnan(sup_loss.data):
                print('Nan in sup_loss')
                nan = True
            if nn_utils.isnan(unsup_encoder_loss.data):
                print('Nan in unsup_encoder_loss!', file=sys.stderr)
                nan = True
            if nn_utils.isnan(unsup_decoder_loss.data):
                print('Nan in unsup_decoder_loss!', file=sys.stderr)
                nan = True
            if bi_direction:
                if nn_utils.isnan(unsup_encoder_loss_back.data):
                    print('Nan in unsup_encoder_loss_back!', file=sys.stderr)
                    nan = True
                if nn_utils.isnan(unsup_decoder_loss_back.data):
                    print('Nan in unsup_decoder_loss_back!', file=sys.stderr)
                    nan = True

            if nan:
                continue
            if bi_direction:
                report_unsup_encoder_loss += (
                    unsup_encoder_loss.sum().data[0] +
                    unsup_encoder_loss_back.sum().data[0])
                report_unsup_decoder_loss += (
                    unsup_decoder_loss.sum().data[0] +
                    unsup_decoder_loss_back.sum().data[0])
            else:
                report_unsup_encoder_loss += unsup_encoder_loss.sum().data[0]
                report_unsup_decoder_loss += unsup_decoder_loss.sum().data[0]
            report_unsup_examples += unsup_encoder_loss.size(0)

            if bi_direction:
                unsup_loss = torch.mean(unsup_encoder_loss) + torch.mean(
                    unsup_decoder_loss) + torch.mean(
                        unsup_encoder_loss_back) + torch.mean(
                            unsup_decoder_loss_back)
            else:
                unsup_loss = torch.mean(unsup_encoder_loss) + torch.mean(
                    unsup_decoder_loss)
            loss = sup_loss + args.unsup_loss_weight * unsup_loss

            loss.backward()
            # clip gradient
            grad_norm = torch.nn.utils.clip_grad_norm(jae.parameters(),
                                                      args.clip_grad)
            optimizer.step()
            if train_iter % args.log_every == 0:
                print(
                    '[Iter %d] supervised: encoder loss=%.5f, decoder loss=%.5f'
                    % (train_iter, report_encoder_loss / report_examples,
                       report_decoder_loss / report_examples),
                    file=sys.stderr)

                print(
                    '[Iter %d] unsupervised: encoder loss=%.5f, decoder loss=%.5f, baseline loss=%.5f'
                    % (train_iter,
                       report_unsup_encoder_loss / report_unsup_examples,
                       report_unsup_decoder_loss / report_unsup_examples,
                       report_unsup_baseline_loss / report_unsup_examples),
                    file=sys.stderr)

                samples = meta_data['samples']
                for v in meta_data.values():
                    if isinstance(v, Variable): v.cpu()
                for i, sample in enumerate(samples[:1]):
                    print('\t[%s] Source: %s' %
                          (sample.idx, ' '.join(sample.src_sent)),
                          file=sys.stderr)
                    print('\t[%s] Code: \n%s' % (sample.idx, sample.tgt_code),
                          file=sys.stderr)
                    ref_example = [
                        e for e in unlabeled_examples
                        if e.idx == int(sample.idx[:sample.idx.index('-')])
                    ][0]
                    print('\t[%s] Gold Code: \n%s' %
                          (sample.idx, ref_example.tgt_code),
                          file=sys.stderr)
                    print(
                        '\t[%s] Log p(z|x): %f' %
                        (sample.idx, meta_data['encoding_scores'][i].data[0]),
                        file=sys.stderr)
                    print('\t[%s] Log p(x|z): %f' %
                          (sample.idx,
                           meta_data['reconstruction_scores'][i].data[0]),
                          file=sys.stderr)
                    print('\t[%s] Encoder Loss: %f' %
                          (sample.idx, unsup_encoder_loss[i].data[0]),
                          file=sys.stderr)
                    print('\t**************************', file=sys.stderr)

                report_encoder_loss = report_decoder_loss = report_examples = 0.
                report_unsup_encoder_loss = report_unsup_decoder_loss = report_unsup_baseline_loss = report_unsup_examples = 0.

        print('[Epoch %d] epoch elapsed %ds' %
              (epoch, time.time() - epoch_begin),
              file=sys.stderr)
        # perform validation
        print('[Epoch %d] begin validation' % epoch, file=sys.stderr)

        eval_start = time.time()
        eval_results = evaluation.evaluate(dev_set.examples,
                                           encoder,
                                           args,
                                           verbose=True)
        encoder.train()
        dev_acc = eval_results['accuracy']
        print('[Epoch %d] code generation accuracy=%.5f took %ds' %
              (epoch, dev_acc, time.time() - eval_start),
              file=sys.stderr)
        is_better = history_dev_scores == [] or dev_acc > max(
            history_dev_scores)
        history_dev_scores.append(dev_acc)

        if is_better:
            patience = 0
            model_file = args.save_to + '.bin'
            print('save currently the best model ..', file=sys.stderr)
            print('save model to [%s]' % model_file, file=sys.stderr)
            jae.save(model_file)
            # also save the optimizers' state
            torch.save(optimizer.state_dict(), args.save_to + '.optim.bin')
        elif epoch == args.max_epoch:
            print('reached max epoch, stop!', file=sys.stderr)
            exit(0)
        elif patience < args.patience:
            patience += 1
            print('hit patience %d' % patience, file=sys.stderr)

        if patience == args.patience:
            num_trial += 1
            print('hit #%d trial' % num_trial, file=sys.stderr)
            if num_trial == args.max_num_trial:
                print('early stop!', file=sys.stderr)
                exit(0)

            # decay lr, and restore from previously best checkpoint
            lr = optimizer.param_groups[0]['lr'] * args.lr_decay
            print('load previously best model and decay learning rate to %f' %
                  lr,
                  file=sys.stderr)

            # load best model's parameters
            jae.load_parameters(args.save_to + '.bin')
            if args.cuda: jae = jae.cuda()

            # load optimizers
            if args.reset_optimizer:
                print('reset to a new infer_optimizer', file=sys.stderr)
                optimizer = torch.optim.Adam(
                    [p for p in jae.parameters() if p.requires_grad], lr=lr)
            else:
                print('restore parameters of the optimizers', file=sys.stderr)
                optimizer.load_state_dict(
                    torch.load(args.save_to + '.optim.bin'))

            # set new lr
            for param_group in optimizer.param_groups:
                param_group['lr'] = lr

            # reset patience
            patience = 0