Ejemplo n.º 1
0
def train(dataset, model, sample, callback, staged):
    if not isinstance(model, nn.Module):
        return

    opt = optim.Adam(model.parameters(), lr=FLAGS.lr)
    #sched = opt_sched.CosineAnnealingLR(opt, T_max=FLAGS.n_epochs)
    if FLAGS.sched_factor < 1:
        sched = opt_sched.ReduceLROnPlateau(opt,
                                            mode='max',
                                            factor=FLAGS.sched_factor,
                                            verbose=True)

    for i_epoch in hlog.loop("%05d", range(FLAGS.n_epochs)):
        model.train()
        epoch_loss = 0
        for i_batch in range(FLAGS.n_epoch_batches):
            #sched.step()
            opt.zero_grad()
            datum = make_batch([sample() for _ in range(FLAGS.n_batch)],
                               dataset.vocab, staged)
            loss = model(datum.inp_data, datum.out_data, datum.direct_out_data,
                         datum.copy_out_data, *datum.extra)
            loss.backward()
            clip_grad_norm_(model.parameters(), FLAGS.clip)
            opt.step()
            epoch_loss += loss.item()
        epoch_loss /= FLAGS.n_epoch_batches
        hlog.value("loss", epoch_loss)
        val_score = callback(i_epoch)
        if FLAGS.sched_factor < 1:
            sched.step(val_score)
Ejemplo n.º 2
0
 def enumerate_templates():
     for i, utt in enumerate(utts):
         inp, out = utt
         seq = inp + (sep,) + out
         if FLAGS.max_comp_len is not None and len(seq) >= FLAGS.max_comp_len:
             continue
         if i % 1000 == 0:
             hlog.value("template_utt", "%d/%d" % (i, len(utts)))
         for generic in self._make_generic(seq, keep_args):
             yield generic, utt
Ejemplo n.º 3
0
def evaluate(score_utts, data, dataset):
    _, utts = zip(*data)
    baseline_nll = score_utts(utts, baseline=True)
    nll = score_utts(utts)
    tval, pval = scipy.stats.ttest_rel(nll, baseline_nll)
    n_toks = sum(len(utt) - 1 for utt in utts)
    nll_norm = nll.sum() / n_toks
    ppl = np.exp(nll_norm)
    hlog.value("ppl", ppl)
    hlog.value("t/p", str(tval) + " " + str(pval))
    return -ppl
Ejemplo n.º 4
0
def main():
    factory = GrammarFactory()
    vocab = factory.vocab()
    model = InductorModel(vocab).to(DEVICE)
    opt = optim.RMSprop(model.parameters(), lr=0.0003)

    with hlog.task("train"):
        for i_epoch in hlog.loop("%05d", range(1000), timer=False):
            epoch_loss = 0
            for i_iter in range(10):
                opt.zero_grad()
                loss = 0
                for i_batch_part in range(n_batch):
                    ctx, inp, out = sample_batch(factory, vocab)
                    loss += model(ctx, inp, out)
                loss.backward()
                clip_grad_norm_(model.parameters(), .1)
                opt.step()
                epoch_loss += loss.item() / n_batch
            hlog.value("loss", epoch_loss)
Ejemplo n.º 5
0
    def __init__(
            self,
            train_utts,
            val_utts,
            test_utts,
            aug_data=(),
            invert=False,
    ):
        vocab = Vocab()
        for i in range(FLAGS.wug_count):
            vocab.add(wug_template % i)
        vocab.add(sep)
        for utts in (train_utts, val_utts, test_utts):
            for inp, out in utts:
                for seq in (inp, out):
                    for word in seq:
                        vocab.add(word)

        aug_utts = [(tuple(d["inp"]), tuple(d["out"])) for d in aug_data]
        if FLAGS.dedup:
            train_utts = [(tuple(i), tuple(o)) for i, o in train_utts]
            train_utts = sorted(list(set(train_utts)))
        hlog.value("train", len(train_utts))
        hlog.value("aug", len(aug_utts))

        if invert:
            train_utts = [(o, i) for i, o in train_utts]
            aug_utts = [(o, i) for i, o in aug_utts]
            val_utts = [(o, i) for i, o in val_utts]
            test_utts = [(o, i) for i, o in test_utts]

        self.vocab = vocab
        self.sep = sep
        self.train_utts = train_utts
        self.aug_utts = aug_utts
        self.val_utts = val_utts
        self.test_utts = test_utts
        if FLAGS.compute_adjacency:
            self._compute_adjacency(train_utts)
Ejemplo n.º 6
0
def visualize(datum, vocab, model):
    for (inp, out) in zip(*datum.inp):
        hlog.value("inp", " ".join(vocab.decode(inp + out)))
    for (inp, out) in zip(*datum.out):
        hlog.value("out", " ".join(vocab.decode(inp + out)))
    samples, _ = model.sample(datum.inp_data, greedy=True)
    for (inp, out) in zip(*samples):
        hlog.value("???", " ".join(vocab.decode(inp + out)))
    print()
Ejemplo n.º 7
0
def evaluate(model, data, dataset, vis=False, beam=False):
    correct = 0
    total = 0
    for i in range(0, len(data), FLAGS.n_batch):
        batch = make_batch(data[i:i + FLAGS.n_batch],
                           model.vocab,
                           staged=False)
        preds, _ = model.sample(batch.inp_data, greedy=True, beam=beam)
        for j in range(len(preds)):
            score_here = dataset.score(preds[j], batch.out[j], batch.inp[j])
            if vis:
                with hlog.task(str(total)):
                    hlog.value("input",
                               " ".join(model.vocab.decode(batch.inp[j])))
                    hlog.value("pred", " ".join(model.vocab.decode(preds[j])))
                    hlog.value("gold",
                               " ".join(model.vocab.decode(batch.out[j])))
                    hlog.value("corr", score_here)
                    hlog.log("")
            total += 1
            correct += score_here
    acc = 1. * correct / total
    hlog.value("acc", acc)
    return acc
Ejemplo n.º 8
0
def main(argv):
    torch.manual_seed(FLAGS.seed)
    np.random.seed(FLAGS.seed)

    hlog.flags()

    dataset = get_dataset()
    model = pick_model(dataset)

    model.prepare(dataset)
    if isinstance(model, nn.Module):
        path = os.path.join(FLAGS.model_dir, FLAGS.model)
        checkpoint = torch.load(path)
        model.load_state_dict(checkpoint)

    realized = set()
    examples = pick_examples(dataset)
    while len(realized) < FLAGS.n_sample:
        try:
            templ, names = next(examples)
        except StopIteration:
            break
        datum = make_batch([(templ, templ) for _ in range(10)],
                           dataset.vocab,
                           staged=True)
        (inps, outs), scores = model.sample(datum.inp_data, datum.out_data)

        keep = []
        for inp, out, score in zip(inps, outs, scores):
            inp_realized, inp_used = dataset.realize(inp, names)
            out_realized, out_used = dataset.realize(out, names)
            if ((not FLAGS.output_only)
                    and len(inp_used) == 0) or len(out_used) == 0:
                continue
            if len(inp_used | out_used) != len(names):
                continue
            if not ((FLAGS.output_only or dataset.novel(inp=inp_realized))
                    and dataset.novel(out=out_realized)):
                continue
            if (inp_realized, out_realized) in realized:
                continue
            keep.append(((inp_realized, out_realized), score))
        for (inp_realized, out_realized), score in keep:
            with hlog.task(str(len(realized))):
                hlog.value("inp", " ".join(dataset.vocab.decode(templ[0])))
                hlog.value("out", " ".join(dataset.vocab.decode(templ[1])))
                hlog.value("var", names)
                hlog.value("score", score)
                with hlog.task("realized"):
                    hlog.value("inp", " ".join(inp_realized))
                    hlog.value("out", " ".join(out_realized))
            realized.add((inp_realized, out_realized))

    data = [{"inp": inp, "out": out} for inp, out in realized]
    with open(FLAGS.write, "w") as fh:
        json.dump(data, fh, indent=2)
Ejemplo n.º 9
0
    def _compute_adjacency(self, utts):
        counts = Counter()
        for utt in utts:
            inp, out = utt
            for seq in (inp, out):
                enc = self.vocab.encode(seq)[1:-1]
                for span in range(1, FLAGS.wug_size+1):
                    for i in range(len(enc)+1-span):
                        counts[tuple(enc[i:i+span])] += 1
        if FLAGS.wug_limit is None:
            keep_args = set(counts.keys())
        else:
            keep_args = set([c for c, n in counts.items() if n <= FLAGS.wug_limit])

        def make_store(initializer):
            if FLAGS.use_trie:
                return DefaultTrie(initializer)
            else:
                return defaultdict(initializer)

        def compute_templ_sim(templates):
            wugs = {self.vocab[w] for w in _wugs()}
            size = FLAGS.sim_window_size
            def wug_indices(templ):
                return tuple(i for i, t in enumerate(templ) if t in wugs)
            templ_to_sig = make_store(set)
            sig_to_templ = make_store(set)

            for templ in templates:
                indices = wug_indices(templ)
                sig = tuple(templ[i-size:i+size+1] for i in indices)
                templ_to_sig[templ].add(sig)
                sig_to_templ[sig].add(templ)

            templ_sim = make_store(set)
            for templ1 in templates:
                for sig in templ_to_sig[templ1]:
                    for templ2 in sig_to_templ[sig]:
                        templ_sim[templ1].add(templ2)
            return templ_sim

        def enumerate_templates():
            for i, utt in enumerate(utts):
                inp, out = utt
                seq = inp + (sep,) + out
                if FLAGS.max_comp_len is not None and len(seq) >= FLAGS.max_comp_len:
                    continue
                if i % 1000 == 0:
                    hlog.value("template_utt", "%d/%d" % (i, len(utts)))
                for generic in self._make_generic(seq, keep_args):
                    yield generic, utt

        arg_to_templ = make_store(set)
        templ_to_arg = make_store(set)
        templ_to_templ = make_store(set)
        #sim_templ = FuzzyIndex(tfidf=True)
        #templ_to_orig = defaultdict(set)
        for (templ, args), orig in enumerate_templates():
            arg_to_templ[args].add(templ)
            templ_to_arg[templ].add(args)
            #sim_templ.put(templ, args)
            #templ_to_orig[templ].add(orig)

        if FLAGS.template_sim == "window":
            templ_sim = compute_templ_sim(templ_to_arg.keys())
        else:
            templ_sim = {t: set([t]) for t in templ_to_arg.keys()}

        multiplicity = make_store(lambda: 0)
        for i_arg, args1 in enumerate(arg_to_templ.keys()):
            if i_arg % 10000 == 0:
                hlog.value("template_arg", "%d/%d" % (i_arg, len(arg_to_templ)))
            for templ1 in arg_to_templ[args1]:
                multiplicity[templ1] += 1
                c = 0
                for templ2_pre in arg_to_templ[args1]:
                    for templ2 in templ_sim[templ2_pre]:
                        if templ1 == templ2:
                            continue
                        #if (templ1, templ2) in templ_to_templ:
                        #    continue
                        templ_to_templ[templ2].add(templ1)
                        c += 1
                        if (
                            FLAGS.max_adjacencies is not None 
                            and c >= FLAGS.max_adjacencies
                        ):
                            break

        self.templ_to_arg = templ_to_arg
        #self.arg_to_templ = arg_to_templ
        self.templ_to_templ = templ_to_templ
        self.multiplicity = multiplicity

        comp_pairs = []
        for templ1 in self.templ_to_templ:
            if self.multiplicity[templ1] <= 1:
                continue
            for templ2 in self.templ_to_templ[templ1]:
                comp_pairs.append((templ1, templ2))
        self.comp_pairs = sorted(comp_pairs)
        self.templates = sorted(self.templ_to_arg.keys())