def main(argv): torch.manual_seed(FLAGS.seed) np.random.seed(FLAGS.seed) hlog.flags() dataset = get_dataset() model = pick_model(dataset) model.prepare(dataset) if isinstance(model, nn.Module): path = os.path.join(FLAGS.model_dir, FLAGS.model) checkpoint = torch.load(path) model.load_state_dict(checkpoint) realized = set() examples = pick_examples(dataset) while len(realized) < FLAGS.n_sample: try: templ, names = next(examples) except StopIteration: break datum = make_batch([(templ, templ) for _ in range(10)], dataset.vocab, staged=True) (inps, outs), scores = model.sample(datum.inp_data, datum.out_data) keep = [] for inp, out, score in zip(inps, outs, scores): inp_realized, inp_used = dataset.realize(inp, names) out_realized, out_used = dataset.realize(out, names) if ((not FLAGS.output_only) and len(inp_used) == 0) or len(out_used) == 0: continue if len(inp_used | out_used) != len(names): continue if not ((FLAGS.output_only or dataset.novel(inp=inp_realized)) and dataset.novel(out=out_realized)): continue if (inp_realized, out_realized) in realized: continue keep.append(((inp_realized, out_realized), score)) for (inp_realized, out_realized), score in keep: with hlog.task(str(len(realized))): hlog.value("inp", " ".join(dataset.vocab.decode(templ[0]))) hlog.value("out", " ".join(dataset.vocab.decode(templ[1]))) hlog.value("var", names) hlog.value("score", score) with hlog.task("realized"): hlog.value("inp", " ".join(inp_realized)) hlog.value("out", " ".join(out_realized)) realized.add((inp_realized, out_realized)) data = [{"inp": inp, "out": out} for inp, out in realized] with open(FLAGS.write, "w") as fh: json.dump(data, fh, indent=2)
def callback(i_epoch): model.eval() final = i_epoch == FLAGS.n_epochs - 1 with hlog.task("eval_val", timer=False): val_acc = evaluate(score_utts, dataset.get_val(), dataset) if FLAGS.TEST and (final or FLAGS.test_curve): with hlog.task("eval_test", timer=False): evaluate(score_utts, dataset.get_test(), dataset) if (i_epoch + 1) % FLAGS.n_checkpoint == 0: torch.save( model.state_dict(), os.path.join(FLAGS.model_dir, "model.%05d.chk" % i_epoch)) return val_acc
def callback(i_epoch): if not fine_tune[0] and i_epoch >= 20: hlog.log("FINE_TUNE") fine_tune[0] = True model.eval() final = i_epoch == FLAGS.n_epochs - 1 with hlog.task("eval_train", timer=False): train_data = [dataset.sample_train() for _ in range(1000)] evaluate(model, train_data, dataset) with hlog.task("eval_val", timer=False): val_data = dataset.get_val() val_acc = evaluate(model, val_data, dataset, vis=final, beam=final) if FLAGS.TEST and (final or FLAGS.test_curve): with hlog.task("eval_test", timer=False): test_data = dataset.get_test() evaluate(model, test_data, dataset, beam=final) if (i_epoch + 1) % FLAGS.n_checkpoint == 0: torch.save( model.state_dict(), os.path.join(FLAGS.model_dir, "model.%05d.chk" % i_epoch)) return val_acc
def evaluate(dataset, model): with hlog.task("train", timer=False): visualize( make_batch([dataset.sample_comp_train()], dataset.vocab, staged=True), dataset.vocab, model) #with hlog.task("holdout", timer=False): # visualize( # make_batch([dataset.sample_comp_gen()[:2]], dataset.vocab, staged=True), # dataset.vocab, # model # ) print()
def mkn_main(dataset): model = kenlm.LanguageModel(FLAGS.lm_file) if FLAGS.aug_ratio > 0: assert FLAGS.aug_lm_file is not None aug_model = kenlm.LanguageModel(FLAGS.aug_lm_file) def score_utts(utts, baseline=False): scores = [] for utt in utts: dec = " ".join(dataset.vocab.decode(utt)) score_here = model.score(dec) if (not baseline) and FLAGS.aug_ratio > 0: #base_prob = np.exp(score_here) score_aug = aug_model.score(dec) #aug_prob = np.exp(aug_score) #print(np.log(base_prob), np.log(aug_prob)) #score_here = np.log((base_prob + FLAGS.aug_ratio * aug_prob) / (1 + FLAGS.aug_ratio)) score_here = np.logaddexp( score_here + np.log(1 / (1 + FLAGS.aug_ratio)), score_aug + np.log(FLAGS.aug_ratio / (1 + FLAGS.aug_ratio))) scores.append(-score_here * np.log(10)) scores = np.asarray(scores) assert (scores > 0).all() return scores with hlog.task("eval_train", timer=False): evaluate(score_utts, dataset.get_train(), dataset) with hlog.task("eval_val", timer=False): evaluate(score_utts, dataset.get_val(), dataset) if FLAGS.TEST: with hlog.task("eval_test", timer=False): evaluate(score_utts, dataset.get_test(), dataset)
def main(): factory = GrammarFactory() vocab = factory.vocab() model = InductorModel(vocab).to(DEVICE) opt = optim.RMSprop(model.parameters(), lr=0.0003) with hlog.task("train"): for i_epoch in hlog.loop("%05d", range(1000), timer=False): epoch_loss = 0 for i_iter in range(10): opt.zero_grad() loss = 0 for i_batch_part in range(n_batch): ctx, inp, out = sample_batch(factory, vocab) loss += model(ctx, inp, out) loss.backward() clip_grad_norm_(model.parameters(), .1) opt.step() epoch_loss += loss.item() / n_batch hlog.value("loss", epoch_loss)
def evaluate(model, data, dataset, vis=False, beam=False): correct = 0 total = 0 for i in range(0, len(data), FLAGS.n_batch): batch = make_batch(data[i:i + FLAGS.n_batch], model.vocab, staged=False) preds, _ = model.sample(batch.inp_data, greedy=True, beam=beam) for j in range(len(preds)): score_here = dataset.score(preds[j], batch.out[j], batch.inp[j]) if vis: with hlog.task(str(total)): hlog.value("input", " ".join(model.vocab.decode(batch.inp[j]))) hlog.value("pred", " ".join(model.vocab.decode(preds[j]))) hlog.value("gold", " ".join(model.vocab.decode(batch.out[j]))) hlog.value("corr", score_here) hlog.log("") total += 1 correct += score_here acc = 1. * correct / total hlog.value("acc", acc) return acc