예제 #1
0
파일: train.py 프로젝트: xwgeng/RNNSearch
def evaluate(batch_idx, epoch):
    model.eval()
    hyp_list = []
    ref_list = []
    start_time = time.time()
    for ix, batch in enumerate(valid_iter, start=1):
        src_raw = batch[0]
        trg_raw = batch[1:]
        src, src_mask = convert_data(src_raw, src_vocab, device, True, UNK,
                                     PAD, SOS, EOS)
        with torch.no_grad():
            output = model.beamsearch(src,
                                      src_mask,
                                      opt.beam_size,
                                      normalize=True)
            best_hyp, best_score = output[0]
            best_hyp = convert_str([best_hyp], trg_vocab)
            hyp_list.append(best_hyp[0])
            ref = map(lambda x: x[0], trg_raw)
            ref_list.append(ref)
    elapsed = time.time() - start_time
    bleu1 = corpus_bleu(ref_list,
                        hyp_list,
                        smoothing_function=SmoothingFunction().method1)
    hyp_list = map(lambda x: ' '.join(x), hyp_list)
    p_tmp = tempfile.mktemp()
    f_tmp = open(p_tmp, 'w')
    f_tmp.write('\n'.join(hyp_list))
    f_tmp.close()
    bleu2 = bleu_script(p_tmp)
    print('BLEU score for {}-{} is {}/{}, {}'.format(epoch, batch_idx, bleu1,
                                                     bleu2, elapsed))
    opt.score_list.append((bleu2, batch_idx, epoch))
예제 #2
0
def evaluate(batch_idx, epoch):
    model.eval()
    hyp_list = []
    ref_list = []
    start_time = time.time()
    for ix, batch in enumerate(valid_iter, start=1):
        src_raw = batch[0]
        trg_raw = batch[1:]
        src, src_mask = convert_data(
            src_raw, src_vocab, device, True, UNK, PAD, SOS, EOS
        )
        with torch.no_grad():
            output = model.beamsearch(src, src_mask, opt.beam_size, normalize=True)
            best_hyp, best_score = output[0]
            best_hyp = convert_str([best_hyp], trg_vocab)
            hyp_list.append(best_hyp[0])
            ref = [x[0] for x in trg_raw]
            ref_list.append(ref)
    elapsed = time.time() - start_time

    hyp_list = [" ".join(x) for x in hyp_list]
    p_tmp = tempfile.mktemp()
    f_tmp = open(p_tmp, "w")
    f_tmp.write("\n".join(hyp_list))
    f_tmp.close()
    bleu2 = bleu_script(p_tmp)

    bleu_1_gram = bleu(hyp_list, ref_list, smoothing=True, n=1)
    bleu_2_gram = bleu(hyp_list, ref_list, smoothing=True, n=2)
    bleu_3_gram = bleu(hyp_list, ref_list, smoothing=True, n=3)
    bleu_4_gram = bleu(hyp_list, ref_list, smoothing=True, n=4)
    writer.add_scalar("./bleu_1_gram", bleu_1_gram, epoch)
    writer.add_scalar("./bleu_2_gram", bleu_2_gram, epoch)
    writer.add_scalar("./bleu_3_gram", bleu_3_gram, epoch)
    writer.add_scalar("./bleu_4_gram", bleu_4_gram, epoch)
    writer.add_scalar("./multi-bleu", bleu2, epoch)
    bleu_result = [
        ["multi-bleu", "bleu_1-gram", "bleu_2-gram", "bleu_3-gram", "bleu_4-gram"],
        [bleu2, bleu_1_gram, bleu_2_gram, bleu_3_gram, bleu_4_gram],
    ]
    bleu_table = AsciiTable(bleu_result)
    logger.info(
        "BLEU score for Epoch-{}-batch-{}: ".format(epoch, batch_idx)
        + "\n"
        + bleu_table.table
    )
    opt.score_list.append(
        (bleu2, bleu_1_gram, bleu_2_gram, bleu_3_gram, bleu_4_gram, batch_idx, epoch)
    )
예제 #3
0
    bleu = float(out)
    return bleu


hyp_list = []
ref_list = []
start_time = time.time()
for ix, batch in enumerate(test_iter, start=1):
    src_raw = batch[0]
    trg_raw = batch[1:]
    src, src_mask = convert_data(src_raw, src_vocab, device, True, UNK, PAD,
                                 SOS, EOS)
    with torch.no_grad():
        output = model.beamsearch(src, src_mask, opt.beam_size, normalize=True)
        best_hyp, best_score = output[0]
        best_hyp = convert_str([best_hyp], trg_vocab)
        hyp_list.append(best_hyp[0])
        ref = map(lambda x: x[0], trg_raw)
        ref_list.append(ref)
    print(ix, len(test_iter), 100. * ix / len(test_iter))
elapsed = time.time() - start_time
bleu1 = corpus_bleu(ref_list,
                    hyp_list,
                    smoothing_function=SmoothingFunction().method1)
hyp_list = map(lambda x: ' '.join(x), hyp_list)
p_tmp = tempfile.mktemp()
f_tmp = open(p_tmp, 'w')
f_tmp.write('\n'.join(hyp_list))
f_tmp.close()
bleu2 = bleu_script(p_tmp)
print('BLEU score for model {} is {}/{}, {}'.format(opt.name, bleu1, bleu2,
예제 #4
0
def train(epoch):
    model.train()
    opt.epoch_best_score = -float("inf")
    opt.epoch_best_name = None
    for batch_idx, batch in enumerate(train_iter, start=1):
        batch = sort_batch(batch)
        src_raw = batch[0]
        trg_raw = batch[1]
        src, src_mask = convert_data(
            src_raw, src_vocab, device, True, UNK, PAD, SOS, EOS
        )
        f_trg, f_trg_mask = convert_data(
            trg_raw, trg_vocab, device, False, UNK, PAD, SOS, EOS
        )
        b_trg, b_trg_mask = convert_data(
            trg_raw, trg_vocab, device, True, UNK, PAD, SOS, EOS
        )
        optimizer.zero_grad()
        if opt.cuda and torch.cuda.device_count() > 1 and opt.local_rank is None:
            loss, w_loss = nn.parallel.data_parallel(
                model, (src, src_mask, f_trg, f_trg_mask, b_trg, b_trg_mask), device_ids
            )
        else:
            loss, w_loss = model(src, src_mask, f_trg, f_trg_mask, b_trg, b_trg_mask)
        global_batches = len(train_iter) * epoch + current_batches
        writer.add_scalar(
            "./loss", scalar_value=loss.item(), global_step=global_batches,
        )
        loss.mean().backward()
        torch.nn.utils.clip_grad_norm_(param_list, opt.grad_clip)
        optimizer.step()
        if batch_idx % 10 == 0 or batch_idx == len(train_iter) or batch_idx == 0:
            logger.info(
                str(
                    "Epoch: {} batch: {}/{}({:.3%}), loss: {:.6}, lr: {}".format(
                        epoch,
                        batch_idx,
                        len(train_iter),
                        batch_idx / len(train_iter),
                        loss.item(),
                        opt.cur_lr,
                    )
                )
            )

        # validation
        if batch_idx % opt.vfreq == 0:
            logger.info(str("===========validation / test START==========="))
            evaluate(batch_idx, epoch)
            model.train()
            if opt.decay_lr:
                adjust_learningrate(opt.score_list)
            if len(opt.score_list) == 1 or opt.score_list[-1][0] > max(
                [x[0] for x in opt.score_list[:-1]]
            ):
                if opt.best_name is not None:
                    os.remove(os.path.join(opt.checkpoint, opt.best_name))
                opt.best_name = save_model(model, batch_idx, epoch, "best")
            if opt.epoch_best and opt.score_list[-1][0] > opt.epoch_best_score:
                opt.epoch_best_score = opt.score_list[-1][0]
                if opt.epoch_best_name is not None:
                    os.remove(os.path.join(opt.checkpoint, opt.epoch_best_name))
                opt.epoch_best_name = save_model(model, batch_idx, epoch, "epoch-best")
            logger.info("===========validation / test DONE===========")

        # sampling
        if batch_idx % opt.sfreq == 0:
            length = len(src_raw)
            ix = np.random.randint(0, length)
            samp_src_raw = [src_raw[ix]]
            samp_trg_raw = [trg_raw[ix]]
            samp_src, samp_src_mask = convert_data(
                samp_src_raw, src_vocab, device, True, UNK, PAD, SOS, EOS
            )
            model.eval()
            with torch.no_grad():
                output = model.beamsearch(samp_src, samp_src_mask, opt.beam_size)
            best_hyp, best_score = output[0]
            best_hyp = convert_str([best_hyp], trg_vocab)
            sampling_result = []
            sampling_result.append(["Key", "Value"])
            sampling_result.append(["Source", str(" ".join(samp_src_raw[0]))])
            sampling_result.append(["Target", str(" ".join(samp_trg_raw[0]))])
            sampling_result.append(["Predict", str(" ".join(best_hyp[0]))])
            sampling_result.append(["Best Score", str(round(best_score, 5))])
            sampling_table = AsciiTable(sampling_result)
            logger.info("===========sampling START===========")
            logger.info("\n" + str(sampling_table.table))
            logger.info("===========sampling DONE===========")
            model.train()

        # saving model
        if opt.freq and batch_idx % opt.freq == 0:
            if opt.tmp_name is not None:
                os.remove(os.path.join(opt.checkpoint, opt.tmp_name))
            opt.tmp_name = save_model(model, batch_idx, epoch, "tmp")
예제 #5
0
파일: train.py 프로젝트: xwgeng/RNNSearch
def train(epoch):
    model.train()
    opt.epoch_best_score = -float('inf')
    opt.epoch_best_name = None
    for batch_idx, batch in enumerate(train_iter, start=1):
        start_time = time.time()
        batch = sort_batch(batch)
        src_raw = batch[0]
        trg_raw = batch[1]
        src, src_mask = convert_data(src_raw, src_vocab, device, True, UNK,
                                     PAD, SOS, EOS)
        f_trg, f_trg_mask = convert_data(trg_raw, trg_vocab, device, False,
                                         UNK, PAD, SOS, EOS)
        b_trg, b_trg_mask = convert_data(trg_raw, trg_vocab, device, True, UNK,
                                         PAD, SOS, EOS)
        optimizer.zero_grad()
        if opt.cuda and torch.cuda.device_count(
        ) > 1 and opt.local_rank is None:
            R = nn.parallel.data_parallel(
                model, (src, src_mask, f_trg, f_trg_mask, b_trg, b_trg_mask),
                device_ids)
        else:
            R = model(src, src_mask, f_trg, f_trg_mask, b_trg, b_trg_mask)
        R[0].mean().backward()
        grad_norm = torch.nn.utils.clip_grad_norm_(param_list, opt.grad_clip)
        optimizer.step()
        elapsed = time.time() - start_time
        R = map(lambda x: str(x.mean().item()), R)
        print(epoch, batch_idx,
              len(train_iter), 100. * batch_idx / len(train_iter), ' '.join(R),
              grad_norm.item(), opt.cur_lr, elapsed)

        # validation
        if batch_idx % opt.vfreq == 0:
            evaluate(batch_idx, epoch)
            model.train()
            if opt.decay_lr:
                adjust_learningrate(opt.score_list)
            if len(opt.score_list) == 1 or \
                opt.score_list[-1][0] > max(map(lambda x: x[0], opt.score_list[:-1])):
                if opt.best_name is not None:
                    os.remove(os.path.join(opt.checkpoint, opt.best_name))
                opt.best_name = save_model(model, batch_idx, epoch, 'best')
            if opt.epoch_best and opt.score_list[-1][0] > opt.epoch_best_score:
                opt.epoch_best_score = opt.score_list[-1][0]
                if opt.epoch_best_name is not None:
                    os.remove(os.path.join(opt.checkpoint,
                                           opt.epoch_best_name))
                opt.epoch_best_name = save_model(model, batch_idx, epoch,
                                                 'epoch-best')

        # sampling
        if batch_idx % opt.sfreq == 0:
            length = len(src_raw)
            ix = np.random.randint(0, length)
            samp_src_raw = [src_raw[ix]]
            samp_trg_raw = [trg_raw[ix]]
            samp_src, samp_src_mask = convert_data(samp_src_raw, src_vocab,
                                                   device, True, UNK, PAD, SOS,
                                                   EOS)
            model.eval()
            with torch.no_grad():
                output = model.beamsearch(samp_src, samp_src_mask,
                                          opt.beam_size)
            best_hyp, best_score = output[0]
            best_hyp = convert_str([best_hyp], trg_vocab)
            print('--', ' '.join(samp_src_raw[0]))
            print('--', ' '.join(samp_trg_raw[0]))
            print('--', ' '.join(best_hyp[0]))
            print('--', best_score)
            model.train()

        # saving model
        if opt.freq and batch_idx % opt.freq == 0:
            if opt.tmp_name is not None:
                os.remove(os.path.join(opt.checkpoint, opt.tmp_name))
            opt.tmp_name = save_model(model, batch_idx, epoch, 'tmp')