Beispiel #1
0
def init_training(args):
    from functools import partial
    import pickle
    pickle.load = partial(pickle.load, encoding="latin1")
    pickle.Unpickler = partial(pickle.Unpickler, encoding="latin1")
    # model = torch.load(model_file, map_location=lambda storage, loc: storage, pickle_module=pickle)
    vocab = torch.load(args.vocab,
                       map_location=lambda storage, loc: storage,
                       pickle_module=pickle)

    model = NMT(args, vocab)
    model.train()

    if args.uniform_init:
        print('uniformly initialize parameters [-%f, +%f]' %
              (args.uniform_init, args.uniform_init),
              file=sys.stderr)
        for p in model.parameters():
            p.data.uniform_(-args.uniform_init, args.uniform_init)

    vocab_mask = torch.ones(len(vocab.tgt))
    vocab_mask[vocab.tgt['<pad>']] = 0
    nll_loss = nn.NLLLoss(weight=vocab_mask, reduction='sum')
    cross_entropy_loss = nn.CrossEntropyLoss(weight=vocab_mask,
                                             reduction='sum')

    if args.cuda:
        # model = nn.DataParallel(model).cuda()
        model = model.cuda()
        nll_loss = nll_loss.cuda()
        cross_entropy_loss = cross_entropy_loss.cuda()

    optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)

    return vocab, model, optimizer, nll_loss, cross_entropy_loss
Beispiel #2
0
def test(args):
    test_data_src = read_corpus(args.test_src, source='src')
    test_data_tgt = read_corpus(args.test_tgt, source='tgt')
    test_data = list(zip(test_data_src, test_data_tgt))

    if args.load_model:
        print('load model from [%s]' % args.load_model)
        params = torch.load(args.load_model,
                            map_location=lambda storage, loc: storage)
        vocab = params['vocab']
        saved_args = params['args']
        state_dict = params['state_dict']

        model = NMT(saved_args, vocab)
        model.load_state_dict(state_dict)
    else:
        vocab = torch.load(args.vocab)
        model = NMT(args, vocab)

    model.eval()

    if args.cuda:
        # model = nn.DataParallel(model).cuda()
        model = model.cuda()

    hypotheses = decode(model, test_data)
    top_hypotheses = [hyps[0] for hyps in hypotheses]

    bleu_score = get_bleu([tgt for src, tgt in test_data], top_hypotheses)
    word_acc = get_acc([tgt for src, tgt in test_data], top_hypotheses,
                       'word_acc')
    sent_acc = get_acc([tgt for src, tgt in test_data], top_hypotheses,
                       'sent_acc')
    print('Corpus Level BLEU: %f, word level acc: %f, sentence level acc: %f' %
          (bleu_score, word_acc, sent_acc),
          file=sys.stderr)

    if args.save_to_file:
        print('save decoding results to %s' % args.save_to_file)
        with open(args.save_to_file, 'w') as f:
            for hyps in hypotheses:
                f.write(' '.join(hyps[0][1:-1]) + '\n')

        if args.save_nbest:
            nbest_file = args.save_to_file + '.nbest'
            print('save nbest decoding results to %s' % nbest_file)
            with open(nbest_file, 'w') as f:
                for src_sent, tgt_sent, hyps in zip(test_data_src,
                                                    test_data_tgt, hypotheses):
                    print('Source: %s' % ' '.join(src_sent), file=f)
                    print('Target: %s' % ' '.join(tgt_sent), file=f)
                    print('Hypotheses:', file=f)
                    for i, hyp in enumerate(hyps, 1):
                        print('[%d] %s' % (i, ' '.join(hyp)), file=f)
                    print('*' * 30, file=f)
Beispiel #3
0
def sample(args):
    train_data_src = read_corpus(args.train_src, source='src')
    train_data_tgt = read_corpus(args.train_tgt, source='tgt')
    train_data = zip(train_data_src, train_data_tgt)

    if args.load_model:
        print('load model from [%s]' % args.load_model)
        params = torch.load(args.load_model,
                            map_location=lambda storage, loc: storage)
        vocab = params['vocab']
        opt = params['args']
        state_dict = params['state_dict']

        model = NMT(opt, vocab)
        model.load_state_dict(state_dict)
    else:
        vocab = torch.load(args.vocab)
        model = NMT(args, vocab)

    model.eval()

    if args.cuda:
        # model = nn.DataParallel(model).cuda()
        model = model.cuda()

    print('begin sampling')

    check_every = 10
    train_iter = cum_samples = 0
    train_time = time.time()
    for src_sents, tgt_sents in data_iter(train_data,
                                          batch_size=args.batch_size):
        train_iter += 1
        samples = model.sample(src_sents,
                               sample_size=args.sample_size,
                               to_word=True)
        cum_samples += sum(len(sample) for sample in samples)

        if train_iter % check_every == 0:
            elapsed = time.time() - train_time
            print('sampling speed: %d/s' % (cum_samples / elapsed))
            cum_samples = 0
            train_time = time.time()

        for i, tgt_sent in enumerate(tgt_sents):
            print('*' * 80)
            print('target:' + ' '.join(tgt_sent))
            tgt_samples = samples[i]
            print('samples:')
            for sid, sample in enumerate(tgt_samples, 1):
                print('[%d] %s' % (sid, ' '.join(sample[1:-1])))
            print('*' * 80)
Beispiel #4
0
def interactive(args):
    assert args.load_model, 'You have to specify a pre-trained model'
    print('load model from [%s]' % args.load_model)
    params = torch.load(args.load_model,
                        map_location=lambda storage, loc: storage)
    vocab = params['vocab']
    saved_args = params['args']
    state_dict = params['state_dict']

    model = NMT(saved_args, vocab)
    model.load_state_dict(state_dict)

    model.eval()

    if args.cuda:
        # model = nn.DataParallel(model).cuda()
        model = model.cuda()

    while True:
        src_sent = input('Source Sentence:')
        src_sent = src_sent.strip().split(' ')
        hyps = model.translate(src_sent)
        for i, hyp in enumerate(hyps, 1):
            print('Hypothesis #%d: %s' % (i, ' '.join(hyp)))
Beispiel #5
0
def train():
    text = Text(config.src_corpus, config.tar_corpus)
    train_data = Data(config.train_path_src, config.train_path_tar)
    dev_data = Data(config.dev_path_src, config.dev_path_tar)
    train_loader = DataLoader(dataset=train_data,
                              batch_size=config.batch_size,
                              shuffle=True,
                              collate_fn=utils.get_batch)
    dev_loader = DataLoader(dataset=dev_data,
                            batch_size=config.dev_batch_size,
                            shuffle=True,
                            collate_fn=utils.get_batch)
    parser = OptionParser()
    parser.add_option("--embed_size",
                      dest="embed_size",
                      default=config.embed_size)
    parser.add_option("--hidden_size",
                      dest="hidden_size",
                      default=config.hidden_size)
    parser.add_option("--window_size_d",
                      dest="window_size_d",
                      default=config.window_size_d)
    parser.add_option("--encoder_layer",
                      dest="encoder_layer",
                      default=config.encoder_layer)
    parser.add_option("--decoder_layers",
                      dest="decoder_layers",
                      default=config.decoder_layers)
    parser.add_option("--dropout_rate",
                      dest="dropout_rate",
                      default=config.dropout_rate)
    (options, args) = parser.parse_args()
    device = torch.device("cuda:0" if config.cuda else "cpu")
    #model_path = "/home/wangshuhe/shuhelearn/ShuHeLearning/NMT_attention/result/01.31_drop0.3_54_21.46508598886769_checkpoint.pth"
    #print(f"load model from {model_path}", file=sys.stderr)
    #model = NMT.load(model_path)
    model = NMT(text, options, device)
    #model = model.cuda()
    #model_path = "/home/wangshuhe/shuhelearn/ShuHeLearning/NMT_attention/result/140_164.29781984744628_checkpoint.pth"
    #print(f"load model from {model_path}", file=sys.stderr)
    #model = NMT.load(model_path)
    #model = torch.nn.DataParallel(model)
    model = model.to(device)
    model = model.cuda()
    model.train()
    optimizer = Optim(torch.optim.Adam(model.parameters()))
    #optimizer = Optim(torch.optim.Adam(model.parameters(), betas=(0.9, 0.98), eps=1e-9), config.hidden_size, config.warm_up_step)
    #print(optimizer.lr)
    epoch = 0
    valid_num = 1
    hist_valid_ppl = []

    print("begin training!")
    while (True):
        epoch += 1
        max_iter = int(math.ceil(len(train_data) / config.batch_size))
        with tqdm(total=max_iter, desc="train") as pbar:
            for src_sents, tar_sents, tar_words_num_to_predict in train_loader:
                optimizer.zero_grad()
                batch_size = len(src_sents)

                now_loss = -model(src_sents, tar_sents)
                now_loss = now_loss.sum()
                loss = now_loss / batch_size
                loss.backward()

                _ = torch.nn.utils.clip_grad_norm_(model.parameters(),
                                                   config.clip_grad)
                #optimizer.updata_lr()
                optimizer.step_and_updata_lr()

                pbar.set_postfix({
                    "epwwoch":
                    epoch,
                    "avg_loss":
                    loss.item(),
                    "ppl":
                    math.exp(now_loss.item() / tar_words_num_to_predict),
                    "lr":
                    optimizer.lr
                })
                #pbar.set_postfix({"epoch": epoch, "avg_loss": loss.item(), "ppl": math.exp(now_loss.item()/tar_words_num_to_predict)})
                pbar.update(1)
        #print(optimizer.lr)
        if (epoch % config.valid_iter == 0):
            #if (epoch >= config.valid_iter//2):
            if (valid_num % 5 == 0):
                valid_num = 0
                optimizer.updata_lr()
            valid_num += 1
            print("now begin validation ...", file=sys.stderr)
            eav_ppl = evaluate_ppl(model, dev_data, dev_loader)
            print("validation ppl %.2f" % (eav_ppl), file=sys.stderr)
            flag = len(hist_valid_ppl) == 0 or eav_ppl < min(hist_valid_ppl)
            if (flag):
                print("current model is the best!, save to [%s]" %
                      (config.model_save_path),
                      file=sys.stderr)
                hist_valid_ppl.append(eav_ppl)
                model.save(
                    os.path.join(
                        config.model_save_path,
                        f"02.08_window35drop0.2_{epoch}_{eav_ppl}_checkpoint.pth"
                    ))
                torch.save(
                    optimizer.optimizer.state_dict(),
                    os.path.join(
                        config.model_save_path,
                        f"02.08_window35drop0.2_{epoch}_{eav_ppl}_optimizer.optim"
                    ))
        if (epoch == config.max_epoch):
            print("reach the maximum number of epochs!", file=sys.stderr)
            return
Beispiel #6
0
def compute_lm_prob(args):
    """
    given source-target sentence pairs, compute ppl and log-likelihood
    """
    test_data_src = read_corpus(args.test_src, source='src')
    test_data_tgt = read_corpus(args.test_tgt, source='tgt')
    test_data = zip(test_data_src, test_data_tgt)

    if args.load_model:
        print('load model from [%s]' % args.load_model)
        params = torch.load(args.load_model,
                            map_location=lambda storage, loc: storage)
        vocab = params['vocab']
        saved_args = params['args']
        state_dict = params['state_dict']

        model = NMT(saved_args, vocab)
        model.load_state_dict(state_dict)
    else:
        vocab = torch.load(args.vocab)
        model = NMT(args, vocab)

    model.eval()

    if args.cuda:
        # model = nn.DataParallel(model).cuda()
        model = model.cuda()

    f = open(args.save_to_file, 'w')
    for src_sent, tgt_sent in test_data:
        src_sents = [src_sent]
        tgt_sents = [tgt_sent]

        batch_size = len(src_sents)
        src_sents_len = [len(s) for s in src_sents]
        pred_tgt_word_nums = [len(s[1:])
                              for s in tgt_sents]  # omitting leading `<s>`

        # (sent_len, batch_size)
        src_sents_var = to_input_variable(src_sents,
                                          model.vocab.src,
                                          cuda=args.cuda,
                                          is_test=True)
        tgt_sents_var = to_input_variable(tgt_sents,
                                          model.vocab.tgt,
                                          cuda=args.cuda,
                                          is_test=True)

        # (tgt_sent_len, batch_size, tgt_vocab_size)
        scores = model(src_sents_var, src_sents_len, tgt_sents_var[:-1])
        # (tgt_sent_len * batch_size, tgt_vocab_size)
        log_scores = F.log_softmax(scores.view(-1, scores.size(2)))
        # remove leading <s> in tgt sent, which is not used as the target
        # (batch_size * tgt_sent_len)
        flattened_tgt_sents = tgt_sents_var[1:].view(-1)
        # (batch_size * tgt_sent_len)
        tgt_log_scores = torch.gather(
            log_scores, 1, flattened_tgt_sents.unsqueeze(1)).squeeze(1)
        # 0-index is the <pad> symbol
        tgt_log_scores = tgt_log_scores * (
            1. - torch.eq(flattened_tgt_sents, 0).float())
        # (tgt_sent_len, batch_size)
        tgt_log_scores = tgt_log_scores.view(-1, batch_size)  # .permute(1, 0)
        # (batch_size)
        tgt_sent_scores = tgt_log_scores.sum(dim=0).squeeze()
        tgt_sent_word_scores = [
            tgt_sent_scores[i].item() / pred_tgt_word_nums[i]
            for i in range(batch_size)
        ]

        for src_sent, tgt_sent, score in zip(src_sents, tgt_sents,
                                             tgt_sent_word_scores):
            f.write('%s ||| %s ||| %f\n' %
                    (' '.join(src_sent), ' '.join(tgt_sent), score))

    f.close()