Exemplo n.º 1
0
def sample(args):
    print('loading VAE at %s' % args.load_model, file=sys.stderr)
    fname, ext = os.path.splitext(args.load_model)
    encoder_path = fname + '.encoder' + ext
    decoder_path = fname + '.decoder' + ext

    vae_params = torch.load(args.load_model, map_location=lambda storage, loc: storage)
    encoder_params = torch.load(encoder_path, map_location=lambda storage, loc: storage)
    decoder_params = torch.load(decoder_path, map_location=lambda storage, loc: storage)

    transition_system = encoder_params['transition_system']
    vae_params['args'].cuda = encoder_params['args'].cuda = decoder_params['args'].cuda = args.cuda

    decoder = Reconstructor(decoder_params['args'], decoder_params['vocab'], transition_system)
    decoder.load_state_dict(decoder_params['state_dict'])

    assert vae_params['args'].prior == 'lstm'
    prior = LSTMPrior.load(args.load_prior, transition_system=decoder_params['transition_system'], cuda=args.cuda)
    print('loaded prior at %s' % args.load_prior, file=sys.stderr)
    # freeze prior parameters
    for p in prior.parameters():
        p.requires_grad = False

    decoder.eval()
    prior.eval()

    if args.cuda:
        decoder.cuda()
        prior.cuda()

    err_num = 0
    total_num = 0

    # while True:
    for sample_id in xrange(10000):
        sampled_z = prior.sample()
        sampled_z = ' '.join(sampled_z)
        sampled_z = sampled_z.replace(' else :', 'else :').replace(' except ', 'except ').replace(' elif ', 'elif ').replace('<unk>', 'unk')

        print('Z: %s' % sampled_z)
        total_num += 1

        try:
            transition_system.surface_code_to_ast(sampled_z)
        except:
            print('Error!')
            err_num += 1
            continue

        print('Sampled NL sentences:')
        sampled_nls = decoder.sample(sampled_z)
        for i, sampled_nl in enumerate(sampled_nls):
            print('[%d] %s' % (i, ' '.join(sampled_nl)))

        print()

    print('Ratio of well-formed samples: %d/%d=%.5f' % (total_num - err_num, total_num,
                                                        (total_num - err_num) / float(total_num)), file=sys.stderr)
Exemplo n.º 2
0
def train_lstm_lm(args):
    grammar = ASDLGrammar.from_text(open(args.asdl_file).read())
    transition_system = TransitionSystem.get_class_by_lang('lambda_dcs')(
        grammar)

    train_set = Dataset.from_bin_file(args.train_file)
    dev_set = Dataset.from_bin_file(args.dev_file)
    vocab = pickle.load(open(args.vocab)).code

    print('train data size: %d, dev data size: %d' %
          (len(train_set), len(dev_set)),
          file=sys.stderr)
    print('vocab size: %d' % len(vocab), file=sys.stderr)

    model = LSTMPrior(args, vocab, transition_system)
    model.train()
    if args.cuda: model.cuda()
    optimizer = torch.optim.Adam(model.parameters())

    def evaluate_ppl():
        model.eval()
        cum_loss = 0.
        cum_tgt_words = 0.
        for examples in nn_utils.batch_iter(dev_set.examples, args.batch_size):
            batch_tokens = [
                transition_system.tokenize_code(e.tgt_code) for e in examples
            ]
            batch = nn_utils.to_input_variable(batch_tokens,
                                               vocab,
                                               cuda=args.cuda,
                                               append_boundary_sym=True)
            loss = model.forward(batch).sum()
            cum_loss += loss.data[0]
            cum_tgt_words += sum(len(tokens) + 1
                                 for tokens in batch_tokens)  # add ending </s>

        ppl = np.exp(cum_loss / cum_tgt_words)
        model.train()
        return ppl

    print('begin training decoder, %d training examples, %d dev examples' %
          (len(train_set), len(dev_set)),
          file=sys.stderr)

    epoch = num_trial = train_iter = patience = 0
    report_loss = report_examples = 0.
    history_dev_scores = []
    while True:
        epoch += 1
        epoch_begin = time.time()

        for examples in nn_utils.batch_iter(train_set.examples,
                                            batch_size=args.batch_size,
                                            shuffle=True):
            train_iter += 1
            optimizer.zero_grad()

            batch_tokens = [
                transition_system.tokenize_code(e.tgt_code) for e in examples
            ]
            batch = nn_utils.to_input_variable(batch_tokens,
                                               vocab,
                                               cuda=args.cuda,
                                               append_boundary_sym=True)
            loss = model.forward(batch)
            # print(loss.data)
            loss_val = torch.sum(loss).data[0]
            report_loss += loss_val
            report_examples += len(examples)
            loss = torch.mean(loss)

            loss.backward()

            # clip gradient
            grad_norm = torch.nn.utils.clip_grad_norm(model.parameters(),
                                                      args.clip_grad)

            optimizer.step()

            if train_iter % args.log_every == 0:
                print('[Iter %d] encoder loss=%.5f' %
                      (train_iter, report_loss / report_examples),
                      file=sys.stderr)

                report_loss = report_examples = 0.

        print('[Epoch %d] epoch elapsed %ds' %
              (epoch, time.time() - epoch_begin),
              file=sys.stderr)
        # model_file = args.save_to + '.iter%d.bin' % train_iter
        # print('save model to [%s]' % model_file, file=sys.stderr)
        # model.save(model_file)

        # perform validation
        print('[Epoch %d] begin validation' % epoch, file=sys.stderr)
        eval_start = time.time()
        # evaluate ppl
        ppl = evaluate_ppl()
        print('[Epoch %d] ppl=%.5f took %ds' %
              (epoch, ppl, time.time() - eval_start),
              file=sys.stderr)
        dev_acc = -ppl
        is_better = history_dev_scores == [] or dev_acc > max(
            history_dev_scores)
        history_dev_scores.append(dev_acc)

        if is_better:
            patience = 0
            model_file = args.save_to + '.bin'
            print('save currently the best model ..', file=sys.stderr)
            print('save model to [%s]' % model_file, file=sys.stderr)
            model.save(model_file)
            # also save the optimizers' state
            torch.save(optimizer.state_dict(), args.save_to + '.optim.bin')
        elif patience < args.patience:
            patience += 1
            print('hit patience %d' % patience, file=sys.stderr)

        if patience == args.patience:
            num_trial += 1
            print('hit #%d trial' % num_trial, file=sys.stderr)
            if num_trial == args.max_num_trial:
                print('early stop!', file=sys.stderr)
                exit(0)

            # decay lr, and restore from previously best checkpoint
            lr = optimizer.param_groups[0]['lr'] * args.lr_decay
            print('load previously best model and decay learning rate to %f' %
                  lr,
                  file=sys.stderr)

            # load model
            params = torch.load(args.save_to + '.bin',
                                map_location=lambda storage, loc: storage)
            model.load_state_dict(params['state_dict'])
            if args.cuda: model = model.cuda()

            # load optimizers
            if args.reset_optimizer:
                print('reset optimizer', file=sys.stderr)
                optimizer = torch.optim.Adam(
                    model.inference_model.parameters(), lr=lr)
            else:
                print('restore parameters of the optimizers', file=sys.stderr)
                optimizer.load_state_dict(
                    torch.load(args.save_to + '.optim.bin'))

            # set new lr
            for param_group in optimizer.param_groups:
                param_group['lr'] = lr

            # reset patience
            patience = 0
Exemplo n.º 3
0
def train_semi(args):
    encoder_params = torch.load(args.load_model, map_location=lambda storage, loc: storage)
    decoder_params = torch.load(args.load_decoder, map_location=lambda storage, loc: storage)

    print('loaded encoder at %s' % args.load_model, file=sys.stderr)
    print('loaded decoder at %s' % args.load_decoder, file=sys.stderr)

    transition_system = encoder_params['transition_system']
    encoder_params['args'].cuda = decoder_params['args'].cuda = args.cuda

    encoder = Parser(encoder_params['args'], encoder_params['vocab'], transition_system)
    encoder.load_state_dict(encoder_params['state_dict'])
    decoder = Reconstructor(decoder_params['args'], decoder_params['vocab'], transition_system)
    decoder.load_state_dict(decoder_params['state_dict'])

    if args.prior == 'lstm':
        prior = LSTMPrior.load(args.load_prior, transition_system=transition_system, cuda=args.cuda)
        print('loaded prior at %s' % args.load_prior, file=sys.stderr)
        # freeze prior parameters
        for p in prior.parameters():
            p.requires_grad = False
        prior.eval()
    else:
        prior = UniformPrior()

    if args.baseline == 'mlp':
        structVAE = StructVAE(encoder, decoder, prior, args)
    elif args.baseline == 'src_lm' or args.baseline == 'src_lm_and_linear':
        src_lm = LSTMLanguageModel.load(args.load_src_lm)
        print('loaded source LM at %s' % args.load_src_lm, file=sys.stderr)
        vae_cls = StructVAE_LMBaseline if args.baseline == 'src_lm' else StructVAE_SrcLmAndLinearBaseline
        structVAE = vae_cls(encoder, decoder, prior, src_lm, args)
    else:
        raise ValueError('unknown baseline')

    structVAE.train()
    if args.cuda: structVAE.cuda()

    labeled_data = Dataset.from_bin_file(args.train_file)
    # labeled_data.examples = labeled_data.examples[:10]
    unlabeled_data = Dataset.from_bin_file(args.unlabeled_file)   # pretend they are un-labeled!
    dev_set = Dataset.from_bin_file(args.dev_file)
    # dev_set.examples = dev_set.examples[:10]

    optimizer = torch.optim.Adam(ifilter(lambda p: p.requires_grad, structVAE.parameters()), lr=args.lr)

    print('*** begin semi-supervised training %d labeled examples, %d unlabeled examples ***' %
          (len(labeled_data), len(unlabeled_data)), file=sys.stderr)
    report_encoder_loss = report_decoder_loss = report_src_sent_words_num = report_tgt_query_words_num = report_examples = 0.
    report_unsup_examples = report_unsup_encoder_loss = report_unsup_decoder_loss = report_unsup_baseline_loss = 0.
    patience = 0
    num_trial = 1
    epoch = train_iter = 0
    history_dev_scores = []
    while True:
        epoch += 1
        epoch_begin = time.time()
        unlabeled_examples_iter = unlabeled_data.batch_iter(batch_size=args.unsup_batch_size, shuffle=True)

        for labeled_examples in labeled_data.batch_iter(batch_size=args.batch_size, shuffle=True):
            labeled_examples = [e for e in labeled_examples if len(e.tgt_actions) <= args.decode_max_time_step]

            train_iter += 1
            optimizer.zero_grad()
            report_examples += len(labeled_examples)

            sup_encoder_loss = -encoder.score(labeled_examples)
            sup_decoder_loss = -decoder.score(labeled_examples)

            report_encoder_loss += sup_encoder_loss.sum().data[0]
            report_decoder_loss += sup_decoder_loss.sum().data[0]

            sup_encoder_loss = torch.mean(sup_encoder_loss)
            sup_decoder_loss = torch.mean(sup_decoder_loss)

            sup_loss = sup_encoder_loss + sup_decoder_loss

            # compute unsupervised loss
            try:
                unlabeled_examples = next(unlabeled_examples_iter)
            except StopIteration:
                # if finished unlabeled data stream, restart it
                unlabeled_examples_iter = unlabeled_data.batch_iter(batch_size=args.batch_size, shuffle=True)
                unlabeled_examples = next(unlabeled_examples_iter)
                unlabeled_examples = [e for e in unlabeled_examples if len(e.tgt_actions) <= args.decode_max_time_step]

            try:
                unsup_encoder_loss, unsup_decoder_loss, unsup_baseline_loss, meta_data = structVAE.get_unsupervised_loss(
                    unlabeled_examples)

                nan = False
                if nn_utils.isnan(sup_loss.data):
                    print('Nan in sup_loss')
                    nan = True
                if nn_utils.isnan(unsup_encoder_loss.data):
                    print('Nan in unsup_encoder_loss!', file=sys.stderr)
                    nan = True
                if nn_utils.isnan(unsup_decoder_loss.data):
                    print('Nan in unsup_decoder_loss!', file=sys.stderr)
                    nan = True
                if nn_utils.isnan(unsup_baseline_loss.data):
                    print('Nan in unsup_baseline_loss!', file=sys.stderr)
                    nan = True

                if nan:
                    # torch.save((unsup_encoder_loss, unsup_decoder_loss, unsup_baseline_loss, meta_data), 'nan_data.bin')
                    continue

                report_unsup_encoder_loss += unsup_encoder_loss.sum().data[0]
                report_unsup_decoder_loss += unsup_decoder_loss.sum().data[0]
                report_unsup_baseline_loss += unsup_baseline_loss.sum().data[0]
                report_unsup_examples += unsup_encoder_loss.size(0)
            except ValueError as e:
                print(e.message, file=sys.stderr)
                continue
            # except Exception as e:
            #     print('********** Error **********', file=sys.stderr)
            #     print('batch labeled examples: ', file=sys.stderr)
            #     for example in labeled_examples:
            #         print('%s %s' % (example.idx, ' '.join(example.src_sent)), file=sys.stderr)
            #     print('batch unlabeled examples: ', file=sys.stderr)
            #     for example in unlabeled_examples:
            #         print('%s %s' % (example.idx, ' '.join(example.src_sent)), file=sys.stderr)
            #     print(e.message, file=sys.stderr)
            #     traceback.print_exc(file=sys.stderr)
            #     for k, v in meta_data.iteritems():
            #         print('%s: %s' % (k, v), file=sys.stderr)
            #     print('********** Error **********', file=sys.stderr)
            #     continue

            unsup_loss = torch.mean(unsup_encoder_loss) + torch.mean(unsup_decoder_loss) + torch.mean(unsup_baseline_loss)

            loss = sup_loss + args.unsup_loss_weight * unsup_loss

            loss.backward()

            # clip gradient
            grad_norm = torch.nn.utils.clip_grad_norm(structVAE.parameters(), args.clip_grad)
            optimizer.step()

            if train_iter % args.log_every == 0:
                print('[Iter %d] supervised: encoder loss=%.5f, decoder loss=%.5f' %
                      (train_iter,
                       report_encoder_loss / report_examples,
                       report_decoder_loss / report_examples),
                      file=sys.stderr)

                print('[Iter %d] unsupervised: encoder loss=%.5f, decoder loss=%.5f, baseline loss=%.5f' %
                      (train_iter,
                       report_unsup_encoder_loss / report_unsup_examples,
                       report_unsup_decoder_loss / report_unsup_examples,
                       report_unsup_baseline_loss / report_unsup_examples),
                      file=sys.stderr)

                # print('[Iter %d] unsupervised: baseline=%.5f, raw learning signal=%.5f, learning signal=%.5f' % (train_iter,
                #                                                                        meta_data['baseline'].mean().data[0],
                #                                                                        meta_data['raw_learning_signal'].mean().data[0],
                #                                                                        meta_data['learning_signal'].mean().data[0]), file=sys.stderr)

                if isinstance(structVAE, StructVAE_LMBaseline):
                    print('[Iter %d] baseline: source LM b_lm_weight: %.3f, b: %.3f' % (train_iter,
                                                                                        structVAE.b_lm_weight.data[0],
                                                                                        structVAE.b.data[0]),
                          file=sys.stderr)

                samples = meta_data['samples']
                for v in meta_data.itervalues():
                    if isinstance(v, Variable): v.cpu()
                for i, sample in enumerate(samples[:15]):
                    print('\t[%s] Source: %s' % (sample.idx, ' '.join(sample.src_sent)), file=sys.stderr)
                    print('\t[%s] Code: \n%s' % (sample.idx, sample.tgt_code), file=sys.stderr)
                    ref_example = [e for e in unlabeled_examples if e.idx == int(sample.idx[:sample.idx.index('-')])][0]
                    print('\t[%s] Gold Code: \n%s' % (sample.idx, ref_example.tgt_code), file=sys.stderr)
                    print('\t[%s] Log p(z|x): %f' % (sample.idx, meta_data['encoding_scores'][i].data[0]), file=sys.stderr)
                    print('\t[%s] Log p(x|z): %f' % (sample.idx, meta_data['reconstruction_scores'][i].data[0]), file=sys.stderr)
                    print('\t[%s] KL term: %f' % (sample.idx, meta_data['kl_term'][i].data[0]), file=sys.stderr)
                    print('\t[%s] Prior: %f' % (sample.idx, meta_data['prior'][i].data[0]), file=sys.stderr)
                    print('\t[%s] baseline: %f' % (sample.idx, meta_data['baseline'][i].data[0]), file=sys.stderr)
                    print('\t[%s] Raw Learning Signal: %f' % (sample.idx, meta_data['raw_learning_signal'][i].data[0]), file=sys.stderr)
                    print('\t[%s] Learning Signal - baseline: %f' % (sample.idx, meta_data['learning_signal'][i].data[0]), file=sys.stderr)
                    print('\t[%s] Encoder Loss: %f' % (sample.idx, unsup_encoder_loss[i].data[0]), file=sys.stderr)
                    print('\t**************************', file=sys.stderr)

                report_encoder_loss = report_decoder_loss = report_examples = 0.
                report_unsup_encoder_loss = report_unsup_decoder_loss = report_unsup_baseline_loss = report_unsup_examples = 0.

        print('[Epoch %d] epoch elapsed %ds' % (epoch, time.time() - epoch_begin), file=sys.stderr)
        # perform validation
        print('[Epoch %d] begin validation' % epoch, file=sys.stderr)

        eval_start = time.time()
        eval_results = evaluation.evaluate(dev_set.examples, encoder, args, verbose=True)
        dev_acc = eval_results['accuracy']
        print('[Epoch %d] code generation accuracy=%.5f took %ds' % (epoch, dev_acc, time.time() - eval_start),
              file=sys.stderr)
        is_better = history_dev_scores == [] or dev_acc > max(history_dev_scores)
        history_dev_scores.append(dev_acc)

        # model_file = args.save_to + '.iter%d.bin' % train_iter
        # print('save model to [%s]' % model_file, file=sys.stderr)
        # structVAE.save(model_file)

        if is_better:
            patience = 0
            model_file = args.save_to + '.bin'
            print('save currently the best model ..', file=sys.stderr)
            print('save model to [%s]' % model_file, file=sys.stderr)
            structVAE.save(model_file)
            # also save the optimizers' state
            torch.save(optimizer.state_dict(), args.save_to + '.optim.bin')
        elif epoch == args.max_epoch:
            print('reached max epoch, stop!', file=sys.stderr)
            exit(0)
        elif patience < args.patience:
            patience += 1
            print('hit patience %d' % patience, file=sys.stderr)

        if patience == args.patience:
            num_trial += 1
            print('hit #%d trial' % num_trial, file=sys.stderr)
            if num_trial == args.max_num_trial:
                print('early stop!', file=sys.stderr)
                exit(0)

            # decay lr, and restore from previously best checkpoint
            lr = optimizer.param_groups[0]['lr'] * args.lr_decay
            print('load previously best model and decay learning rate to %f' % lr, file=sys.stderr)

            # load best model's parameters
            structVAE.load_parameters(args.save_to + '.bin')
            if args.cuda: structVAE = structVAE.cuda()

            # load optimizers
            if args.reset_optimizer:
                print('reset to a new infer_optimizer', file=sys.stderr)
                optimizer = torch.optim.Adam(ifilter(lambda p: p.requires_grad, structVAE.parameters()), lr=lr)
            else:
                print('restore parameters of the optimizers', file=sys.stderr)
                optimizer.load_state_dict(torch.load(args.save_to + '.optim.bin'))

            # set new lr
            for param_group in optimizer.param_groups:
                param_group['lr'] = lr

            # reset patience
            patience = 0
Exemplo n.º 4
0
def log_semi(args):
    print('loading VAE at %s' % args.load_model, file=sys.stderr)
    fname, ext = os.path.splitext(args.load_model)
    encoder_path = fname + '.encoder' + ext
    decoder_path = fname + '.decoder' + ext

    vae_params = torch.load(args.load_model, map_location=lambda storage, loc: storage)
    encoder_params = torch.load(encoder_path, map_location=lambda storage, loc: storage)
    decoder_params = torch.load(decoder_path, map_location=lambda storage, loc: storage)

    transition_system = encoder_params['transition_system']
    vae_params['args'].cuda = encoder_params['args'].cuda = decoder_params['args'].cuda = args.cuda

    encoder = Parser(encoder_params['args'], encoder_params['vocab'], transition_system)
    decoder = Reconstructor(decoder_params['args'], decoder_params['vocab'], transition_system)

    if vae_params['args'].prior == 'lstm':
        prior = LSTMPrior.load(vae_params['args'].load_prior, transition_system=decoder_params['transition_system'], cuda=args.cuda)
        print('loaded prior at %s' % vae_params['args'].load_prior, file=sys.stderr)
        # freeze prior parameters
        for p in prior.parameters():
            p.requires_grad = False
        prior.eval()
    else:
        prior = UniformPrior()

    if vae_params['args'].baseline == 'mlp':
        structVAE = StructVAE(encoder, decoder, prior, vae_params['args'])
    elif vae_params['args'].baseline == 'src_lm' or vae_params['args'].baseline == 'src_lm_and_linear':
        src_lm = LSTMLanguageModel.load(vae_params['args'].load_src_lm)
        print('loaded source LM at %s' % vae_params['args'].load_src_lm, file=sys.stderr)
        Baseline = StructVAE_LMBaseline if args.baseline == 'src_lm' else StructVAE_SrcLmAndLinearBaseline
        structVAE = Baseline(encoder, decoder, prior, src_lm, vae_params['args'])
    else:
        raise ValueError('unknown baseline')

    structVAE.load_parameters(args.load_model)
    structVAE.train()
    if args.cuda: structVAE.cuda()

    unlabeled_data = Dataset.from_bin_file(args.unlabeled_file)  # pretend they are un-labeled!

    print('*** begin sampling ***', file=sys.stderr)
    start_time = time.time()
    train_iter = 0
    log_entries = []
    for unlabeled_examples in unlabeled_data.batch_iter(batch_size=args.batch_size, shuffle=False):
        unlabeled_examples = [e for e in unlabeled_examples if len(e.tgt_actions) <= args.decode_max_time_step]

        train_iter += 1
        try:
            unsup_encoder_loss, unsup_decoder_loss, unsup_baseline_loss, meta_data = structVAE.get_unsupervised_loss(
                unlabeled_examples)

        except ValueError as e:
            print(e.message, file=sys.stderr)
            continue

        samples = meta_data['samples']
        for v in meta_data.itervalues():
            if isinstance(v, Variable): v.cpu()

        for i, sample in enumerate(samples):
            ref_example = [e for e in unlabeled_examples if e.idx == int(sample.idx[:sample.idx.index('-')])][0]
            log_entry = {
                'sample': sample,
                'ref_example': ref_example,
                'log_p_z_x': meta_data['encoding_scores'][i].data[0],
                'log_p_x_z': meta_data['reconstruction_scores'][i].data[0],
                'kl': meta_data['kl_term'][i].data[0],
                'prior': meta_data['prior'][i].data[0],
                'baseline': meta_data['baseline'][i].data[0],
                'learning_signal': meta_data['raw_learning_signal'][i].data[0],
                'learning_signal - baseline': meta_data['learning_signal'][i].data[0],
                'encoder_loss': unsup_encoder_loss[i].data[0],
                'decoder_loss': unsup_decoder_loss[i].data[0]
            }

            log_entries.append(log_entry)

    print('done! took %d s' % (time.time() - start_time), file=sys.stderr)
    pkl.dump(log_entries, open(args.save_to, 'wb'))
Exemplo n.º 5
0
def train_semi_jae(args):

    bi_direction = args.bi_direction

    encoder_params = torch.load(args.load_model,
                                map_location=lambda storage, loc: storage)
    decoder_params = torch.load(args.load_decoder,
                                map_location=lambda storage, loc: storage)

    print('loaded encoder at %s' % args.load_model, file=sys.stderr)
    print('loaded decoder at %s' % args.load_decoder, file=sys.stderr)

    transition_system = encoder_params['transition_system']
    encoder_params['args'].cuda = decoder_params['args'].cuda = args.cuda

    encoder = Parser(encoder_params['args'], encoder_params['vocab'],
                     transition_system)
    encoder.load_state_dict(encoder_params['state_dict'])
    decoder = Reconstructor(decoder_params['args'], decoder_params['vocab'],
                            transition_system)
    decoder.load_state_dict(decoder_params['state_dict'])

    zprior = LSTMPrior.load(args.load_prior,
                            transition_system=transition_system,
                            cuda=args.cuda)
    print('loaded p(z) prior at %s' % args.load_prior, file=sys.stderr)
    # freeze prior parameters
    for p in zprior.parameters():
        p.requires_grad = False
    zprior.eval()
    xprior = LSTMLanguageModel.load(args.load_src_lm)
    print('loaded p(x) prior at %s' % args.load_src_lm, file=sys.stderr)
    xprior.eval()

    if args.cache:
        jae = JAE_cache(encoder, decoder, zprior, xprior, args)
    else:
        jae = JAE(encoder, decoder, zprior, xprior, args)

    jae.train()
    encoder.train()
    decoder.train()
    if args.cuda: jae.cuda()

    labeled_data = Dataset.from_bin_file(args.train_file)
    # labeled_data.examples = labeled_data.examples[:10]
    unlabeled_data = Dataset.from_bin_file(
        args.unlabeled_file)  # pretend they are un-labeled!
    dev_set = Dataset.from_bin_file(args.dev_file)
    # dev_set.examples = dev_set.examples[:10]

    optimizer = torch.optim.Adam(
        [p for p in jae.parameters() if p.requires_grad], lr=args.lr)

    print(
        '*** begin semi-supervised training %d labeled examples, %d unlabeled examples ***'
        % (len(labeled_data), len(unlabeled_data)),
        file=sys.stderr)
    report_encoder_loss = report_decoder_loss = report_examples = 0.
    report_unsup_examples = report_unsup_encoder_loss = report_unsup_decoder_loss = report_unsup_baseline_loss = 0.
    patience = 0
    num_trial = 1
    epoch = train_iter = 0
    history_dev_scores = []
    while True:
        epoch += 1
        epoch_begin = time.time()
        unlabeled_examples_iter = unlabeled_data.batch_iter(
            batch_size=args.unsup_batch_size, shuffle=True)
        for labeled_examples in labeled_data.batch_iter(
                batch_size=args.batch_size, shuffle=True):
            labeled_examples = [
                e for e in labeled_examples
                if len(e.tgt_actions) <= args.decode_max_time_step
            ]

            train_iter += 1

            optimizer.zero_grad()

            report_examples += len(labeled_examples)

            sup_encoder_loss = -encoder.score(labeled_examples)
            sup_decoder_loss = -decoder.score(labeled_examples)

            report_encoder_loss += sup_encoder_loss.sum().data[0]
            report_decoder_loss += sup_decoder_loss.sum().data[0]

            sup_encoder_loss = torch.mean(sup_encoder_loss)
            sup_decoder_loss = torch.mean(sup_decoder_loss)

            sup_loss = sup_encoder_loss + sup_decoder_loss

            # compute unsupervised loss

            try:
                unlabeled_examples = next(unlabeled_examples_iter)
            except StopIteration:
                # if finished unlabeled data stream, restart it
                unlabeled_examples_iter = unlabeled_data.batch_iter(
                    batch_size=args.batch_size, shuffle=True)
                unlabeled_examples = next(unlabeled_examples_iter)
                unlabeled_examples = [
                    e for e in unlabeled_examples
                    if len(e.tgt_actions) <= args.decode_max_time_step
                ]

            unsup_encoder_loss, unsup_decoder_loss, meta_data = jae.get_unsupervised_loss(
                unlabeled_examples, args.moves)
            if bi_direction:
                unsup_encoder_loss_back, unsup_decoder_loss_back, meta_data_back = jae.get_unsupervised_loss_backward(
                    unlabeled_examples, args.moves)

            nan = False
            if nn_utils.isnan(sup_loss.data):
                print('Nan in sup_loss')
                nan = True
            if nn_utils.isnan(unsup_encoder_loss.data):
                print('Nan in unsup_encoder_loss!', file=sys.stderr)
                nan = True
            if nn_utils.isnan(unsup_decoder_loss.data):
                print('Nan in unsup_decoder_loss!', file=sys.stderr)
                nan = True
            if bi_direction:
                if nn_utils.isnan(unsup_encoder_loss_back.data):
                    print('Nan in unsup_encoder_loss_back!', file=sys.stderr)
                    nan = True
                if nn_utils.isnan(unsup_decoder_loss_back.data):
                    print('Nan in unsup_decoder_loss_back!', file=sys.stderr)
                    nan = True

            if nan:
                continue
            if bi_direction:
                report_unsup_encoder_loss += (
                    unsup_encoder_loss.sum().data[0] +
                    unsup_encoder_loss_back.sum().data[0])
                report_unsup_decoder_loss += (
                    unsup_decoder_loss.sum().data[0] +
                    unsup_decoder_loss_back.sum().data[0])
            else:
                report_unsup_encoder_loss += unsup_encoder_loss.sum().data[0]
                report_unsup_decoder_loss += unsup_decoder_loss.sum().data[0]
            report_unsup_examples += unsup_encoder_loss.size(0)

            if bi_direction:
                unsup_loss = torch.mean(unsup_encoder_loss) + torch.mean(
                    unsup_decoder_loss) + torch.mean(
                        unsup_encoder_loss_back) + torch.mean(
                            unsup_decoder_loss_back)
            else:
                unsup_loss = torch.mean(unsup_encoder_loss) + torch.mean(
                    unsup_decoder_loss)
            loss = sup_loss + args.unsup_loss_weight * unsup_loss

            loss.backward()
            # clip gradient
            grad_norm = torch.nn.utils.clip_grad_norm(jae.parameters(),
                                                      args.clip_grad)
            optimizer.step()
            if train_iter % args.log_every == 0:
                print(
                    '[Iter %d] supervised: encoder loss=%.5f, decoder loss=%.5f'
                    % (train_iter, report_encoder_loss / report_examples,
                       report_decoder_loss / report_examples),
                    file=sys.stderr)

                print(
                    '[Iter %d] unsupervised: encoder loss=%.5f, decoder loss=%.5f, baseline loss=%.5f'
                    % (train_iter,
                       report_unsup_encoder_loss / report_unsup_examples,
                       report_unsup_decoder_loss / report_unsup_examples,
                       report_unsup_baseline_loss / report_unsup_examples),
                    file=sys.stderr)

                samples = meta_data['samples']
                for v in meta_data.values():
                    if isinstance(v, Variable): v.cpu()
                for i, sample in enumerate(samples[:1]):
                    print('\t[%s] Source: %s' %
                          (sample.idx, ' '.join(sample.src_sent)),
                          file=sys.stderr)
                    print('\t[%s] Code: \n%s' % (sample.idx, sample.tgt_code),
                          file=sys.stderr)
                    ref_example = [
                        e for e in unlabeled_examples
                        if e.idx == int(sample.idx[:sample.idx.index('-')])
                    ][0]
                    print('\t[%s] Gold Code: \n%s' %
                          (sample.idx, ref_example.tgt_code),
                          file=sys.stderr)
                    print(
                        '\t[%s] Log p(z|x): %f' %
                        (sample.idx, meta_data['encoding_scores'][i].data[0]),
                        file=sys.stderr)
                    print('\t[%s] Log p(x|z): %f' %
                          (sample.idx,
                           meta_data['reconstruction_scores'][i].data[0]),
                          file=sys.stderr)
                    print('\t[%s] Encoder Loss: %f' %
                          (sample.idx, unsup_encoder_loss[i].data[0]),
                          file=sys.stderr)
                    print('\t**************************', file=sys.stderr)

                report_encoder_loss = report_decoder_loss = report_examples = 0.
                report_unsup_encoder_loss = report_unsup_decoder_loss = report_unsup_baseline_loss = report_unsup_examples = 0.

        print('[Epoch %d] epoch elapsed %ds' %
              (epoch, time.time() - epoch_begin),
              file=sys.stderr)
        # perform validation
        print('[Epoch %d] begin validation' % epoch, file=sys.stderr)

        eval_start = time.time()
        eval_results = evaluation.evaluate(dev_set.examples,
                                           encoder,
                                           args,
                                           verbose=True)
        encoder.train()
        dev_acc = eval_results['accuracy']
        print('[Epoch %d] code generation accuracy=%.5f took %ds' %
              (epoch, dev_acc, time.time() - eval_start),
              file=sys.stderr)
        is_better = history_dev_scores == [] or dev_acc > max(
            history_dev_scores)
        history_dev_scores.append(dev_acc)

        if is_better:
            patience = 0
            model_file = args.save_to + '.bin'
            print('save currently the best model ..', file=sys.stderr)
            print('save model to [%s]' % model_file, file=sys.stderr)
            jae.save(model_file)
            # also save the optimizers' state
            torch.save(optimizer.state_dict(), args.save_to + '.optim.bin')
        elif epoch == args.max_epoch:
            print('reached max epoch, stop!', file=sys.stderr)
            exit(0)
        elif patience < args.patience:
            patience += 1
            print('hit patience %d' % patience, file=sys.stderr)

        if patience == args.patience:
            num_trial += 1
            print('hit #%d trial' % num_trial, file=sys.stderr)
            if num_trial == args.max_num_trial:
                print('early stop!', file=sys.stderr)
                exit(0)

            # decay lr, and restore from previously best checkpoint
            lr = optimizer.param_groups[0]['lr'] * args.lr_decay
            print('load previously best model and decay learning rate to %f' %
                  lr,
                  file=sys.stderr)

            # load best model's parameters
            jae.load_parameters(args.save_to + '.bin')
            if args.cuda: jae = jae.cuda()

            # load optimizers
            if args.reset_optimizer:
                print('reset to a new infer_optimizer', file=sys.stderr)
                optimizer = torch.optim.Adam(
                    [p for p in jae.parameters() if p.requires_grad], lr=lr)
            else:
                print('restore parameters of the optimizers', file=sys.stderr)
                optimizer.load_state_dict(
                    torch.load(args.save_to + '.optim.bin'))

            # set new lr
            for param_group in optimizer.param_groups:
                param_group['lr'] = lr

            # reset patience
            patience = 0