Esempio n. 1
0
def main(argv):
    torch.manual_seed(FLAGS.seed)
    np.random.seed(FLAGS.seed)

    hlog.flags()

    dataset = get_dataset()
    model = pick_model(dataset)

    model.prepare(dataset)
    if isinstance(model, nn.Module):
        path = os.path.join(FLAGS.model_dir, FLAGS.model)
        checkpoint = torch.load(path)
        model.load_state_dict(checkpoint)

    realized = set()
    examples = pick_examples(dataset)
    while len(realized) < FLAGS.n_sample:
        try:
            templ, names = next(examples)
        except StopIteration:
            break
        datum = make_batch([(templ, templ) for _ in range(10)],
                           dataset.vocab,
                           staged=True)
        (inps, outs), scores = model.sample(datum.inp_data, datum.out_data)

        keep = []
        for inp, out, score in zip(inps, outs, scores):
            inp_realized, inp_used = dataset.realize(inp, names)
            out_realized, out_used = dataset.realize(out, names)
            if ((not FLAGS.output_only)
                    and len(inp_used) == 0) or len(out_used) == 0:
                continue
            if len(inp_used | out_used) != len(names):
                continue
            if not ((FLAGS.output_only or dataset.novel(inp=inp_realized))
                    and dataset.novel(out=out_realized)):
                continue
            if (inp_realized, out_realized) in realized:
                continue
            keep.append(((inp_realized, out_realized), score))
        for (inp_realized, out_realized), score in keep:
            with hlog.task(str(len(realized))):
                hlog.value("inp", " ".join(dataset.vocab.decode(templ[0])))
                hlog.value("out", " ".join(dataset.vocab.decode(templ[1])))
                hlog.value("var", names)
                hlog.value("score", score)
                with hlog.task("realized"):
                    hlog.value("inp", " ".join(inp_realized))
                    hlog.value("out", " ".join(out_realized))
            realized.add((inp_realized, out_realized))

    data = [{"inp": inp, "out": out} for inp, out in realized]
    with open(FLAGS.write, "w") as fh:
        json.dump(data, fh, indent=2)
Esempio n. 2
0
 def callback(i_epoch):
     model.eval()
     final = i_epoch == FLAGS.n_epochs - 1
     with hlog.task("eval_val", timer=False):
         val_acc = evaluate(score_utts, dataset.get_val(), dataset)
     if FLAGS.TEST and (final or FLAGS.test_curve):
         with hlog.task("eval_test", timer=False):
             evaluate(score_utts, dataset.get_test(), dataset)
     if (i_epoch + 1) % FLAGS.n_checkpoint == 0:
         torch.save(
             model.state_dict(),
             os.path.join(FLAGS.model_dir, "model.%05d.chk" % i_epoch))
     return val_acc
Esempio n. 3
0
 def callback(i_epoch):
     if not fine_tune[0] and i_epoch >= 20:
         hlog.log("FINE_TUNE")
         fine_tune[0] = True
     model.eval()
     final = i_epoch == FLAGS.n_epochs - 1
     with hlog.task("eval_train", timer=False):
         train_data = [dataset.sample_train() for _ in range(1000)]
         evaluate(model, train_data, dataset)
     with hlog.task("eval_val", timer=False):
         val_data = dataset.get_val()
         val_acc = evaluate(model, val_data, dataset, vis=final, beam=final)
     if FLAGS.TEST and (final or FLAGS.test_curve):
         with hlog.task("eval_test", timer=False):
             test_data = dataset.get_test()
             evaluate(model, test_data, dataset, beam=final)
     if (i_epoch + 1) % FLAGS.n_checkpoint == 0:
         torch.save(
             model.state_dict(),
             os.path.join(FLAGS.model_dir, "model.%05d.chk" % i_epoch))
     return val_acc
Esempio n. 4
0
def evaluate(dataset, model):
    with hlog.task("train", timer=False):
        visualize(
            make_batch([dataset.sample_comp_train()],
                       dataset.vocab,
                       staged=True), dataset.vocab, model)
    #with hlog.task("holdout", timer=False):
    #    visualize(
    #        make_batch([dataset.sample_comp_gen()[:2]], dataset.vocab, staged=True),
    #        dataset.vocab,
    #        model
    #    )
    print()
Esempio n. 5
0
def mkn_main(dataset):
    model = kenlm.LanguageModel(FLAGS.lm_file)
    if FLAGS.aug_ratio > 0:
        assert FLAGS.aug_lm_file is not None
        aug_model = kenlm.LanguageModel(FLAGS.aug_lm_file)

    def score_utts(utts, baseline=False):
        scores = []
        for utt in utts:
            dec = " ".join(dataset.vocab.decode(utt))
            score_here = model.score(dec)
            if (not baseline) and FLAGS.aug_ratio > 0:
                #base_prob = np.exp(score_here)
                score_aug = aug_model.score(dec)

                #aug_prob = np.exp(aug_score)
                #print(np.log(base_prob), np.log(aug_prob))
                #score_here = np.log((base_prob + FLAGS.aug_ratio * aug_prob) / (1 + FLAGS.aug_ratio))
                score_here = np.logaddexp(
                    score_here + np.log(1 / (1 + FLAGS.aug_ratio)), score_aug +
                    np.log(FLAGS.aug_ratio / (1 + FLAGS.aug_ratio)))

            scores.append(-score_here * np.log(10))

        scores = np.asarray(scores)
        assert (scores > 0).all()
        return scores

    with hlog.task("eval_train", timer=False):
        evaluate(score_utts, dataset.get_train(), dataset)

    with hlog.task("eval_val", timer=False):
        evaluate(score_utts, dataset.get_val(), dataset)

    if FLAGS.TEST:
        with hlog.task("eval_test", timer=False):
            evaluate(score_utts, dataset.get_test(), dataset)
Esempio n. 6
0
def main():
    factory = GrammarFactory()
    vocab = factory.vocab()
    model = InductorModel(vocab).to(DEVICE)
    opt = optim.RMSprop(model.parameters(), lr=0.0003)

    with hlog.task("train"):
        for i_epoch in hlog.loop("%05d", range(1000), timer=False):
            epoch_loss = 0
            for i_iter in range(10):
                opt.zero_grad()
                loss = 0
                for i_batch_part in range(n_batch):
                    ctx, inp, out = sample_batch(factory, vocab)
                    loss += model(ctx, inp, out)
                loss.backward()
                clip_grad_norm_(model.parameters(), .1)
                opt.step()
                epoch_loss += loss.item() / n_batch
            hlog.value("loss", epoch_loss)
Esempio n. 7
0
def evaluate(model, data, dataset, vis=False, beam=False):
    correct = 0
    total = 0
    for i in range(0, len(data), FLAGS.n_batch):
        batch = make_batch(data[i:i + FLAGS.n_batch],
                           model.vocab,
                           staged=False)
        preds, _ = model.sample(batch.inp_data, greedy=True, beam=beam)
        for j in range(len(preds)):
            score_here = dataset.score(preds[j], batch.out[j], batch.inp[j])
            if vis:
                with hlog.task(str(total)):
                    hlog.value("input",
                               " ".join(model.vocab.decode(batch.inp[j])))
                    hlog.value("pred", " ".join(model.vocab.decode(preds[j])))
                    hlog.value("gold",
                               " ".join(model.vocab.decode(batch.out[j])))
                    hlog.value("corr", score_here)
                    hlog.log("")
            total += 1
            correct += score_here
    acc = 1. * correct / total
    hlog.value("acc", acc)
    return acc