def train(dataset, model, sample, callback, staged): if not isinstance(model, nn.Module): return opt = optim.Adam(model.parameters(), lr=FLAGS.lr) #sched = opt_sched.CosineAnnealingLR(opt, T_max=FLAGS.n_epochs) if FLAGS.sched_factor < 1: sched = opt_sched.ReduceLROnPlateau(opt, mode='max', factor=FLAGS.sched_factor, verbose=True) for i_epoch in hlog.loop("%05d", range(FLAGS.n_epochs)): model.train() epoch_loss = 0 for i_batch in range(FLAGS.n_epoch_batches): #sched.step() opt.zero_grad() datum = make_batch([sample() for _ in range(FLAGS.n_batch)], dataset.vocab, staged) loss = model(datum.inp_data, datum.out_data, datum.direct_out_data, datum.copy_out_data, *datum.extra) loss.backward() clip_grad_norm_(model.parameters(), FLAGS.clip) opt.step() epoch_loss += loss.item() epoch_loss /= FLAGS.n_epoch_batches hlog.value("loss", epoch_loss) val_score = callback(i_epoch) if FLAGS.sched_factor < 1: sched.step(val_score)
def enumerate_templates(): for i, utt in enumerate(utts): inp, out = utt seq = inp + (sep,) + out if FLAGS.max_comp_len is not None and len(seq) >= FLAGS.max_comp_len: continue if i % 1000 == 0: hlog.value("template_utt", "%d/%d" % (i, len(utts))) for generic in self._make_generic(seq, keep_args): yield generic, utt
def evaluate(score_utts, data, dataset): _, utts = zip(*data) baseline_nll = score_utts(utts, baseline=True) nll = score_utts(utts) tval, pval = scipy.stats.ttest_rel(nll, baseline_nll) n_toks = sum(len(utt) - 1 for utt in utts) nll_norm = nll.sum() / n_toks ppl = np.exp(nll_norm) hlog.value("ppl", ppl) hlog.value("t/p", str(tval) + " " + str(pval)) return -ppl
def main(): factory = GrammarFactory() vocab = factory.vocab() model = InductorModel(vocab).to(DEVICE) opt = optim.RMSprop(model.parameters(), lr=0.0003) with hlog.task("train"): for i_epoch in hlog.loop("%05d", range(1000), timer=False): epoch_loss = 0 for i_iter in range(10): opt.zero_grad() loss = 0 for i_batch_part in range(n_batch): ctx, inp, out = sample_batch(factory, vocab) loss += model(ctx, inp, out) loss.backward() clip_grad_norm_(model.parameters(), .1) opt.step() epoch_loss += loss.item() / n_batch hlog.value("loss", epoch_loss)
def __init__( self, train_utts, val_utts, test_utts, aug_data=(), invert=False, ): vocab = Vocab() for i in range(FLAGS.wug_count): vocab.add(wug_template % i) vocab.add(sep) for utts in (train_utts, val_utts, test_utts): for inp, out in utts: for seq in (inp, out): for word in seq: vocab.add(word) aug_utts = [(tuple(d["inp"]), tuple(d["out"])) for d in aug_data] if FLAGS.dedup: train_utts = [(tuple(i), tuple(o)) for i, o in train_utts] train_utts = sorted(list(set(train_utts))) hlog.value("train", len(train_utts)) hlog.value("aug", len(aug_utts)) if invert: train_utts = [(o, i) for i, o in train_utts] aug_utts = [(o, i) for i, o in aug_utts] val_utts = [(o, i) for i, o in val_utts] test_utts = [(o, i) for i, o in test_utts] self.vocab = vocab self.sep = sep self.train_utts = train_utts self.aug_utts = aug_utts self.val_utts = val_utts self.test_utts = test_utts if FLAGS.compute_adjacency: self._compute_adjacency(train_utts)
def visualize(datum, vocab, model): for (inp, out) in zip(*datum.inp): hlog.value("inp", " ".join(vocab.decode(inp + out))) for (inp, out) in zip(*datum.out): hlog.value("out", " ".join(vocab.decode(inp + out))) samples, _ = model.sample(datum.inp_data, greedy=True) for (inp, out) in zip(*samples): hlog.value("???", " ".join(vocab.decode(inp + out))) print()
def evaluate(model, data, dataset, vis=False, beam=False): correct = 0 total = 0 for i in range(0, len(data), FLAGS.n_batch): batch = make_batch(data[i:i + FLAGS.n_batch], model.vocab, staged=False) preds, _ = model.sample(batch.inp_data, greedy=True, beam=beam) for j in range(len(preds)): score_here = dataset.score(preds[j], batch.out[j], batch.inp[j]) if vis: with hlog.task(str(total)): hlog.value("input", " ".join(model.vocab.decode(batch.inp[j]))) hlog.value("pred", " ".join(model.vocab.decode(preds[j]))) hlog.value("gold", " ".join(model.vocab.decode(batch.out[j]))) hlog.value("corr", score_here) hlog.log("") total += 1 correct += score_here acc = 1. * correct / total hlog.value("acc", acc) return acc
def main(argv): torch.manual_seed(FLAGS.seed) np.random.seed(FLAGS.seed) hlog.flags() dataset = get_dataset() model = pick_model(dataset) model.prepare(dataset) if isinstance(model, nn.Module): path = os.path.join(FLAGS.model_dir, FLAGS.model) checkpoint = torch.load(path) model.load_state_dict(checkpoint) realized = set() examples = pick_examples(dataset) while len(realized) < FLAGS.n_sample: try: templ, names = next(examples) except StopIteration: break datum = make_batch([(templ, templ) for _ in range(10)], dataset.vocab, staged=True) (inps, outs), scores = model.sample(datum.inp_data, datum.out_data) keep = [] for inp, out, score in zip(inps, outs, scores): inp_realized, inp_used = dataset.realize(inp, names) out_realized, out_used = dataset.realize(out, names) if ((not FLAGS.output_only) and len(inp_used) == 0) or len(out_used) == 0: continue if len(inp_used | out_used) != len(names): continue if not ((FLAGS.output_only or dataset.novel(inp=inp_realized)) and dataset.novel(out=out_realized)): continue if (inp_realized, out_realized) in realized: continue keep.append(((inp_realized, out_realized), score)) for (inp_realized, out_realized), score in keep: with hlog.task(str(len(realized))): hlog.value("inp", " ".join(dataset.vocab.decode(templ[0]))) hlog.value("out", " ".join(dataset.vocab.decode(templ[1]))) hlog.value("var", names) hlog.value("score", score) with hlog.task("realized"): hlog.value("inp", " ".join(inp_realized)) hlog.value("out", " ".join(out_realized)) realized.add((inp_realized, out_realized)) data = [{"inp": inp, "out": out} for inp, out in realized] with open(FLAGS.write, "w") as fh: json.dump(data, fh, indent=2)
def _compute_adjacency(self, utts): counts = Counter() for utt in utts: inp, out = utt for seq in (inp, out): enc = self.vocab.encode(seq)[1:-1] for span in range(1, FLAGS.wug_size+1): for i in range(len(enc)+1-span): counts[tuple(enc[i:i+span])] += 1 if FLAGS.wug_limit is None: keep_args = set(counts.keys()) else: keep_args = set([c for c, n in counts.items() if n <= FLAGS.wug_limit]) def make_store(initializer): if FLAGS.use_trie: return DefaultTrie(initializer) else: return defaultdict(initializer) def compute_templ_sim(templates): wugs = {self.vocab[w] for w in _wugs()} size = FLAGS.sim_window_size def wug_indices(templ): return tuple(i for i, t in enumerate(templ) if t in wugs) templ_to_sig = make_store(set) sig_to_templ = make_store(set) for templ in templates: indices = wug_indices(templ) sig = tuple(templ[i-size:i+size+1] for i in indices) templ_to_sig[templ].add(sig) sig_to_templ[sig].add(templ) templ_sim = make_store(set) for templ1 in templates: for sig in templ_to_sig[templ1]: for templ2 in sig_to_templ[sig]: templ_sim[templ1].add(templ2) return templ_sim def enumerate_templates(): for i, utt in enumerate(utts): inp, out = utt seq = inp + (sep,) + out if FLAGS.max_comp_len is not None and len(seq) >= FLAGS.max_comp_len: continue if i % 1000 == 0: hlog.value("template_utt", "%d/%d" % (i, len(utts))) for generic in self._make_generic(seq, keep_args): yield generic, utt arg_to_templ = make_store(set) templ_to_arg = make_store(set) templ_to_templ = make_store(set) #sim_templ = FuzzyIndex(tfidf=True) #templ_to_orig = defaultdict(set) for (templ, args), orig in enumerate_templates(): arg_to_templ[args].add(templ) templ_to_arg[templ].add(args) #sim_templ.put(templ, args) #templ_to_orig[templ].add(orig) if FLAGS.template_sim == "window": templ_sim = compute_templ_sim(templ_to_arg.keys()) else: templ_sim = {t: set([t]) for t in templ_to_arg.keys()} multiplicity = make_store(lambda: 0) for i_arg, args1 in enumerate(arg_to_templ.keys()): if i_arg % 10000 == 0: hlog.value("template_arg", "%d/%d" % (i_arg, len(arg_to_templ))) for templ1 in arg_to_templ[args1]: multiplicity[templ1] += 1 c = 0 for templ2_pre in arg_to_templ[args1]: for templ2 in templ_sim[templ2_pre]: if templ1 == templ2: continue #if (templ1, templ2) in templ_to_templ: # continue templ_to_templ[templ2].add(templ1) c += 1 if ( FLAGS.max_adjacencies is not None and c >= FLAGS.max_adjacencies ): break self.templ_to_arg = templ_to_arg #self.arg_to_templ = arg_to_templ self.templ_to_templ = templ_to_templ self.multiplicity = multiplicity comp_pairs = [] for templ1 in self.templ_to_templ: if self.multiplicity[templ1] <= 1: continue for templ2 in self.templ_to_templ[templ1]: comp_pairs.append((templ1, templ2)) self.comp_pairs = sorted(comp_pairs) self.templates = sorted(self.templ_to_arg.keys())