示例#1
0
def finetune(
    _log,
    _run,
    _rnd,
    max_length=None,
    artifacts_dir="ft_artifacts",
    overwrite=False,
    load_from="artifacts",
    load_params="model.pth",
    device="cpu",
    word_emb_path="wiki.id.vec",
    freeze=False,
    thresh=0.95,
    projective=False,
    multiroot=True,
    batch_size=32,
    lr=1e-5,
    l2_coef=1.0,
    max_epoch=5,
):
    """Finetune a trained model with PPT."""
    if max_length is None:
        max_length = {}

    artifacts_dir = Path(artifacts_dir)
    _log.info("Creating artifacts directory %s", artifacts_dir)
    artifacts_dir.mkdir(exist_ok=overwrite)

    samples = {
        wh: list(read_samples(which=wh, max_length=max_length.get(wh)))
        for wh in ["train", "dev", "test"]
    }
    for wh in samples:
        n_toks = sum(len(s["words"]) for s in samples[wh])
        _log.info("Read %d %s samples and %d tokens", len(samples[wh]), wh,
                  n_toks)

    path = Path(load_from) / "vocab.yml"
    _log.info("Loading vocabulary from %s", path)
    vocab = load(path.read_text(encoding="utf8"))
    for name in vocab:
        _log.info("Found %d %s", len(vocab[name]), name)

    _log.info("Extending vocabulary with target words")
    vocab.extend(chain(*samples.values()), ["words"])
    _log.info("Found %d words now", len(vocab["words"]))

    path = artifacts_dir / "vocab.yml"
    _log.info("Saving vocabulary to %s", path)
    path.write_text(dump(vocab), encoding="utf8")

    samples = {wh: list(vocab.stoi(samples[wh])) for wh in samples}

    path = Path(load_from) / "model.yml"
    _log.info("Loading model from metadata %s", path)
    model = load(path.read_text(encoding="utf8"))

    path = Path(load_from) / load_params
    _log.info("Loading model parameters from %s", path)
    model.load_state_dict(torch.load(path, "cpu"))

    _log.info("Creating extended word embedding layer")
    kv = KeyedVectors.load_word2vec_format(word_emb_path)
    assert model.word_emb.embedding_dim == kv.vector_size
    with torch.no_grad():
        model.word_emb = torch.nn.Embedding.from_pretrained(
            extend_word_embedding(model.word_emb.weight, vocab["words"], kv))

    path = artifacts_dir / "model.yml"
    _log.info("Saving model metadata to %s", path)
    path.write_text(dump(model), encoding="utf8")

    model.word_emb.requires_grad_(not freeze)
    model.tag_emb.requires_grad_(not freeze)
    model.to(device)

    for wh in ["train", "dev"]:
        for i, s in enumerate(samples[wh]):
            s["_id"] = i

        runner = Runner()
        runner.state.update({"ppt_masks": [], "_ids": []})
        runner.on(
            Event.BATCH,
            [
                batch2tensors(device, vocab),
                set_train_mode(model, training=False),
                compute_total_arc_type_scores(model, vocab),
            ],
        )

        @runner.on(Event.BATCH)
        def compute_ppt_ambiguous_arcs_mask(state):
            assert state["batch"]["mask"].all()
            scores = state["total_arc_type_scores"]
            ppt_mask = compute_ambiguous_arcs_mask(scores, thresh, projective,
                                                   multiroot)
            state["ppt_masks"].extend(ppt_mask.tolist())
            state["_ids"].extend(state["batch"]["_id"].tolist())
            state["n_items"] = state["batch"]["words"].numel()

        n_toks = sum(len(s["words"]) for s in samples[wh])
        ProgressBar(total=n_toks, unit="tok").attach_on(runner)

        _log.info("Computing PPT ambiguous arcs mask for %s set", wh)
        with torch.no_grad():
            runner.run(
                BucketIterator(samples[wh], lambda s: len(s["words"]),
                               batch_size))

        assert len(runner.state["ppt_masks"]) == len(samples[wh])
        assert len(runner.state["_ids"]) == len(samples[wh])
        for i, ppt_mask in zip(runner.state["_ids"],
                               runner.state["ppt_masks"]):
            samples[wh][i]["ppt_mask"] = ppt_mask

        _log.info("Computing (log) number of trees stats on %s set", wh)
        report_log_ntrees_stats(samples[wh], "ppt_mask", batch_size,
                                projective, multiroot)

    _log.info("Creating optimizer")
    opt = torch.optim.Adam(model.parameters(), lr=lr)

    finetuner = Runner()
    origin_params = {
        name: p.clone().detach()
        for name, p in model.named_parameters()
    }
    finetuner.on(
        Event.BATCH,
        [
            batch2tensors(device, vocab),
            set_train_mode(model),
            compute_l2_loss(model, origin_params),
            compute_total_arc_type_scores(model, vocab),
        ],
    )

    @finetuner.on(Event.BATCH)
    def compute_loss(state):
        mask = state["batch"]["mask"]
        ppt_mask = state["batch"]["ppt_mask"].bool()
        scores = state["total_arc_type_scores"]

        ppt_loss = compute_aatrn_loss(scores, ppt_mask, mask, projective,
                                      multiroot)
        ppt_loss /= mask.size(0)
        loss = ppt_loss + l2_coef * state["l2_loss"]

        state["loss"] = loss
        state["stats"] = {
            "ppt_loss": ppt_loss.item(),
            "l2_loss": state["l2_loss"].item(),
        }
        state["extra_stats"] = {"loss": loss.item()}
        state["n_items"] = mask.long().sum().item()

    finetuner.on(Event.BATCH,
                 [update_params(opt),
                  log_grads(_run, model),
                  log_stats(_run)])

    @finetuner.on(Event.EPOCH_FINISHED)
    def eval_on_dev(state):
        _log.info("Evaluating on dev")
        eval_state = run_eval(model, vocab, samples["dev"])
        accs = eval_state["counts"].accs
        print_accs(accs, run=_run, step=state["n_iters"])

        ppt_loss = eval_state["mean_ppt_loss"]
        _log.info("dev_ppt_loss: %.4f", ppt_loss)
        _run.log_scalar("dev_ppt_loss", ppt_loss, step=state["n_iters"])

        state["dev_accs"] = accs

    @finetuner.on(Event.EPOCH_FINISHED)
    def maybe_eval_on_test(state):
        if state["epoch"] != max_epoch:
            return

        _log.info("Evaluating on test")
        eval_state = run_eval(model,
                              vocab,
                              samples["test"],
                              compute_loss=False)
        print_accs(eval_state["counts"].accs,
                   on="test",
                   run=_run,
                   step=state["n_iters"])

    finetuner.on(Event.EPOCH_FINISHED,
                 save_state_dict("model", model, under=artifacts_dir))

    EpochTimer().attach_on(finetuner)
    n_tokens = sum(len(s["words"]) for s in samples["train"])
    ProgressBar(stats="stats", total=n_tokens, unit="tok").attach_on(finetuner)

    bucket_key = lambda s: (len(s["words"]) - 1) // 10
    trn_iter = ShuffleIterator(
        BucketIterator(samples["train"],
                       bucket_key,
                       batch_size,
                       shuffle_bucket=True,
                       rng=_rnd),
        rng=_rnd,
    )
    _log.info("Starting finetuning")
    try:
        finetuner.run(trn_iter, max_epoch)
    except KeyboardInterrupt:
        _log.info("Interrupt detected, training will abort")
    else:
        return finetuner.state["dev_accs"]["las_nopunct"]
示例#2
0
def finetune(
    _log,
    _run,
    _rnd,
    corpus,
    artifacts_dir="artifacts",
    overwrite=False,
    temperature=1.0,
    freeze_embeddings=True,
    freeze_encoder_up_to=1,
    device="cpu",
    thresh=0.95,
    batch_size=16,
    lr=1e-5,
    max_epoch=5,
    predict_on_finished=False,
):
    """Finetune/train the source model on unlabeled target data."""
    artifacts_dir = Path(artifacts_dir)
    artifacts_dir.mkdir(exist_ok=overwrite)

    samples = read_samples_()
    eval_samples = read_samples_(max_length=None)
    model_name = "clulab/roberta-timex-semeval"
    _log.info("Loading %s", model_name)
    config = AutoConfig.from_pretrained(model_name)
    token_clf = AutoModelForTokenClassification.from_pretrained(model_name,
                                                                config=config)
    model = RoBERTagger(token_clf, config.num_labels, temperature)

    _log.info("Initializing transitions")
    torch.nn.init.zeros_(model.start_transition)
    torch.nn.init.zeros_(model.transition)
    for lid, label in config.id2label.items():
        if not label.startswith("I-"):
            continue

        with torch.no_grad():
            model.start_transition[lid] = -1e9
        for plid, plabel in config.id2label.items():
            if plabel == "O" or plabel[2:] != label[2:]:
                with torch.no_grad():
                    model.transition[plid, lid] = -1e9

    for name, p in model.named_parameters():
        freeze = False
        if freeze_embeddings and ".embeddings." in name:
            freeze = True
        if freeze_encoder_up_to >= 0:
            for i in range(freeze_encoder_up_to + 1):
                if f".encoder.layer.{i}." in name:
                    freeze = True
        if freeze:
            _log.info("Freezing %s", name)
            p.requires_grad_(False)

    model.to(device)

    _log.info("Computing ambiguous PTST tag pairs mask")
    model.eval()
    ptst_masks, _ids = [], []
    pbar = tqdm(total=sum(len(s["word_ids"]) for s in samples), unit="tok")
    for batch in BucketIterator(samples, lambda s: len(s["word_ids"]),
                                batch_size):
        arr = batch.to_array()
        assert arr["mask"].all()
        words = torch.from_numpy(arr["word_ids"]).long().to(device)
        with torch.no_grad():
            ptst_mask = compute_ambiguous_tag_pairs_mask(model(words), thresh)
        ptst_masks.extend(ptst_mask.tolist())
        _ids.extend(arr["_id"].tolist())
        pbar.update(int(arr["mask"].sum()))
    pbar.close()

    assert len(ptst_masks) == len(samples)
    assert len(_ids) == len(samples)
    for i, ptst_mask in zip(_ids, ptst_masks):
        samples[i]["ptst_mask"] = ptst_mask

    _log.info("Report number of sequences")
    log_total_nseqs, log_nseqs = [], []
    pbar = tqdm(total=sum(len(s["word_ids"]) for s in samples), leave=False)
    for batch in BucketIterator(samples, lambda s: len(s["word_ids"]),
                                batch_size):
        arr = batch.to_array()
        assert arr["mask"].all()
        ptst_mask = torch.from_numpy(arr["ptst_mask"]).bool().to(device)
        cnt_scores = torch.zeros_like(ptst_mask).float()
        cnt_scores_masked = cnt_scores.masked_fill(~ptst_mask, -1e9)
        log_total_nseqs.extend(LinearCRF(cnt_scores).log_partitions().tolist())
        log_nseqs.extend(
            LinearCRF(cnt_scores_masked).log_partitions().tolist())
        pbar.update(arr["word_ids"].size)
    pbar.close()
    cov = [math.exp(x - x_) for x, x_ in zip(log_nseqs, log_total_nseqs)]
    _log.info(
        "Number of seqs: min {:.2} ({:.2}%) | med {:.2} ({:.2}%) | max {:.2} ({:.2}%)"
        .format(
            math.exp(min(log_nseqs)),
            100 * min(cov),
            math.exp(median(log_nseqs)),
            100 * median(cov),
            math.exp(max(log_nseqs)),
            100 * max(cov),
        ))

    _log.info("Creating optimizer")
    opt = torch.optim.Adam(model.parameters(), lr=lr)

    finetuner = Runner()

    @finetuner.on(Event.BATCH)
    def compute_loss(state):
        arr = state["batch"].to_array()
        words = torch.from_numpy(arr["word_ids"]).long().to(device)
        mask = torch.from_numpy(arr["mask"]).bool().to(device)
        ptst_mask = torch.from_numpy(arr["ptst_mask"]).bool().to(device)

        model.train()
        scores = model(words, mask)
        masked_scores = scores.masked_fill(~ptst_mask, -1e9)

        # mask passed to LinearCRF shouldn't include the last token
        last_idx = mask.long().sum(dim=1, keepdim=True) - 1
        mask_ = mask.scatter(1, last_idx, False)[:, :-1]

        crf = LinearCRF(masked_scores, mask_)
        crf_z = LinearCRF(scores, mask_)
        ptst_loss = -crf.log_partitions().sum() + crf_z.log_partitions().sum()
        ptst_loss /= mask.size(0)

        state["loss"] = ptst_loss
        state["stats"] = {"ptst_loss": ptst_loss.item()}
        state["n_items"] = mask.long().sum().item()

    finetuner.on(Event.BATCH,
                 [update_params(opt),
                  log_grads(_run, model),
                  log_stats(_run)])

    @finetuner.on(Event.EPOCH_FINISHED)
    def evaluate(state):
        _log.info("Evaluating on train")
        eval_score, loss = run_eval(model,
                                    config.id2label,
                                    samples,
                                    compute_loss=True)
        if eval_score is not None:
            print_accs(eval_score, on="train", run=_run, step=state["n_iters"])
        _log.info("train_ptst_loss: %.4f", loss)
        _run.log_scalar("train_ptst_loss", loss, step=state["n_iters"])

        _log.info("Evaluating on eval")
        eval_score, _ = run_eval(model, config.id2label, eval_samples)
        if eval_score is not None:
            print_accs(eval_score, on="eval", run=_run, step=state["n_iters"])

        state["eval_f1"] = None if eval_score is None else eval_score["f1"]

    finetuner.on(Event.EPOCH_FINISHED,
                 save_state_dict("model", model, under=artifacts_dir))

    @finetuner.on(Event.FINISHED)
    def maybe_predict(state):
        if not predict_on_finished:
            return

        _log.info("Computing predictions")
        model.eval()
        preds, _ids = [], []
        pbar = tqdm(total=sum(len(s["word_ids"]) for s in eval_samples),
                    unit="tok")
        for batch in BucketIterator(eval_samples, lambda s: len(s["word_ids"]),
                                    batch_size):
            arr = batch.to_array()
            assert arr["mask"].all()
            words = torch.from_numpy(arr["word_ids"]).long().to(device)
            scores = model(words)
            pred = LinearCRF(scores).argmax()
            preds.extend(pred.tolist())
            _ids.extend(arr["_id"].tolist())
            pbar.update(int(arr["mask"].sum()))
        pbar.close()

        assert len(preds) == len(eval_samples)
        assert len(_ids) == len(eval_samples)
        for i, preds_ in zip(_ids, preds):
            eval_samples[i]["preds"] = preds_

        group = defaultdict(list)
        for s in eval_samples:
            group[str(s["path"])].append(s)

        _log.info("Writing predictions")
        for doc_path, doc_samples in group.items():
            spans = [x for s in doc_samples for x in s["spans"]]
            labels = [
                config.id2label[x] for s in doc_samples for x in s["preds"]
            ]
            doc_path = Path(doc_path[len(f"{corpus['path']}/"):])
            data = make_anafora(spans, labels, doc_path.name)
            (artifacts_dir / "time" / doc_path.parent).mkdir(parents=True,
                                                             exist_ok=True)
            data.to_file(
                f"{str(artifacts_dir / 'time' / doc_path)}.TimeNorm.system.completed.xml"
            )

    EpochTimer().attach_on(finetuner)
    n_tokens = sum(len(s["word_ids"]) for s in samples)
    ProgressBar(stats="stats", total=n_tokens, unit="tok").attach_on(finetuner)

    bucket_key = lambda s: (len(s["word_ids"]) - 1) // 10
    trn_iter = ShuffleIterator(
        BucketIterator(samples,
                       bucket_key,
                       batch_size,
                       shuffle_bucket=True,
                       rng=_rnd),
        rng=_rnd,
    )
    _log.info("Starting finetuning")
    try:
        finetuner.run(trn_iter, max_epoch)
    except KeyboardInterrupt:
        _log.info("Interrupt detected, training will abort")
    else:
        return finetuner.state.get("eval_f1")
示例#3
0
def finetune(
    _log,
    _run,
    _rnd,
    max_length=None,
    artifacts_dir="ft_artifacts",
    overwrite=False,
    load_from="artifacts",
    load_params="model.pth",
    device="cpu",
    word_emb_path="wiki.id.vec",
    freeze=False,
    projective=False,
    multiroot=True,
    batch_size=32,
    lr=1e-5,
    l2_coef=1.0,
    max_epoch=5,
):
    """Finetune a trained model with self-training."""
    if max_length is None:
        max_length = {}

    artifacts_dir = Path(artifacts_dir)
    _log.info("Creating artifacts directory %s", artifacts_dir)
    artifacts_dir.mkdir(exist_ok=overwrite)

    samples = {
        wh: list(read_samples(which=wh, max_length=max_length.get(wh)))
        for wh in ["train", "dev", "test"]
    }
    for wh in samples:
        n_toks = sum(len(s["words"]) for s in samples[wh])
        _log.info("Read %d %s samples and %d tokens", len(samples[wh]), wh,
                  n_toks)

    path = Path(load_from) / "vocab.yml"
    _log.info("Loading vocabulary from %s", path)
    vocab = load(path.read_text(encoding="utf8"))
    for name in vocab:
        _log.info("Found %d %s", len(vocab[name]), name)

    _log.info("Extending vocabulary with target words")
    vocab.extend(chain(*samples.values()), ["words"])
    _log.info("Found %d words now", len(vocab["words"]))

    path = artifacts_dir / "vocab.yml"
    _log.info("Saving vocabulary to %s", path)
    path.write_text(dump(vocab), encoding="utf8")

    samples = {wh: list(vocab.stoi(samples[wh])) for wh in samples}

    path = Path(load_from) / "model.yml"
    _log.info("Loading model from metadata %s", path)
    model = load(path.read_text(encoding="utf8"))

    path = Path(load_from) / load_params
    _log.info("Loading model parameters from %s", path)
    model.load_state_dict(torch.load(path, "cpu"))

    _log.info("Creating extended word embedding layer")
    kv = KeyedVectors.load_word2vec_format(word_emb_path)
    assert model.word_emb.embedding_dim == kv.vector_size
    with torch.no_grad():
        model.word_emb = torch.nn.Embedding.from_pretrained(
            extend_word_embedding(model.word_emb.weight, vocab["words"], kv))

    path = artifacts_dir / "model.yml"
    _log.info("Saving model metadata to %s", path)
    path.write_text(dump(model), encoding="utf8")

    model.word_emb.requires_grad_(not freeze)
    model.tag_emb.requires_grad_(not freeze)
    model.to(device)

    for wh in ["train"]:
        for i, s in enumerate(samples[wh]):
            s["_id"] = i

        runner = Runner()
        runner.state.update({"st_heads": [], "st_types": [], "_ids": []})
        runner.on(
            Event.BATCH,
            [
                batch2tensors(device, vocab),
                set_train_mode(model, training=False),
                compute_total_arc_type_scores(model, vocab),
                predict_batch(projective, multiroot),
            ],
        )

        @runner.on(Event.BATCH)
        def save_st_trees(state):
            state["st_heads"].extend(state["pred_heads"].tolist())
            state["st_types"].extend(state["pred_types"].tolist())
            state["_ids"].extend(state["batch"]["_id"].tolist())
            state["n_items"] = state["batch"]["words"].numel()

        n_toks = sum(len(s["words"]) for s in samples[wh])
        ProgressBar(total=n_toks, unit="tok").attach_on(runner)

        _log.info("Computing ST trees for %s set", wh)
        with torch.no_grad():
            runner.run(
                BucketIterator(samples[wh], lambda s: len(s["words"]),
                               batch_size))

        assert len(runner.state["st_heads"]) == len(samples[wh])
        assert len(runner.state["st_types"]) == len(samples[wh])
        assert len(runner.state["_ids"]) == len(samples[wh])
        for i, st_heads, st_types in zip(runner.state["_ids"],
                                         runner.state["st_heads"],
                                         runner.state["st_types"]):
            assert len(samples[wh][i]["words"]) == len(st_heads)
            assert len(samples[wh][i]["words"]) == len(st_types)
            samples[wh][i]["st_heads"] = st_heads
            samples[wh][i]["st_types"] = st_types

    _log.info("Creating optimizer")
    opt = torch.optim.Adam(model.parameters(), lr=lr)

    finetuner = Runner()
    origin_params = {
        name: p.clone().detach()
        for name, p in model.named_parameters()
    }
    finetuner.on(
        Event.BATCH,
        [
            batch2tensors(device, vocab),
            set_train_mode(model),
            compute_l2_loss(model, origin_params),
        ],
    )

    @finetuner.on(Event.BATCH)
    def compute_loss(state):
        bat = state["batch"]
        words, tags, heads, types = bat["words"], bat["tags"], bat[
            "st_heads"], bat["st_types"]
        mask = bat["mask"]

        arc_scores, type_scores = model(words, tags, mask, heads)
        arc_scores = arc_scores.masked_fill(~mask.unsqueeze(2),
                                            -1e9)  # mask padding heads
        type_scores[..., vocab["types"].index(vocab.PAD_TOKEN)] = -1e9

        # remove root
        arc_scores, type_scores = arc_scores[:, :, 1:], type_scores[:, 1:]
        heads, types, mask = heads[:, 1:], types[:, 1:], mask[:, 1:]

        arc_scores = rearrange(arc_scores,
                               "bsz slen1 slen2 -> (bsz slen2) slen1")
        heads = heads.reshape(-1)
        arc_loss = torch.nn.functional.cross_entropy(arc_scores,
                                                     heads,
                                                     reduction="none")

        type_scores = rearrange(type_scores,
                                "bsz slen ntypes -> (bsz slen) ntypes")
        types = types.reshape(-1)
        type_loss = torch.nn.functional.cross_entropy(type_scores,
                                                      types,
                                                      reduction="none")

        arc_loss = arc_loss.masked_select(mask.reshape(-1)).mean()
        type_loss = type_loss.masked_select(mask.reshape(-1)).mean()
        loss = arc_loss + type_loss + l2_coef * state["l2_loss"]

        state["loss"] = loss
        state["stats"] = {
            "arc_ppl": arc_loss.exp().item(),
            "type_ppl": type_loss.exp().item(),
            "l2_loss": state["l2_loss"].item(),
        }
        state["extra_stats"] = {
            "arc_loss": arc_loss.item(),
            "type_loss": type_loss.item()
        }

    finetuner.on(
        Event.BATCH,
        [
            get_n_items(),
            update_params(opt),
            log_grads(_run, model),
            log_stats(_run)
        ],
    )

    @finetuner.on(Event.EPOCH_FINISHED)
    def eval_on_dev(state):
        _log.info("Evaluating on dev")
        eval_state = run_eval(model, vocab, samples["dev"])
        accs = eval_state["counts"].accs
        print_accs(accs, run=_run, step=state["n_iters"])
        state["dev_accs"] = accs

    @finetuner.on(Event.EPOCH_FINISHED)
    def maybe_eval_on_test(state):
        if state["epoch"] != max_epoch:
            return

        _log.info("Evaluating on test")
        eval_state = run_eval(model, vocab, samples["test"])
        print_accs(eval_state["counts"].accs,
                   on="test",
                   run=_run,
                   step=state["n_iters"])

    finetuner.on(Event.EPOCH_FINISHED,
                 save_state_dict("model", model, under=artifacts_dir))

    EpochTimer().attach_on(finetuner)
    n_tokens = sum(len(s["words"]) for s in samples["train"])
    ProgressBar(stats="stats", total=n_tokens, unit="tok").attach_on(finetuner)

    bucket_key = lambda s: (len(s["words"]) - 1) // 10
    trn_iter = ShuffleIterator(
        BucketIterator(samples["train"],
                       bucket_key,
                       batch_size,
                       shuffle_bucket=True,
                       rng=_rnd),
        rng=_rnd,
    )
    _log.info("Starting finetuning")
    try:
        finetuner.run(trn_iter, max_epoch)
    except KeyboardInterrupt:
        _log.info("Interrupt detected, training will abort")
    else:
        return finetuner.state["dev_accs"]["las_nopunct"]
示例#4
0
def train(
    _log,
    _run,
    _rnd,
    artifacts_dir="artifacts",
    overwrite=False,
    max_length=None,
    load_types_vocab_from=None,
    batch_size=16,
    device="cpu",
    lr=0.001,
    patience=5,
    max_epoch=1000,
):
    """Train a self-attention graph-based parser."""
    if max_length is None:
        max_length = {}

    artifacts_dir = Path(artifacts_dir)
    _log.info("Creating artifacts directory %s", artifacts_dir)
    artifacts_dir.mkdir(exist_ok=overwrite)

    samples = {
        wh: list(read_samples(which=wh, max_length=max_length.get(wh)))
        for wh in ["train", "dev", "test"]
    }
    for wh in samples:
        n_toks = sum(len(s["words"]) for s in samples[wh])
        _log.info("Read %d %s samples and %d tokens", len(samples[wh]), wh,
                  n_toks)

    _log.info("Creating vocabulary")
    vocab = Vocab.from_samples(chain(*samples.values()))
    if load_types_vocab_from:
        path = Path(load_types_vocab_from)
        _log.info("Loading types vocab from %s", path)
        vocab["types"] = load(path.read_text(encoding="utf8"))["types"]

    _log.info("Vocabulary created")
    for name in vocab:
        _log.info("Found %d %s", len(vocab[name]), name)

    path = artifacts_dir / "vocab.yml"
    _log.info("Saving vocabulary to %s", path)
    path.write_text(dump(vocab), encoding="utf8")

    samples = {wh: list(vocab.stoi(samples[wh])) for wh in samples}

    model = make_model(vocab)
    model.to(device)

    _log.info("Creating optimizer")
    opt = torch.optim.Adam(model.parameters(), lr=lr)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(opt,
                                                           mode="max",
                                                           factor=0.5)

    trainer = Runner()
    trainer.state.update({"dev_larcs_nopunct": -1, "dev_uarcs_nopunct": -1})
    trainer.on(Event.BATCH,
               [batch2tensors(device, vocab),
                set_train_mode(model)])

    @trainer.on(Event.BATCH)
    def compute_loss(state):
        bat = state["batch"]
        words, tags, heads, types = bat["words"], bat["tags"], bat[
            "heads"], bat["types"]
        mask = bat["mask"]

        arc_scores, type_scores = model(words, tags, mask, heads)
        arc_scores = arc_scores.masked_fill(~mask.unsqueeze(2),
                                            -1e9)  # mask padding heads
        type_scores[..., vocab["types"].index(Vocab.PAD_TOKEN)] = -1e9

        # remove root
        arc_scores, type_scores = arc_scores[:, :, 1:], type_scores[:, 1:]
        heads, types, mask = heads[:, 1:], types[:, 1:], mask[:, 1:]

        arc_scores = rearrange(arc_scores,
                               "bsz slen1 slen2 -> (bsz slen2) slen1")
        heads = heads.reshape(-1)
        arc_loss = torch.nn.functional.cross_entropy(arc_scores,
                                                     heads,
                                                     reduction="none")

        type_scores = rearrange(type_scores,
                                "bsz slen ntypes -> (bsz slen) ntypes")
        types = types.reshape(-1)
        type_loss = torch.nn.functional.cross_entropy(type_scores,
                                                      types,
                                                      reduction="none")

        arc_loss = arc_loss.masked_select(mask.reshape(-1)).mean()
        type_loss = type_loss.masked_select(mask.reshape(-1)).mean()
        loss = arc_loss + type_loss

        state["loss"] = loss
        arc_loss, type_loss = arc_loss.item(), type_loss.item()
        state["stats"] = {
            "arc_ppl": math.exp(arc_loss),
            "type_ppl": math.exp(type_loss),
        }
        state["extra_stats"] = {"arc_loss": arc_loss, "type_loss": type_loss}
        state["n_items"] = bat["mask"].long().sum().item()

    trainer.on(Event.BATCH,
               [update_params(opt),
                log_grads(_run, model),
                log_stats(_run)])

    @trainer.on(Event.EPOCH_FINISHED)
    def eval_on_dev(state):
        _log.info("Evaluating on dev")
        eval_state = run_eval(model, vocab, samples["dev"])
        accs = eval_state["counts"].accs
        print_accs(accs, run=_run, step=state["n_iters"])

        scheduler.step(accs["las_nopunct"])

        if eval_state["counts"].larcs_nopunct > state["dev_larcs_nopunct"]:
            state["better"] = True
        elif eval_state["counts"].larcs_nopunct < state["dev_larcs_nopunct"]:
            state["better"] = False
        elif eval_state["counts"].uarcs_nopunct > state["dev_uarcs_nopunct"]:
            state["better"] = True
        else:
            state["better"] = False

        if state["better"]:
            _log.info("Found new best result on dev!")
            state["dev_larcs_nopunct"] = eval_state["counts"].larcs_nopunct
            state["dev_uarcs_nopunct"] = eval_state["counts"].uarcs_nopunct
            state["dev_accs"] = accs
            state["dev_epoch"] = state["epoch"]
        else:
            _log.info("Not better, the best so far is epoch %d:",
                      state["dev_epoch"])
            print_accs(state["dev_accs"])
            print_accs(state["test_accs"], on="test")

    @trainer.on(Event.EPOCH_FINISHED)
    def maybe_eval_on_test(state):
        if not state["better"]:
            return

        _log.info("Evaluating on test")
        eval_state = run_eval(model, vocab, samples["test"])
        state["test_accs"] = eval_state["counts"].accs
        print_accs(state["test_accs"],
                   on="test",
                   run=_run,
                   step=state["n_iters"])

    trainer.on(
        Event.EPOCH_FINISHED,
        [
            maybe_stop_early(patience=patience),
            save_state_dict("model", model, under=artifacts_dir,
                            when="better"),
        ],
    )

    EpochTimer().attach_on(trainer)
    n_tokens = sum(len(s["words"]) for s in samples["train"])
    ProgressBar(stats="stats", total=n_tokens, unit="tok").attach_on(trainer)

    bucket_key = lambda s: (len(s["words"]) - 1) // 10
    trn_iter = ShuffleIterator(
        BucketIterator(samples["train"],
                       bucket_key,
                       batch_size,
                       shuffle_bucket=True,
                       rng=_rnd),
        rng=_rnd,
    )
    _log.info("Starting training")
    try:
        trainer.run(trn_iter, max_epoch)
    except KeyboardInterrupt:
        _log.info("Interrupt detected, training will abort")
    else:
        return trainer.state["dev_accs"]["las_nopunct"]
示例#5
0
def finetune(
    corpus,
    _log,
    _run,
    _rnd,
    max_length=None,
    artifacts_dir="ft_artifacts",
    load_samples_from=None,
    overwrite=False,
    load_src=None,
    src_key_as_lang=False,
    main_src=None,
    device="cpu",
    word_emb_path="wiki.id.vec",
    freeze=False,
    thresh=0.95,
    projective=False,
    multiroot=True,
    batch_size=32,
    save_samples=False,
    lr=1e-5,
    l2_coef=1.0,
    max_epoch=5,
):
    """Finetune a trained model with PPTX."""
    if max_length is None:
        max_length = {}
    if load_src is None:
        load_src = {"src": ("artifacts", "model.pth")}
        main_src = "src"
    elif main_src not in load_src:
        raise ValueError(f"{main_src} not found in load_src")

    artifacts_dir = Path(artifacts_dir)
    _log.info("Creating artifacts directory %s", artifacts_dir)
    artifacts_dir.mkdir(exist_ok=overwrite)

    if load_samples_from:
        _log.info("Loading samples from %s", load_samples_from)
        with open(load_samples_from, "rb") as f:
            samples = pickle.load(f)
    else:
        samples = {
            wh: list(read_samples(which=wh, max_length=max_length.get(wh)))
            for wh in ["train", "dev", "test"]
        }
    for wh in samples:
        n_toks = sum(len(s["words"]) for s in samples[wh])
        _log.info("Read %d %s samples and %d tokens", len(samples[wh]), wh,
                  n_toks)

    kv = KeyedVectors.load_word2vec_format(word_emb_path)

    if load_samples_from:
        _log.info(
            "Skipping non-main src because samples are processed and loaded")
        srcs = []
    else:
        srcs = [src for src in load_src if src != main_src]
        if src_key_as_lang and corpus["lang"] in srcs:
            _log.info("Removing %s from src parsers because it's the tgt",
                      corpus["lang"])
            srcs.remove(corpus["lang"])
    srcs.append(main_src)

    for src_i, src in enumerate(srcs):
        _log.info("Processing src %s [%d/%d]", src, src_i + 1, len(srcs))
        load_from, load_params = load_src[src]
        path = Path(load_from) / "vocab.yml"
        _log.info("Loading %s vocabulary from %s", src, path)
        vocab = load(path.read_text(encoding="utf8"))
        for name in vocab:
            _log.info("Found %d %s", len(vocab[name]), name)

        _log.info("Extending %s vocabulary with target words", src)
        vocab.extend(chain(*samples.values()), ["words"])
        _log.info("Found %d words now", len(vocab["words"]))

        samples_ = {wh: list(vocab.stoi(samples[wh])) for wh in samples}

        path = Path(load_from) / "model.yml"
        _log.info("Loading %s model from metadata %s", src, path)
        model = load(path.read_text(encoding="utf8"))

        path = Path(load_from) / load_params
        _log.info("Loading %s model parameters from %s", src, path)
        model.load_state_dict(torch.load(path, "cpu"))

        _log.info("Creating %s extended word embedding layer", src)
        assert model.word_emb.embedding_dim == kv.vector_size
        with torch.no_grad():
            model.word_emb = torch.nn.Embedding.from_pretrained(
                extend_word_embedding(model.word_emb.weight, vocab["words"],
                                      kv))
        model.to(device)

        for wh in ["train", "dev"]:
            if load_samples_from:
                assert all("pptx_mask" in s for s in samples[wh])
                continue

            for i, s in enumerate(samples_[wh]):
                s["_id"] = i

            runner = Runner()
            runner.state.update({"pptx_masks": [], "_ids": []})
            runner.on(
                Event.BATCH,
                [
                    batch2tensors(device, vocab),
                    set_train_mode(model, training=False),
                    compute_total_arc_type_scores(model, vocab),
                ],
            )

            @runner.on(Event.BATCH)
            def compute_pptx_ambiguous_arcs_mask(state):
                assert state["batch"]["mask"].all()
                scores = state["total_arc_type_scores"]
                pptx_mask = compute_ambiguous_arcs_mask(
                    scores, thresh, projective, multiroot)
                state["pptx_masks"].extend(pptx_mask)
                state["_ids"].extend(state["batch"]["_id"].tolist())
                state["n_items"] = state["batch"]["words"].numel()

            n_toks = sum(len(s["words"]) for s in samples_[wh])
            ProgressBar(total=n_toks, unit="tok").attach_on(runner)

            _log.info(
                "Computing PPTX ambiguous arcs mask for %s set with source %s",
                wh, src)
            with torch.no_grad():
                runner.run(
                    BucketIterator(samples_[wh], lambda s: len(s["words"]),
                                   batch_size))

            assert len(runner.state["pptx_masks"]) == len(samples_[wh])
            assert len(runner.state["_ids"]) == len(samples_[wh])
            for i, pptx_mask in zip(runner.state["_ids"],
                                    runner.state["pptx_masks"]):
                samples_[wh][i]["pptx_mask"] = pptx_mask.tolist()

            _log.info("Computing (log) number of trees stats on %s set", wh)
            report_log_ntrees_stats(samples_[wh], "pptx_mask", batch_size,
                                    projective, multiroot)

            _log.info("Combining the ambiguous arcs mask")
            assert len(samples_[wh]) == len(samples[wh])
            for i in range(len(samples_[wh])):
                pptx_mask = torch.tensor(samples_[wh][i]["pptx_mask"])
                assert pptx_mask.dim() == 3
                if "pptx_mask" in samples[wh][i]:
                    old_mask = torch.tensor(samples[wh][i]["pptx_mask"])
                else:
                    old_mask = torch.zeros(1, 1, 1).bool()
                samples[wh][i]["pptx_mask"] = (old_mask | pptx_mask).tolist()

    assert src == main_src
    _log.info("Main source is %s", src)

    path = artifacts_dir / "vocab.yml"
    _log.info("Saving vocabulary to %s", path)
    path.write_text(dump(vocab), encoding="utf8")

    path = artifacts_dir / "model.yml"
    _log.info("Saving model metadata to %s", path)
    path.write_text(dump(model), encoding="utf8")

    if save_samples:
        path = artifacts_dir / "samples.pkl"
        _log.info("Saving samples to %s", path)
        with open(path, "wb") as f:
            pickle.dump(samples, f)

    samples = {wh: list(vocab.stoi(samples[wh])) for wh in samples}

    for wh in ["train", "dev"]:
        _log.info("Computing (log) number of trees stats on %s set", wh)
        report_log_ntrees_stats(samples[wh], "pptx_mask", batch_size,
                                projective, multiroot)

    model.word_emb.requires_grad_(not freeze)
    model.tag_emb.requires_grad_(not freeze)

    _log.info("Creating optimizer")
    opt = torch.optim.Adam(model.parameters(), lr=lr)

    finetuner = Runner()
    origin_params = {
        name: p.clone().detach()
        for name, p in model.named_parameters()
    }
    finetuner.on(
        Event.BATCH,
        [
            batch2tensors(device, vocab),
            set_train_mode(model),
            compute_l2_loss(model, origin_params),
            compute_total_arc_type_scores(model, vocab),
        ],
    )

    @finetuner.on(Event.BATCH)
    def compute_loss(state):
        mask = state["batch"]["mask"]
        pptx_mask = state["batch"]["pptx_mask"].bool()
        scores = state["total_arc_type_scores"]

        pptx_loss = compute_aatrn_loss(scores, pptx_mask, mask, projective,
                                       multiroot)
        pptx_loss /= mask.size(0)
        loss = pptx_loss + l2_coef * state["l2_loss"]

        state["loss"] = loss
        state["stats"] = {
            "pptx_loss": pptx_loss.item(),
            "l2_loss": state["l2_loss"].item(),
        }
        state["extra_stats"] = {"loss": loss.item()}
        state["n_items"] = mask.long().sum().item()

    finetuner.on(Event.BATCH,
                 [update_params(opt),
                  log_grads(_run, model),
                  log_stats(_run)])

    @finetuner.on(Event.EPOCH_FINISHED)
    def eval_on_dev(state):
        _log.info("Evaluating on dev")
        eval_state = run_eval(model, vocab, samples["dev"])
        accs = eval_state["counts"].accs
        print_accs(accs, run=_run, step=state["n_iters"])

        pptx_loss = eval_state["mean_pptx_loss"]
        _log.info("dev_pptx_loss: %.4f", pptx_loss)
        _run.log_scalar("dev_pptx_loss", pptx_loss, step=state["n_iters"])

        state["dev_accs"] = accs

    @finetuner.on(Event.EPOCH_FINISHED)
    def maybe_eval_on_test(state):
        if state["epoch"] != max_epoch:
            return

        _log.info("Evaluating on test")
        eval_state = run_eval(model,
                              vocab,
                              samples["test"],
                              compute_loss=False)
        print_accs(eval_state["counts"].accs,
                   on="test",
                   run=_run,
                   step=state["n_iters"])

    finetuner.on(Event.EPOCH_FINISHED,
                 save_state_dict("model", model, under=artifacts_dir))

    EpochTimer().attach_on(finetuner)
    n_tokens = sum(len(s["words"]) for s in samples["train"])
    ProgressBar(stats="stats", total=n_tokens, unit="tok").attach_on(finetuner)

    bucket_key = lambda s: (len(s["words"]) - 1) // 10
    trn_iter = ShuffleIterator(
        BucketIterator(samples["train"],
                       bucket_key,
                       batch_size,
                       shuffle_bucket=True,
                       rng=_rnd),
        rng=_rnd,
    )
    _log.info("Starting finetuning")
    try:
        finetuner.run(trn_iter, max_epoch)
    except KeyboardInterrupt:
        _log.info("Interrupt detected, training will abort")
    else:
        return finetuner.state["dev_accs"]["las_nopunct"]