Beispiel #1
0
def decode(args: Dict[str, str]):

    test_data_src_code, failed_ids_src_code = read_corpus(
        args['TEST_SOURCE_CODE_FILE'], source='src_code')
    test_data_src_nl, failed_ids_src_nl = read_corpus(
        args['TEST_SOURCE_NL_FILE'], source='src_nl')

    test_data_tgt, failed_ids_tgt = read_corpus(args['TEST_TARGET_FILE'],
                                                source='tgt')

    total_failed_ids = set(failed_ids_src_code).union(failed_ids_tgt).union(
        failed_ids_src_nl)

    test_data_src_code = [
        test_data_src_code[i] for i in range(len(test_data_src_code))
        if i not in total_failed_ids
    ]
    test_data_src_nl = [
        test_data_src_nl[i] for i in range(len(test_data_src_nl))
        if i not in total_failed_ids
    ]
    test_data_tgt = [
        test_data_tgt[i] for i in range(len(test_data_tgt))
        if i not in total_failed_ids
    ]

    print(f"load model from {args['MODEL_PATH']}", file=sys.stderr)

    model = NMT.load(args['MODEL_PATH'])

    if args['--cuda']:
        model = model.to(torch.device("cuda:1"))

    hypotheses = beam_search(model,
                             test_data_src_code,
                             test_data_src_nl,
                             beam_size=int(args['--beam-size']),
                             max_decoding_time_step=int(
                                 args['--max-decoding-time-step']))

    if args['TEST_TARGET_FILE']:
        top_hypotheses = [hyps[0] for hyps in hypotheses]
        bleu_score = compute_corpus_level_bleu_score(test_data_tgt,
                                                     top_hypotheses)
        print(f'Corpus BLEU: {bleu_score}', file=sys.stderr)

    with open(args['OUTPUT_FILE'], 'w') as f:
        for src_code_sent, src_nl_sent, hyps in zip(test_data_src_code,
                                                    test_data_src_nl,
                                                    hypotheses):
            top_hyp = hyps[0]
            hyp_sent = ' '.join(top_hyp.value)
            f.write(hyp_sent + '\n')
        print('initialize target vocabulary ..')
        self.tgt = VocabEntry.from_corpus(tgt_sents, vocab_size, freq_cutoff)

    def __repr__(self):
        return 'Vocab(source code %d words, source nl %d words, target %d words)' % (
            len(self.src_code), len(self.src_nl), len(self.tgt))


if __name__ == '__main__':
    args = docopt(__doc__)

    print('read in source code sentences: %s' % args['--train-src-code'])
    print('read in source nl sentences: %s' % args['--train-src-nl'])
    print('read in target sentences: %s' % args['--train-tgt'])

    src_sents_code, src_f_ids_code = read_corpus(args['--train-src-code'],
                                                 source='src_code')
    tgt_sents, tgt_f_ids = read_corpus(args['--train-tgt'], source='tgt')
    src_sents_nl, src_f_ids_nl = read_corpus(args['--train-src-nl'],
                                             source='src_nl')

    total_failed_ids = set(src_f_ids_code).union(set(tgt_f_ids)).union(
        set(src_f_ids_nl))

    src_sents_code = [
        src_sents_code[i] for i in range(len(src_sents_code))
        if i not in total_failed_ids
    ]
    src_sents_nl = [
        src_sents_nl[i] for i in range(len(src_sents_nl))
        if i not in total_failed_ids
    ]
Beispiel #3
0
def train(args: Dict):

    train_data_src_code, failed_train_src_code_ids = read_corpus(
        args['--train-src-code'], source='src_code')
    train_data_src_nl, failed_train_src_nl_ids = read_corpus(
        args['--train-src-nl'], source='src_nl')
    train_data_tgt, failed_train_tgt_ids = read_corpus(args['--train-tgt'],
                                                       source='tgt')

    dev_data_src_code, failed_dev_src_code_ids = read_corpus(
        args['--dev-src-code'], source='src_code')
    dev_data_src_nl, failed_dev_src_nl_ids = read_corpus(args['--dev-src-nl'],
                                                         source='src_nl')
    dev_data_tgt, failed_dev_tgt_ids = read_corpus(args['--dev-tgt'],
                                                   source='tgt')

    total_failed_ids = set(failed_train_src_nl_ids).union(
        failed_train_tgt_ids).union(failed_train_src_code_ids)

    train_data_src_code = [
        train_data_src_code[i] for i in range(len(train_data_src_code))
        if i not in total_failed_ids
    ]
    train_data_src_nl = [
        train_data_src_nl[i] for i in range(len(train_data_src_nl))
        if i not in total_failed_ids
    ]
    train_data_tgt = [
        train_data_tgt[i] for i in range(len(train_data_tgt))
        if i not in total_failed_ids
    ]

    total_failed_ids = set(failed_dev_src_nl_ids).union(
        failed_dev_tgt_ids).union(failed_dev_src_code_ids)

    dev_data_src_code = [
        dev_data_src_code[i] for i in range(len(dev_data_src_code))
        if i not in total_failed_ids
    ]
    dev_data_src_nl = [
        dev_data_src_nl[i] for i in range(len(dev_data_src_nl))
        if i not in total_failed_ids
    ]
    dev_data_tgt = [
        dev_data_tgt[i] for i in range(len(dev_data_tgt))
        if i not in total_failed_ids
    ]

    train_data = list(
        zip(train_data_src_code, train_data_src_nl, train_data_tgt))
    dev_data = list(zip(dev_data_src_code, dev_data_src_nl, dev_data_tgt))

    train_batch_size = int(args['--batch-size'])
    clip_grad = float(args['--clip-grad'])
    valid_niter = int(args['--valid-niter'])
    log_every = int(args['--log-every'])
    model_save_path = args['--save-to']

    vocab = pickle.load(open(args['--vocab'], 'rb'))

    model = NMT(embed_size=int(args['--embed-size']),
                hidden_size=int(args['--hidden-size']),
                dropout_rate=float(args['--dropout']),
                input_feed=True,
                label_smoothing=float(args['--label-smoothing']),
                vocab=vocab)
    model.train()

    uniform_init = float(args['--uniform-init'])
    if np.abs(uniform_init) > 0.:
        print('uniformly initialize parameters [-%f, +%f]' %
              (uniform_init, uniform_init),
              file=sys.stderr)
        for p in model.parameters():
            p.data.uniform_(-uniform_init, uniform_init)

    vocab_mask = torch.ones(len(vocab.tgt))
    vocab_mask[vocab.tgt['<pad>']] = 0

    device = torch.device("cuda:1" if args['--cuda'] else "cpu")
    print('use device: %s' % device, file=sys.stderr)

    model = model.to(device)

    optimizer = torch.optim.Adam(model.parameters(), lr=float(args['--lr']))

    num_trial = 0
    train_iter = patience = cum_loss = report_loss = cum_tgt_words = report_tgt_words = 0
    cum_examples = report_examples = epoch = valid_num = 0
    hist_valid_scores = []
    train_time = begin_time = time.time()
    print('begin Maximum Likelihood training')

    while True:
        epoch += 1

        for src_code_sents, src_nl_sents, tgt_sents in batch_iter(
                train_data, batch_size=train_batch_size, shuffle=True):
            train_iter += 1

            optimizer.zero_grad()

            batch_size = len(src_code_sents)
            example_losses = -model(src_code_sents, src_nl_sents, tgt_sents)
            batch_loss = example_losses.sum()
            loss = batch_loss / batch_size

            loss.backward()

            # clip gradient
            grad_norm = torch.nn.utils.clip_grad_norm(model.parameters(),
                                                      clip_grad)

            optimizer.step()

            batch_losses_val = batch_loss.item()
            report_loss += batch_losses_val
            cum_loss += batch_losses_val

            tgt_words_num_to_predict = sum(
                len(s[1:]) for s in tgt_sents)  # omitting leading `<s>`
            report_tgt_words += tgt_words_num_to_predict
            cum_tgt_words += tgt_words_num_to_predict
            report_examples += batch_size
            cum_examples += batch_size

            if train_iter % log_every == 0:
                print('epoch %d, iter %d, avg. loss %.2f, avg. ppl %.2f ' \
                      'cum. examples %d, speed %.2f words/sec, time elapsed %.2f sec' % (epoch, train_iter,
                                                                                         report_loss / report_examples,
                                                                                         math.exp(
                                                                                             report_loss / report_tgt_words),
                                                                                         cum_examples,
                                                                                         report_tgt_words / (
                                                                                         time.time() - train_time),
                                                                                         time.time() - begin_time),
                      file=sys.stderr)

                train_time = time.time()
                report_loss = report_tgt_words = report_examples = 0.

            # perform validation
            if train_iter % valid_niter == 0:
                print(
                    'epoch %d, iter %d, cum. loss %.2f, cum. ppl %.2f cum. examples %d'
                    % (epoch, train_iter, cum_loss / cum_examples,
                       np.exp(cum_loss / cum_tgt_words), cum_examples),
                    file=sys.stderr)

                cum_loss = cum_examples = cum_tgt_words = 0.
                valid_num += 1

                print('begin validation ...', file=sys.stderr)

                # compute dev. ppl and bleu
                dev_ppl = evaluate_ppl(
                    model, dev_data,
                    batch_size=16)  # dev batch size can be a bit larger
                valid_metric = -dev_ppl

                print('validation: iter %d, dev. ppl %f' %
                      (train_iter, dev_ppl),
                      file=sys.stderr)

                is_better = len(hist_valid_scores
                                ) == 0 or valid_metric > max(hist_valid_scores)
                hist_valid_scores.append(valid_metric)

                if is_better:
                    patience = 0
                    print('save currently the best model to [%s]' %
                          model_save_path,
                          file=sys.stderr)
                    model.save(model_save_path)

                    # also save the optimizers' state
                    torch.save(optimizer.state_dict(),
                               model_save_path + '.optim')
                elif patience < int(args['--patience']):
                    patience += 1
                    print('hit patience %d' % patience, file=sys.stderr)

                    if patience == int(args['--patience']):
                        num_trial += 1
                        print('hit #%d trial' % num_trial, file=sys.stderr)
                        if num_trial == int(args['--max-num-trial']):
                            print('early stop!', file=sys.stderr)
                            exit(0)

                        # decay lr, and restore from previously best checkpoint
                        lr = optimizer.param_groups[0]['lr'] * float(
                            args['--lr-decay'])
                        print(
                            'load previously best model and decay learning rate to %f'
                            % lr,
                            file=sys.stderr)

                        # load model
                        params = torch.load(
                            model_save_path,
                            map_location=lambda storage, loc: storage)
                        model.load_state_dict(params['state_dict'])
                        model = model.to(device)

                        print('restore parameters of the optimizers',
                              file=sys.stderr)
                        optimizer.load_state_dict(
                            torch.load(model_save_path + '.optim'))

                        # set new lr
                        for param_group in optimizer.param_groups:
                            param_group['lr'] = lr

                        # reset patience
                        patience = 0

            if epoch == int(args['--max-epoch']):
                print('reached maximum number of epochs!', file=sys.stderr)
                exit(0)
import utils_multi

data, fail1 = utils_multi.read_corpus(
    '/home/anushap/Code-Generation-lm/nmt_model/data/2code/code_test.txt',
    'tgt')
print(fail1)

data1, fail2 = utils_multi.read_corpus(
    '/home/anushap/Code-Generation-lm/nmt_model/data/2code/src_code_test.txt',
    'src_code')
print(fail2)

data2, fail3 = utils_multi.read_corpus(
    '/home/anushap/Code-Generation-lm/nmt_model/data/2code/nl_test.txt',
    'src_nl')
print(fail3)

fout = open(
    "/home/anushap/Code-Generation-lm/nmt_model/data/2code/code_test_bleu.txt",
    "w")

fail = fail1 + fail2 + fail3
for i in range(len(data)):
    if i not in fail:
        each = ' '.join(data[i])
        fout.write(each)
        fout.write('\n')

fout.close()