Example #1
0
def run_eval(
    model,
    vocab,
    samples,
    device="cpu",
    projective=False,
    multiroot=True,
    batch_size=32,
):
    runner = Runner()
    runner.on(
        Event.BATCH,
        [
            batch2tensors(device, vocab),
            set_train_mode(model, training=False),
            compute_total_arc_type_scores(model, vocab),
            predict_batch(projective, multiroot),
            evaluate_batch(),
            get_n_items(),
        ],
    )

    n_tokens = sum(len(s["words"]) for s in samples)
    ProgressBar(leave=False, total=n_tokens, unit="tok").attach_on(runner)
    SumReducer("counts").attach_on(runner)

    with torch.no_grad():
        runner.run(
            BucketIterator(samples, lambda s: len(s["words"]), batch_size))

    return runner.state
Example #2
0
def run_eval(
    model,
    vocab,
    samples,
    compute_loss=True,
    device="cpu",
    projective=False,
    multiroot=True,
    batch_size=32,
):
    runner = Runner()
    runner.on(
        Event.BATCH,
        [
            batch2tensors(device, vocab),
            set_train_mode(model, training=False),
            compute_total_arc_type_scores(model, vocab),
        ],
    )

    @runner.on(Event.BATCH)
    def maybe_compute_loss(state):
        if not compute_loss:
            return

        ppt_loss = compute_aatrn_loss(
            state["total_arc_type_scores"],
            state["batch"]["ppt_mask"].bool(),
            projective=projective,
            multiroot=multiroot,
        )
        state["ppt_loss"] = ppt_loss.item()
        state["size"] = state["batch"]["words"].size(0)

    runner.on(Event.BATCH, [
        predict_batch(projective, multiroot),
        evaluate_batch(),
        get_n_items()
    ])

    n_tokens = sum(len(s["words"]) for s in samples)
    ProgressBar(leave=False, total=n_tokens, unit="tok").attach_on(runner)
    SumReducer("counts").attach_on(runner)
    if compute_loss:
        MeanReducer("mean_ppt_loss", value="ppt_loss").attach_on(runner)

    with torch.no_grad():
        runner.run(
            BucketIterator(samples, lambda s: len(s["words"]), batch_size))

    return runner.state
Example #3
0
def finetune(
    _log,
    _run,
    _rnd,
    max_length=None,
    artifacts_dir="ft_artifacts",
    overwrite=False,
    load_from="artifacts",
    load_params="model.pth",
    device="cpu",
    word_emb_path="wiki.id.vec",
    freeze=False,
    thresh=0.95,
    projective=False,
    multiroot=True,
    batch_size=32,
    lr=1e-5,
    l2_coef=1.0,
    max_epoch=5,
):
    """Finetune a trained model with PPT."""
    if max_length is None:
        max_length = {}

    artifacts_dir = Path(artifacts_dir)
    _log.info("Creating artifacts directory %s", artifacts_dir)
    artifacts_dir.mkdir(exist_ok=overwrite)

    samples = {
        wh: list(read_samples(which=wh, max_length=max_length.get(wh)))
        for wh in ["train", "dev", "test"]
    }
    for wh in samples:
        n_toks = sum(len(s["words"]) for s in samples[wh])
        _log.info("Read %d %s samples and %d tokens", len(samples[wh]), wh,
                  n_toks)

    path = Path(load_from) / "vocab.yml"
    _log.info("Loading vocabulary from %s", path)
    vocab = load(path.read_text(encoding="utf8"))
    for name in vocab:
        _log.info("Found %d %s", len(vocab[name]), name)

    _log.info("Extending vocabulary with target words")
    vocab.extend(chain(*samples.values()), ["words"])
    _log.info("Found %d words now", len(vocab["words"]))

    path = artifacts_dir / "vocab.yml"
    _log.info("Saving vocabulary to %s", path)
    path.write_text(dump(vocab), encoding="utf8")

    samples = {wh: list(vocab.stoi(samples[wh])) for wh in samples}

    path = Path(load_from) / "model.yml"
    _log.info("Loading model from metadata %s", path)
    model = load(path.read_text(encoding="utf8"))

    path = Path(load_from) / load_params
    _log.info("Loading model parameters from %s", path)
    model.load_state_dict(torch.load(path, "cpu"))

    _log.info("Creating extended word embedding layer")
    kv = KeyedVectors.load_word2vec_format(word_emb_path)
    assert model.word_emb.embedding_dim == kv.vector_size
    with torch.no_grad():
        model.word_emb = torch.nn.Embedding.from_pretrained(
            extend_word_embedding(model.word_emb.weight, vocab["words"], kv))

    path = artifacts_dir / "model.yml"
    _log.info("Saving model metadata to %s", path)
    path.write_text(dump(model), encoding="utf8")

    model.word_emb.requires_grad_(not freeze)
    model.tag_emb.requires_grad_(not freeze)
    model.to(device)

    for wh in ["train", "dev"]:
        for i, s in enumerate(samples[wh]):
            s["_id"] = i

        runner = Runner()
        runner.state.update({"ppt_masks": [], "_ids": []})
        runner.on(
            Event.BATCH,
            [
                batch2tensors(device, vocab),
                set_train_mode(model, training=False),
                compute_total_arc_type_scores(model, vocab),
            ],
        )

        @runner.on(Event.BATCH)
        def compute_ppt_ambiguous_arcs_mask(state):
            assert state["batch"]["mask"].all()
            scores = state["total_arc_type_scores"]
            ppt_mask = compute_ambiguous_arcs_mask(scores, thresh, projective,
                                                   multiroot)
            state["ppt_masks"].extend(ppt_mask.tolist())
            state["_ids"].extend(state["batch"]["_id"].tolist())
            state["n_items"] = state["batch"]["words"].numel()

        n_toks = sum(len(s["words"]) for s in samples[wh])
        ProgressBar(total=n_toks, unit="tok").attach_on(runner)

        _log.info("Computing PPT ambiguous arcs mask for %s set", wh)
        with torch.no_grad():
            runner.run(
                BucketIterator(samples[wh], lambda s: len(s["words"]),
                               batch_size))

        assert len(runner.state["ppt_masks"]) == len(samples[wh])
        assert len(runner.state["_ids"]) == len(samples[wh])
        for i, ppt_mask in zip(runner.state["_ids"],
                               runner.state["ppt_masks"]):
            samples[wh][i]["ppt_mask"] = ppt_mask

        _log.info("Computing (log) number of trees stats on %s set", wh)
        report_log_ntrees_stats(samples[wh], "ppt_mask", batch_size,
                                projective, multiroot)

    _log.info("Creating optimizer")
    opt = torch.optim.Adam(model.parameters(), lr=lr)

    finetuner = Runner()
    origin_params = {
        name: p.clone().detach()
        for name, p in model.named_parameters()
    }
    finetuner.on(
        Event.BATCH,
        [
            batch2tensors(device, vocab),
            set_train_mode(model),
            compute_l2_loss(model, origin_params),
            compute_total_arc_type_scores(model, vocab),
        ],
    )

    @finetuner.on(Event.BATCH)
    def compute_loss(state):
        mask = state["batch"]["mask"]
        ppt_mask = state["batch"]["ppt_mask"].bool()
        scores = state["total_arc_type_scores"]

        ppt_loss = compute_aatrn_loss(scores, ppt_mask, mask, projective,
                                      multiroot)
        ppt_loss /= mask.size(0)
        loss = ppt_loss + l2_coef * state["l2_loss"]

        state["loss"] = loss
        state["stats"] = {
            "ppt_loss": ppt_loss.item(),
            "l2_loss": state["l2_loss"].item(),
        }
        state["extra_stats"] = {"loss": loss.item()}
        state["n_items"] = mask.long().sum().item()

    finetuner.on(Event.BATCH,
                 [update_params(opt),
                  log_grads(_run, model),
                  log_stats(_run)])

    @finetuner.on(Event.EPOCH_FINISHED)
    def eval_on_dev(state):
        _log.info("Evaluating on dev")
        eval_state = run_eval(model, vocab, samples["dev"])
        accs = eval_state["counts"].accs
        print_accs(accs, run=_run, step=state["n_iters"])

        ppt_loss = eval_state["mean_ppt_loss"]
        _log.info("dev_ppt_loss: %.4f", ppt_loss)
        _run.log_scalar("dev_ppt_loss", ppt_loss, step=state["n_iters"])

        state["dev_accs"] = accs

    @finetuner.on(Event.EPOCH_FINISHED)
    def maybe_eval_on_test(state):
        if state["epoch"] != max_epoch:
            return

        _log.info("Evaluating on test")
        eval_state = run_eval(model,
                              vocab,
                              samples["test"],
                              compute_loss=False)
        print_accs(eval_state["counts"].accs,
                   on="test",
                   run=_run,
                   step=state["n_iters"])

    finetuner.on(Event.EPOCH_FINISHED,
                 save_state_dict("model", model, under=artifacts_dir))

    EpochTimer().attach_on(finetuner)
    n_tokens = sum(len(s["words"]) for s in samples["train"])
    ProgressBar(stats="stats", total=n_tokens, unit="tok").attach_on(finetuner)

    bucket_key = lambda s: (len(s["words"]) - 1) // 10
    trn_iter = ShuffleIterator(
        BucketIterator(samples["train"],
                       bucket_key,
                       batch_size,
                       shuffle_bucket=True,
                       rng=_rnd),
        rng=_rnd,
    )
    _log.info("Starting finetuning")
    try:
        finetuner.run(trn_iter, max_epoch)
    except KeyboardInterrupt:
        _log.info("Interrupt detected, training will abort")
    else:
        return finetuner.state["dev_accs"]["las_nopunct"]
Example #4
0
def finetune(
    _log,
    _run,
    _rnd,
    max_length=None,
    artifacts_dir="ft_artifacts",
    overwrite=False,
    load_from="artifacts",
    load_params="model.pth",
    device="cpu",
    word_emb_path="wiki.id.vec",
    freeze=False,
    projective=False,
    multiroot=True,
    batch_size=32,
    lr=1e-5,
    l2_coef=1.0,
    max_epoch=5,
):
    """Finetune a trained model with self-training."""
    if max_length is None:
        max_length = {}

    artifacts_dir = Path(artifacts_dir)
    _log.info("Creating artifacts directory %s", artifacts_dir)
    artifacts_dir.mkdir(exist_ok=overwrite)

    samples = {
        wh: list(read_samples(which=wh, max_length=max_length.get(wh)))
        for wh in ["train", "dev", "test"]
    }
    for wh in samples:
        n_toks = sum(len(s["words"]) for s in samples[wh])
        _log.info("Read %d %s samples and %d tokens", len(samples[wh]), wh,
                  n_toks)

    path = Path(load_from) / "vocab.yml"
    _log.info("Loading vocabulary from %s", path)
    vocab = load(path.read_text(encoding="utf8"))
    for name in vocab:
        _log.info("Found %d %s", len(vocab[name]), name)

    _log.info("Extending vocabulary with target words")
    vocab.extend(chain(*samples.values()), ["words"])
    _log.info("Found %d words now", len(vocab["words"]))

    path = artifacts_dir / "vocab.yml"
    _log.info("Saving vocabulary to %s", path)
    path.write_text(dump(vocab), encoding="utf8")

    samples = {wh: list(vocab.stoi(samples[wh])) for wh in samples}

    path = Path(load_from) / "model.yml"
    _log.info("Loading model from metadata %s", path)
    model = load(path.read_text(encoding="utf8"))

    path = Path(load_from) / load_params
    _log.info("Loading model parameters from %s", path)
    model.load_state_dict(torch.load(path, "cpu"))

    _log.info("Creating extended word embedding layer")
    kv = KeyedVectors.load_word2vec_format(word_emb_path)
    assert model.word_emb.embedding_dim == kv.vector_size
    with torch.no_grad():
        model.word_emb = torch.nn.Embedding.from_pretrained(
            extend_word_embedding(model.word_emb.weight, vocab["words"], kv))

    path = artifacts_dir / "model.yml"
    _log.info("Saving model metadata to %s", path)
    path.write_text(dump(model), encoding="utf8")

    model.word_emb.requires_grad_(not freeze)
    model.tag_emb.requires_grad_(not freeze)
    model.to(device)

    for wh in ["train"]:
        for i, s in enumerate(samples[wh]):
            s["_id"] = i

        runner = Runner()
        runner.state.update({"st_heads": [], "st_types": [], "_ids": []})
        runner.on(
            Event.BATCH,
            [
                batch2tensors(device, vocab),
                set_train_mode(model, training=False),
                compute_total_arc_type_scores(model, vocab),
                predict_batch(projective, multiroot),
            ],
        )

        @runner.on(Event.BATCH)
        def save_st_trees(state):
            state["st_heads"].extend(state["pred_heads"].tolist())
            state["st_types"].extend(state["pred_types"].tolist())
            state["_ids"].extend(state["batch"]["_id"].tolist())
            state["n_items"] = state["batch"]["words"].numel()

        n_toks = sum(len(s["words"]) for s in samples[wh])
        ProgressBar(total=n_toks, unit="tok").attach_on(runner)

        _log.info("Computing ST trees for %s set", wh)
        with torch.no_grad():
            runner.run(
                BucketIterator(samples[wh], lambda s: len(s["words"]),
                               batch_size))

        assert len(runner.state["st_heads"]) == len(samples[wh])
        assert len(runner.state["st_types"]) == len(samples[wh])
        assert len(runner.state["_ids"]) == len(samples[wh])
        for i, st_heads, st_types in zip(runner.state["_ids"],
                                         runner.state["st_heads"],
                                         runner.state["st_types"]):
            assert len(samples[wh][i]["words"]) == len(st_heads)
            assert len(samples[wh][i]["words"]) == len(st_types)
            samples[wh][i]["st_heads"] = st_heads
            samples[wh][i]["st_types"] = st_types

    _log.info("Creating optimizer")
    opt = torch.optim.Adam(model.parameters(), lr=lr)

    finetuner = Runner()
    origin_params = {
        name: p.clone().detach()
        for name, p in model.named_parameters()
    }
    finetuner.on(
        Event.BATCH,
        [
            batch2tensors(device, vocab),
            set_train_mode(model),
            compute_l2_loss(model, origin_params),
        ],
    )

    @finetuner.on(Event.BATCH)
    def compute_loss(state):
        bat = state["batch"]
        words, tags, heads, types = bat["words"], bat["tags"], bat[
            "st_heads"], bat["st_types"]
        mask = bat["mask"]

        arc_scores, type_scores = model(words, tags, mask, heads)
        arc_scores = arc_scores.masked_fill(~mask.unsqueeze(2),
                                            -1e9)  # mask padding heads
        type_scores[..., vocab["types"].index(vocab.PAD_TOKEN)] = -1e9

        # remove root
        arc_scores, type_scores = arc_scores[:, :, 1:], type_scores[:, 1:]
        heads, types, mask = heads[:, 1:], types[:, 1:], mask[:, 1:]

        arc_scores = rearrange(arc_scores,
                               "bsz slen1 slen2 -> (bsz slen2) slen1")
        heads = heads.reshape(-1)
        arc_loss = torch.nn.functional.cross_entropy(arc_scores,
                                                     heads,
                                                     reduction="none")

        type_scores = rearrange(type_scores,
                                "bsz slen ntypes -> (bsz slen) ntypes")
        types = types.reshape(-1)
        type_loss = torch.nn.functional.cross_entropy(type_scores,
                                                      types,
                                                      reduction="none")

        arc_loss = arc_loss.masked_select(mask.reshape(-1)).mean()
        type_loss = type_loss.masked_select(mask.reshape(-1)).mean()
        loss = arc_loss + type_loss + l2_coef * state["l2_loss"]

        state["loss"] = loss
        state["stats"] = {
            "arc_ppl": arc_loss.exp().item(),
            "type_ppl": type_loss.exp().item(),
            "l2_loss": state["l2_loss"].item(),
        }
        state["extra_stats"] = {
            "arc_loss": arc_loss.item(),
            "type_loss": type_loss.item()
        }

    finetuner.on(
        Event.BATCH,
        [
            get_n_items(),
            update_params(opt),
            log_grads(_run, model),
            log_stats(_run)
        ],
    )

    @finetuner.on(Event.EPOCH_FINISHED)
    def eval_on_dev(state):
        _log.info("Evaluating on dev")
        eval_state = run_eval(model, vocab, samples["dev"])
        accs = eval_state["counts"].accs
        print_accs(accs, run=_run, step=state["n_iters"])
        state["dev_accs"] = accs

    @finetuner.on(Event.EPOCH_FINISHED)
    def maybe_eval_on_test(state):
        if state["epoch"] != max_epoch:
            return

        _log.info("Evaluating on test")
        eval_state = run_eval(model, vocab, samples["test"])
        print_accs(eval_state["counts"].accs,
                   on="test",
                   run=_run,
                   step=state["n_iters"])

    finetuner.on(Event.EPOCH_FINISHED,
                 save_state_dict("model", model, under=artifacts_dir))

    EpochTimer().attach_on(finetuner)
    n_tokens = sum(len(s["words"]) for s in samples["train"])
    ProgressBar(stats="stats", total=n_tokens, unit="tok").attach_on(finetuner)

    bucket_key = lambda s: (len(s["words"]) - 1) // 10
    trn_iter = ShuffleIterator(
        BucketIterator(samples["train"],
                       bucket_key,
                       batch_size,
                       shuffle_bucket=True,
                       rng=_rnd),
        rng=_rnd,
    )
    _log.info("Starting finetuning")
    try:
        finetuner.run(trn_iter, max_epoch)
    except KeyboardInterrupt:
        _log.info("Interrupt detected, training will abort")
    else:
        return finetuner.state["dev_accs"]["las_nopunct"]
Example #5
0
def finetune(
    corpus,
    _log,
    _run,
    _rnd,
    max_length=None,
    artifacts_dir="ft_artifacts",
    load_samples_from=None,
    overwrite=False,
    load_src=None,
    src_key_as_lang=False,
    main_src=None,
    device="cpu",
    word_emb_path="wiki.id.vec",
    freeze=False,
    thresh=0.95,
    projective=False,
    multiroot=True,
    batch_size=32,
    save_samples=False,
    lr=1e-5,
    l2_coef=1.0,
    max_epoch=5,
):
    """Finetune a trained model with PPTX."""
    if max_length is None:
        max_length = {}
    if load_src is None:
        load_src = {"src": ("artifacts", "model.pth")}
        main_src = "src"
    elif main_src not in load_src:
        raise ValueError(f"{main_src} not found in load_src")

    artifacts_dir = Path(artifacts_dir)
    _log.info("Creating artifacts directory %s", artifacts_dir)
    artifacts_dir.mkdir(exist_ok=overwrite)

    if load_samples_from:
        _log.info("Loading samples from %s", load_samples_from)
        with open(load_samples_from, "rb") as f:
            samples = pickle.load(f)
    else:
        samples = {
            wh: list(read_samples(which=wh, max_length=max_length.get(wh)))
            for wh in ["train", "dev", "test"]
        }
    for wh in samples:
        n_toks = sum(len(s["words"]) for s in samples[wh])
        _log.info("Read %d %s samples and %d tokens", len(samples[wh]), wh,
                  n_toks)

    kv = KeyedVectors.load_word2vec_format(word_emb_path)

    if load_samples_from:
        _log.info(
            "Skipping non-main src because samples are processed and loaded")
        srcs = []
    else:
        srcs = [src for src in load_src if src != main_src]
        if src_key_as_lang and corpus["lang"] in srcs:
            _log.info("Removing %s from src parsers because it's the tgt",
                      corpus["lang"])
            srcs.remove(corpus["lang"])
    srcs.append(main_src)

    for src_i, src in enumerate(srcs):
        _log.info("Processing src %s [%d/%d]", src, src_i + 1, len(srcs))
        load_from, load_params = load_src[src]
        path = Path(load_from) / "vocab.yml"
        _log.info("Loading %s vocabulary from %s", src, path)
        vocab = load(path.read_text(encoding="utf8"))
        for name in vocab:
            _log.info("Found %d %s", len(vocab[name]), name)

        _log.info("Extending %s vocabulary with target words", src)
        vocab.extend(chain(*samples.values()), ["words"])
        _log.info("Found %d words now", len(vocab["words"]))

        samples_ = {wh: list(vocab.stoi(samples[wh])) for wh in samples}

        path = Path(load_from) / "model.yml"
        _log.info("Loading %s model from metadata %s", src, path)
        model = load(path.read_text(encoding="utf8"))

        path = Path(load_from) / load_params
        _log.info("Loading %s model parameters from %s", src, path)
        model.load_state_dict(torch.load(path, "cpu"))

        _log.info("Creating %s extended word embedding layer", src)
        assert model.word_emb.embedding_dim == kv.vector_size
        with torch.no_grad():
            model.word_emb = torch.nn.Embedding.from_pretrained(
                extend_word_embedding(model.word_emb.weight, vocab["words"],
                                      kv))
        model.to(device)

        for wh in ["train", "dev"]:
            if load_samples_from:
                assert all("pptx_mask" in s for s in samples[wh])
                continue

            for i, s in enumerate(samples_[wh]):
                s["_id"] = i

            runner = Runner()
            runner.state.update({"pptx_masks": [], "_ids": []})
            runner.on(
                Event.BATCH,
                [
                    batch2tensors(device, vocab),
                    set_train_mode(model, training=False),
                    compute_total_arc_type_scores(model, vocab),
                ],
            )

            @runner.on(Event.BATCH)
            def compute_pptx_ambiguous_arcs_mask(state):
                assert state["batch"]["mask"].all()
                scores = state["total_arc_type_scores"]
                pptx_mask = compute_ambiguous_arcs_mask(
                    scores, thresh, projective, multiroot)
                state["pptx_masks"].extend(pptx_mask)
                state["_ids"].extend(state["batch"]["_id"].tolist())
                state["n_items"] = state["batch"]["words"].numel()

            n_toks = sum(len(s["words"]) for s in samples_[wh])
            ProgressBar(total=n_toks, unit="tok").attach_on(runner)

            _log.info(
                "Computing PPTX ambiguous arcs mask for %s set with source %s",
                wh, src)
            with torch.no_grad():
                runner.run(
                    BucketIterator(samples_[wh], lambda s: len(s["words"]),
                                   batch_size))

            assert len(runner.state["pptx_masks"]) == len(samples_[wh])
            assert len(runner.state["_ids"]) == len(samples_[wh])
            for i, pptx_mask in zip(runner.state["_ids"],
                                    runner.state["pptx_masks"]):
                samples_[wh][i]["pptx_mask"] = pptx_mask.tolist()

            _log.info("Computing (log) number of trees stats on %s set", wh)
            report_log_ntrees_stats(samples_[wh], "pptx_mask", batch_size,
                                    projective, multiroot)

            _log.info("Combining the ambiguous arcs mask")
            assert len(samples_[wh]) == len(samples[wh])
            for i in range(len(samples_[wh])):
                pptx_mask = torch.tensor(samples_[wh][i]["pptx_mask"])
                assert pptx_mask.dim() == 3
                if "pptx_mask" in samples[wh][i]:
                    old_mask = torch.tensor(samples[wh][i]["pptx_mask"])
                else:
                    old_mask = torch.zeros(1, 1, 1).bool()
                samples[wh][i]["pptx_mask"] = (old_mask | pptx_mask).tolist()

    assert src == main_src
    _log.info("Main source is %s", src)

    path = artifacts_dir / "vocab.yml"
    _log.info("Saving vocabulary to %s", path)
    path.write_text(dump(vocab), encoding="utf8")

    path = artifacts_dir / "model.yml"
    _log.info("Saving model metadata to %s", path)
    path.write_text(dump(model), encoding="utf8")

    if save_samples:
        path = artifacts_dir / "samples.pkl"
        _log.info("Saving samples to %s", path)
        with open(path, "wb") as f:
            pickle.dump(samples, f)

    samples = {wh: list(vocab.stoi(samples[wh])) for wh in samples}

    for wh in ["train", "dev"]:
        _log.info("Computing (log) number of trees stats on %s set", wh)
        report_log_ntrees_stats(samples[wh], "pptx_mask", batch_size,
                                projective, multiroot)

    model.word_emb.requires_grad_(not freeze)
    model.tag_emb.requires_grad_(not freeze)

    _log.info("Creating optimizer")
    opt = torch.optim.Adam(model.parameters(), lr=lr)

    finetuner = Runner()
    origin_params = {
        name: p.clone().detach()
        for name, p in model.named_parameters()
    }
    finetuner.on(
        Event.BATCH,
        [
            batch2tensors(device, vocab),
            set_train_mode(model),
            compute_l2_loss(model, origin_params),
            compute_total_arc_type_scores(model, vocab),
        ],
    )

    @finetuner.on(Event.BATCH)
    def compute_loss(state):
        mask = state["batch"]["mask"]
        pptx_mask = state["batch"]["pptx_mask"].bool()
        scores = state["total_arc_type_scores"]

        pptx_loss = compute_aatrn_loss(scores, pptx_mask, mask, projective,
                                       multiroot)
        pptx_loss /= mask.size(0)
        loss = pptx_loss + l2_coef * state["l2_loss"]

        state["loss"] = loss
        state["stats"] = {
            "pptx_loss": pptx_loss.item(),
            "l2_loss": state["l2_loss"].item(),
        }
        state["extra_stats"] = {"loss": loss.item()}
        state["n_items"] = mask.long().sum().item()

    finetuner.on(Event.BATCH,
                 [update_params(opt),
                  log_grads(_run, model),
                  log_stats(_run)])

    @finetuner.on(Event.EPOCH_FINISHED)
    def eval_on_dev(state):
        _log.info("Evaluating on dev")
        eval_state = run_eval(model, vocab, samples["dev"])
        accs = eval_state["counts"].accs
        print_accs(accs, run=_run, step=state["n_iters"])

        pptx_loss = eval_state["mean_pptx_loss"]
        _log.info("dev_pptx_loss: %.4f", pptx_loss)
        _run.log_scalar("dev_pptx_loss", pptx_loss, step=state["n_iters"])

        state["dev_accs"] = accs

    @finetuner.on(Event.EPOCH_FINISHED)
    def maybe_eval_on_test(state):
        if state["epoch"] != max_epoch:
            return

        _log.info("Evaluating on test")
        eval_state = run_eval(model,
                              vocab,
                              samples["test"],
                              compute_loss=False)
        print_accs(eval_state["counts"].accs,
                   on="test",
                   run=_run,
                   step=state["n_iters"])

    finetuner.on(Event.EPOCH_FINISHED,
                 save_state_dict("model", model, under=artifacts_dir))

    EpochTimer().attach_on(finetuner)
    n_tokens = sum(len(s["words"]) for s in samples["train"])
    ProgressBar(stats="stats", total=n_tokens, unit="tok").attach_on(finetuner)

    bucket_key = lambda s: (len(s["words"]) - 1) // 10
    trn_iter = ShuffleIterator(
        BucketIterator(samples["train"],
                       bucket_key,
                       batch_size,
                       shuffle_bucket=True,
                       rng=_rnd),
        rng=_rnd,
    )
    _log.info("Starting finetuning")
    try:
        finetuner.run(trn_iter, max_epoch)
    except KeyboardInterrupt:
        _log.info("Interrupt detected, training will abort")
    else:
        return finetuner.state["dev_accs"]["las_nopunct"]