Example #1
0
def run_eval(
    model,
    vocab,
    samples,
    device="cpu",
    projective=False,
    multiroot=True,
    batch_size=32,
):
    runner = Runner()
    runner.on(
        Event.BATCH,
        [
            batch2tensors(device, vocab),
            set_train_mode(model, training=False),
            compute_total_arc_type_scores(model, vocab),
            predict_batch(projective, multiroot),
            evaluate_batch(),
            get_n_items(),
        ],
    )

    n_tokens = sum(len(s["words"]) for s in samples)
    ProgressBar(leave=False, total=n_tokens, unit="tok").attach_on(runner)
    SumReducer("counts").attach_on(runner)

    with torch.no_grad():
        runner.run(
            BucketIterator(samples, lambda s: len(s["words"]), batch_size))

    return runner.state
Example #2
0
def run_eval(
    model,
    vocab,
    samples,
    compute_loss=True,
    device="cpu",
    projective=False,
    multiroot=True,
    batch_size=32,
):
    runner = Runner()
    runner.on(
        Event.BATCH,
        [
            batch2tensors(device, vocab),
            set_train_mode(model, training=False),
            compute_total_arc_type_scores(model, vocab),
        ],
    )

    @runner.on(Event.BATCH)
    def maybe_compute_loss(state):
        if not compute_loss:
            return

        ppt_loss = compute_aatrn_loss(
            state["total_arc_type_scores"],
            state["batch"]["ppt_mask"].bool(),
            projective=projective,
            multiroot=multiroot,
        )
        state["ppt_loss"] = ppt_loss.item()
        state["size"] = state["batch"]["words"].size(0)

    runner.on(Event.BATCH, [
        predict_batch(projective, multiroot),
        evaluate_batch(),
        get_n_items()
    ])

    n_tokens = sum(len(s["words"]) for s in samples)
    ProgressBar(leave=False, total=n_tokens, unit="tok").attach_on(runner)
    SumReducer("counts").attach_on(runner)
    if compute_loss:
        MeanReducer("mean_ppt_loss", value="ppt_loss").attach_on(runner)

    with torch.no_grad():
        runner.run(
            BucketIterator(samples, lambda s: len(s["words"]), batch_size))

    return runner.state
Example #3
0
def finetune(
    _log,
    _run,
    _rnd,
    max_length=None,
    artifacts_dir="ft_artifacts",
    overwrite=False,
    load_from="artifacts",
    load_params="model.pth",
    device="cpu",
    word_emb_path="wiki.id.vec",
    freeze=False,
    projective=False,
    multiroot=True,
    batch_size=32,
    lr=1e-5,
    l2_coef=1.0,
    max_epoch=5,
):
    """Finetune a trained model with self-training."""
    if max_length is None:
        max_length = {}

    artifacts_dir = Path(artifacts_dir)
    _log.info("Creating artifacts directory %s", artifacts_dir)
    artifacts_dir.mkdir(exist_ok=overwrite)

    samples = {
        wh: list(read_samples(which=wh, max_length=max_length.get(wh)))
        for wh in ["train", "dev", "test"]
    }
    for wh in samples:
        n_toks = sum(len(s["words"]) for s in samples[wh])
        _log.info("Read %d %s samples and %d tokens", len(samples[wh]), wh,
                  n_toks)

    path = Path(load_from) / "vocab.yml"
    _log.info("Loading vocabulary from %s", path)
    vocab = load(path.read_text(encoding="utf8"))
    for name in vocab:
        _log.info("Found %d %s", len(vocab[name]), name)

    _log.info("Extending vocabulary with target words")
    vocab.extend(chain(*samples.values()), ["words"])
    _log.info("Found %d words now", len(vocab["words"]))

    path = artifacts_dir / "vocab.yml"
    _log.info("Saving vocabulary to %s", path)
    path.write_text(dump(vocab), encoding="utf8")

    samples = {wh: list(vocab.stoi(samples[wh])) for wh in samples}

    path = Path(load_from) / "model.yml"
    _log.info("Loading model from metadata %s", path)
    model = load(path.read_text(encoding="utf8"))

    path = Path(load_from) / load_params
    _log.info("Loading model parameters from %s", path)
    model.load_state_dict(torch.load(path, "cpu"))

    _log.info("Creating extended word embedding layer")
    kv = KeyedVectors.load_word2vec_format(word_emb_path)
    assert model.word_emb.embedding_dim == kv.vector_size
    with torch.no_grad():
        model.word_emb = torch.nn.Embedding.from_pretrained(
            extend_word_embedding(model.word_emb.weight, vocab["words"], kv))

    path = artifacts_dir / "model.yml"
    _log.info("Saving model metadata to %s", path)
    path.write_text(dump(model), encoding="utf8")

    model.word_emb.requires_grad_(not freeze)
    model.tag_emb.requires_grad_(not freeze)
    model.to(device)

    for wh in ["train"]:
        for i, s in enumerate(samples[wh]):
            s["_id"] = i

        runner = Runner()
        runner.state.update({"st_heads": [], "st_types": [], "_ids": []})
        runner.on(
            Event.BATCH,
            [
                batch2tensors(device, vocab),
                set_train_mode(model, training=False),
                compute_total_arc_type_scores(model, vocab),
                predict_batch(projective, multiroot),
            ],
        )

        @runner.on(Event.BATCH)
        def save_st_trees(state):
            state["st_heads"].extend(state["pred_heads"].tolist())
            state["st_types"].extend(state["pred_types"].tolist())
            state["_ids"].extend(state["batch"]["_id"].tolist())
            state["n_items"] = state["batch"]["words"].numel()

        n_toks = sum(len(s["words"]) for s in samples[wh])
        ProgressBar(total=n_toks, unit="tok").attach_on(runner)

        _log.info("Computing ST trees for %s set", wh)
        with torch.no_grad():
            runner.run(
                BucketIterator(samples[wh], lambda s: len(s["words"]),
                               batch_size))

        assert len(runner.state["st_heads"]) == len(samples[wh])
        assert len(runner.state["st_types"]) == len(samples[wh])
        assert len(runner.state["_ids"]) == len(samples[wh])
        for i, st_heads, st_types in zip(runner.state["_ids"],
                                         runner.state["st_heads"],
                                         runner.state["st_types"]):
            assert len(samples[wh][i]["words"]) == len(st_heads)
            assert len(samples[wh][i]["words"]) == len(st_types)
            samples[wh][i]["st_heads"] = st_heads
            samples[wh][i]["st_types"] = st_types

    _log.info("Creating optimizer")
    opt = torch.optim.Adam(model.parameters(), lr=lr)

    finetuner = Runner()
    origin_params = {
        name: p.clone().detach()
        for name, p in model.named_parameters()
    }
    finetuner.on(
        Event.BATCH,
        [
            batch2tensors(device, vocab),
            set_train_mode(model),
            compute_l2_loss(model, origin_params),
        ],
    )

    @finetuner.on(Event.BATCH)
    def compute_loss(state):
        bat = state["batch"]
        words, tags, heads, types = bat["words"], bat["tags"], bat[
            "st_heads"], bat["st_types"]
        mask = bat["mask"]

        arc_scores, type_scores = model(words, tags, mask, heads)
        arc_scores = arc_scores.masked_fill(~mask.unsqueeze(2),
                                            -1e9)  # mask padding heads
        type_scores[..., vocab["types"].index(vocab.PAD_TOKEN)] = -1e9

        # remove root
        arc_scores, type_scores = arc_scores[:, :, 1:], type_scores[:, 1:]
        heads, types, mask = heads[:, 1:], types[:, 1:], mask[:, 1:]

        arc_scores = rearrange(arc_scores,
                               "bsz slen1 slen2 -> (bsz slen2) slen1")
        heads = heads.reshape(-1)
        arc_loss = torch.nn.functional.cross_entropy(arc_scores,
                                                     heads,
                                                     reduction="none")

        type_scores = rearrange(type_scores,
                                "bsz slen ntypes -> (bsz slen) ntypes")
        types = types.reshape(-1)
        type_loss = torch.nn.functional.cross_entropy(type_scores,
                                                      types,
                                                      reduction="none")

        arc_loss = arc_loss.masked_select(mask.reshape(-1)).mean()
        type_loss = type_loss.masked_select(mask.reshape(-1)).mean()
        loss = arc_loss + type_loss + l2_coef * state["l2_loss"]

        state["loss"] = loss
        state["stats"] = {
            "arc_ppl": arc_loss.exp().item(),
            "type_ppl": type_loss.exp().item(),
            "l2_loss": state["l2_loss"].item(),
        }
        state["extra_stats"] = {
            "arc_loss": arc_loss.item(),
            "type_loss": type_loss.item()
        }

    finetuner.on(
        Event.BATCH,
        [
            get_n_items(),
            update_params(opt),
            log_grads(_run, model),
            log_stats(_run)
        ],
    )

    @finetuner.on(Event.EPOCH_FINISHED)
    def eval_on_dev(state):
        _log.info("Evaluating on dev")
        eval_state = run_eval(model, vocab, samples["dev"])
        accs = eval_state["counts"].accs
        print_accs(accs, run=_run, step=state["n_iters"])
        state["dev_accs"] = accs

    @finetuner.on(Event.EPOCH_FINISHED)
    def maybe_eval_on_test(state):
        if state["epoch"] != max_epoch:
            return

        _log.info("Evaluating on test")
        eval_state = run_eval(model, vocab, samples["test"])
        print_accs(eval_state["counts"].accs,
                   on="test",
                   run=_run,
                   step=state["n_iters"])

    finetuner.on(Event.EPOCH_FINISHED,
                 save_state_dict("model", model, under=artifacts_dir))

    EpochTimer().attach_on(finetuner)
    n_tokens = sum(len(s["words"]) for s in samples["train"])
    ProgressBar(stats="stats", total=n_tokens, unit="tok").attach_on(finetuner)

    bucket_key = lambda s: (len(s["words"]) - 1) // 10
    trn_iter = ShuffleIterator(
        BucketIterator(samples["train"],
                       bucket_key,
                       batch_size,
                       shuffle_bucket=True,
                       rng=_rnd),
        rng=_rnd,
    )
    _log.info("Starting finetuning")
    try:
        finetuner.run(trn_iter, max_epoch)
    except KeyboardInterrupt:
        _log.info("Interrupt detected, training will abort")
    else:
        return finetuner.state["dev_accs"]["las_nopunct"]