Ejemplo n.º 1
0
def run_eval(
    model,
    vocab,
    samples,
    device="cpu",
    projective=False,
    multiroot=True,
    batch_size=32,
):
    runner = Runner()
    runner.on(
        Event.BATCH,
        [
            batch2tensors(device, vocab),
            set_train_mode(model, training=False),
            compute_total_arc_type_scores(model, vocab),
            predict_batch(projective, multiroot),
            evaluate_batch(),
            get_n_items(),
        ],
    )

    n_tokens = sum(len(s["words"]) for s in samples)
    ProgressBar(leave=False, total=n_tokens, unit="tok").attach_on(runner)
    SumReducer("counts").attach_on(runner)

    with torch.no_grad():
        runner.run(
            BucketIterator(samples, lambda s: len(s["words"]), batch_size))

    return runner.state
Ejemplo n.º 2
0
def report_log_ntrees_stats(
    samples: Sequence[dict],
    aa_mask_field: str,
    batch_size: int = 1,
    projective: bool = False,
    multiroot: bool = False,
) -> None:
    log_ntrees: list = []
    pbar = tqdm(total=sum(len(s["words"]) for s in samples), leave=False)
    for batch in BucketIterator(samples, lambda s: len(s["words"]),
                                batch_size):
        arr = batch.to_array()
        aaet_mask = torch.from_numpy(arr[aa_mask_field]).bool()
        cnt_scores = torch.zeros_like(aaet_mask).float().masked_fill(
            ~aaet_mask, -1e9)
        log_ntrees.extend(
            DepTreeCRF(cnt_scores, projective=projective,
                       multiroot=multiroot).log_partitions().tolist())
        pbar.update(arr["words"].size)
    pbar.close()
    logger.info(
        "Log number of trees: min %.2f | q1 %.2f | q2 %.2f | q3 %.2f | max %.2f",
        np.min(log_ntrees),
        np.quantile(log_ntrees, 0.25),
        np.quantile(log_ntrees, 0.5),
        np.quantile(log_ntrees, 0.75),
        np.max(log_ntrees),
    )
Ejemplo n.º 3
0
def run_eval(
    model,
    vocab,
    samples,
    compute_loss=True,
    device="cpu",
    projective=False,
    multiroot=True,
    batch_size=32,
):
    runner = Runner()
    runner.on(
        Event.BATCH,
        [
            batch2tensors(device, vocab),
            set_train_mode(model, training=False),
            compute_total_arc_type_scores(model, vocab),
        ],
    )

    @runner.on(Event.BATCH)
    def maybe_compute_loss(state):
        if not compute_loss:
            return

        ppt_loss = compute_aatrn_loss(
            state["total_arc_type_scores"],
            state["batch"]["ppt_mask"].bool(),
            projective=projective,
            multiroot=multiroot,
        )
        state["ppt_loss"] = ppt_loss.item()
        state["size"] = state["batch"]["words"].size(0)

    runner.on(Event.BATCH, [
        predict_batch(projective, multiroot),
        evaluate_batch(),
        get_n_items()
    ])

    n_tokens = sum(len(s["words"]) for s in samples)
    ProgressBar(leave=False, total=n_tokens, unit="tok").attach_on(runner)
    SumReducer("counts").attach_on(runner)
    if compute_loss:
        MeanReducer("mean_ppt_loss", value="ppt_loss").attach_on(runner)

    with torch.no_grad():
        runner.run(
            BucketIterator(samples, lambda s: len(s["words"]), batch_size))

    return runner.state
Ejemplo n.º 4
0
    def maybe_predict(state):
        if not predict_on_finished:
            return

        _log.info("Computing predictions")
        model.eval()
        preds, _ids = [], []
        pbar = tqdm(total=sum(len(s["word_ids"]) for s in eval_samples),
                    unit="tok")
        for batch in BucketIterator(eval_samples, lambda s: len(s["word_ids"]),
                                    batch_size):
            arr = batch.to_array()
            assert arr["mask"].all()
            words = torch.from_numpy(arr["word_ids"]).long().to(device)
            scores = model(words)
            pred = LinearCRF(scores).argmax()
            preds.extend(pred.tolist())
            _ids.extend(arr["_id"].tolist())
            pbar.update(int(arr["mask"].sum()))
        pbar.close()

        assert len(preds) == len(eval_samples)
        assert len(_ids) == len(eval_samples)
        for i, preds_ in zip(_ids, preds):
            eval_samples[i]["preds"] = preds_

        group = defaultdict(list)
        for s in eval_samples:
            group[str(s["path"])].append(s)

        _log.info("Writing predictions")
        for doc_path, doc_samples in group.items():
            spans = [x for s in doc_samples for x in s["spans"]]
            labels = [
                config.id2label[x] for s in doc_samples for x in s["preds"]
            ]
            doc_path = Path(doc_path[len(f"{corpus['path']}/"):])
            data = make_anafora(spans, labels, doc_path.name)
            (artifacts_dir / "time" / doc_path.parent).mkdir(parents=True,
                                                             exist_ok=True)
            data.to_file(
                f"{str(artifacts_dir / 'time' / doc_path)}.TimeNorm.system.completed.xml"
            )
Ejemplo n.º 5
0
def finetune(
    _log,
    _run,
    _rnd,
    max_length=None,
    artifacts_dir="ft_artifacts",
    overwrite=False,
    load_from="artifacts",
    load_params="model.pth",
    device="cpu",
    word_emb_path="wiki.id.vec",
    freeze=False,
    thresh=0.95,
    projective=False,
    multiroot=True,
    batch_size=32,
    lr=1e-5,
    l2_coef=1.0,
    max_epoch=5,
):
    """Finetune a trained model with PPT."""
    if max_length is None:
        max_length = {}

    artifacts_dir = Path(artifacts_dir)
    _log.info("Creating artifacts directory %s", artifacts_dir)
    artifacts_dir.mkdir(exist_ok=overwrite)

    samples = {
        wh: list(read_samples(which=wh, max_length=max_length.get(wh)))
        for wh in ["train", "dev", "test"]
    }
    for wh in samples:
        n_toks = sum(len(s["words"]) for s in samples[wh])
        _log.info("Read %d %s samples and %d tokens", len(samples[wh]), wh,
                  n_toks)

    path = Path(load_from) / "vocab.yml"
    _log.info("Loading vocabulary from %s", path)
    vocab = load(path.read_text(encoding="utf8"))
    for name in vocab:
        _log.info("Found %d %s", len(vocab[name]), name)

    _log.info("Extending vocabulary with target words")
    vocab.extend(chain(*samples.values()), ["words"])
    _log.info("Found %d words now", len(vocab["words"]))

    path = artifacts_dir / "vocab.yml"
    _log.info("Saving vocabulary to %s", path)
    path.write_text(dump(vocab), encoding="utf8")

    samples = {wh: list(vocab.stoi(samples[wh])) for wh in samples}

    path = Path(load_from) / "model.yml"
    _log.info("Loading model from metadata %s", path)
    model = load(path.read_text(encoding="utf8"))

    path = Path(load_from) / load_params
    _log.info("Loading model parameters from %s", path)
    model.load_state_dict(torch.load(path, "cpu"))

    _log.info("Creating extended word embedding layer")
    kv = KeyedVectors.load_word2vec_format(word_emb_path)
    assert model.word_emb.embedding_dim == kv.vector_size
    with torch.no_grad():
        model.word_emb = torch.nn.Embedding.from_pretrained(
            extend_word_embedding(model.word_emb.weight, vocab["words"], kv))

    path = artifacts_dir / "model.yml"
    _log.info("Saving model metadata to %s", path)
    path.write_text(dump(model), encoding="utf8")

    model.word_emb.requires_grad_(not freeze)
    model.tag_emb.requires_grad_(not freeze)
    model.to(device)

    for wh in ["train", "dev"]:
        for i, s in enumerate(samples[wh]):
            s["_id"] = i

        runner = Runner()
        runner.state.update({"ppt_masks": [], "_ids": []})
        runner.on(
            Event.BATCH,
            [
                batch2tensors(device, vocab),
                set_train_mode(model, training=False),
                compute_total_arc_type_scores(model, vocab),
            ],
        )

        @runner.on(Event.BATCH)
        def compute_ppt_ambiguous_arcs_mask(state):
            assert state["batch"]["mask"].all()
            scores = state["total_arc_type_scores"]
            ppt_mask = compute_ambiguous_arcs_mask(scores, thresh, projective,
                                                   multiroot)
            state["ppt_masks"].extend(ppt_mask.tolist())
            state["_ids"].extend(state["batch"]["_id"].tolist())
            state["n_items"] = state["batch"]["words"].numel()

        n_toks = sum(len(s["words"]) for s in samples[wh])
        ProgressBar(total=n_toks, unit="tok").attach_on(runner)

        _log.info("Computing PPT ambiguous arcs mask for %s set", wh)
        with torch.no_grad():
            runner.run(
                BucketIterator(samples[wh], lambda s: len(s["words"]),
                               batch_size))

        assert len(runner.state["ppt_masks"]) == len(samples[wh])
        assert len(runner.state["_ids"]) == len(samples[wh])
        for i, ppt_mask in zip(runner.state["_ids"],
                               runner.state["ppt_masks"]):
            samples[wh][i]["ppt_mask"] = ppt_mask

        _log.info("Computing (log) number of trees stats on %s set", wh)
        report_log_ntrees_stats(samples[wh], "ppt_mask", batch_size,
                                projective, multiroot)

    _log.info("Creating optimizer")
    opt = torch.optim.Adam(model.parameters(), lr=lr)

    finetuner = Runner()
    origin_params = {
        name: p.clone().detach()
        for name, p in model.named_parameters()
    }
    finetuner.on(
        Event.BATCH,
        [
            batch2tensors(device, vocab),
            set_train_mode(model),
            compute_l2_loss(model, origin_params),
            compute_total_arc_type_scores(model, vocab),
        ],
    )

    @finetuner.on(Event.BATCH)
    def compute_loss(state):
        mask = state["batch"]["mask"]
        ppt_mask = state["batch"]["ppt_mask"].bool()
        scores = state["total_arc_type_scores"]

        ppt_loss = compute_aatrn_loss(scores, ppt_mask, mask, projective,
                                      multiroot)
        ppt_loss /= mask.size(0)
        loss = ppt_loss + l2_coef * state["l2_loss"]

        state["loss"] = loss
        state["stats"] = {
            "ppt_loss": ppt_loss.item(),
            "l2_loss": state["l2_loss"].item(),
        }
        state["extra_stats"] = {"loss": loss.item()}
        state["n_items"] = mask.long().sum().item()

    finetuner.on(Event.BATCH,
                 [update_params(opt),
                  log_grads(_run, model),
                  log_stats(_run)])

    @finetuner.on(Event.EPOCH_FINISHED)
    def eval_on_dev(state):
        _log.info("Evaluating on dev")
        eval_state = run_eval(model, vocab, samples["dev"])
        accs = eval_state["counts"].accs
        print_accs(accs, run=_run, step=state["n_iters"])

        ppt_loss = eval_state["mean_ppt_loss"]
        _log.info("dev_ppt_loss: %.4f", ppt_loss)
        _run.log_scalar("dev_ppt_loss", ppt_loss, step=state["n_iters"])

        state["dev_accs"] = accs

    @finetuner.on(Event.EPOCH_FINISHED)
    def maybe_eval_on_test(state):
        if state["epoch"] != max_epoch:
            return

        _log.info("Evaluating on test")
        eval_state = run_eval(model,
                              vocab,
                              samples["test"],
                              compute_loss=False)
        print_accs(eval_state["counts"].accs,
                   on="test",
                   run=_run,
                   step=state["n_iters"])

    finetuner.on(Event.EPOCH_FINISHED,
                 save_state_dict("model", model, under=artifacts_dir))

    EpochTimer().attach_on(finetuner)
    n_tokens = sum(len(s["words"]) for s in samples["train"])
    ProgressBar(stats="stats", total=n_tokens, unit="tok").attach_on(finetuner)

    bucket_key = lambda s: (len(s["words"]) - 1) // 10
    trn_iter = ShuffleIterator(
        BucketIterator(samples["train"],
                       bucket_key,
                       batch_size,
                       shuffle_bucket=True,
                       rng=_rnd),
        rng=_rnd,
    )
    _log.info("Starting finetuning")
    try:
        finetuner.run(trn_iter, max_epoch)
    except KeyboardInterrupt:
        _log.info("Interrupt detected, training will abort")
    else:
        return finetuner.state["dev_accs"]["las_nopunct"]
Ejemplo n.º 6
0
def run_eval(
    model,
    id2label,
    samples,
    corpus,
    _log,
    device="cpu",
    batch_size=32,
    gold_path="",
    compute_loss=False,
    confusion=False,
):
    if not gold_path and not compute_loss:
        _log.info(
            "Skipping evaluation since gold data isn't provided and loss isn't required"
        )
        return None, None

    runner = Runner()
    runner.state.update({"preds": [], "_ids": []})

    @runner.on(Event.BATCH)
    def maybe_compute_prediction(state):
        if not gold_path:
            return

        arr = state["batch"].to_array()
        state["arr"] = arr
        assert arr["mask"].all()
        words = torch.from_numpy(arr["word_ids"]).long().to(device)

        model.eval()
        scores = model(words)
        preds = LinearCRF(scores).argmax()

        state["preds"].extend(preds.tolist())
        state["_ids"].extend(arr["_id"].tolist())
        if compute_loss:
            state["scores"] = scores

    @runner.on(Event.BATCH)
    def maybe_compute_loss(state):
        if not compute_loss:
            return

        arr = state["arr"] if "arr" in state else state["batch"].to_array()
        state["arr"] = arr
        if "scores" in state:
            scores = state["scores"]
        else:
            assert arr["mask"].all()
            words = torch.from_numpy(arr["word_ids"]).long().to(device)
            model.eval()
            scores = model(words)

        mask = torch.from_numpy(arr["mask"]).bool().to(device)
        ptst_mask = torch.from_numpy(arr["ptst_mask"]).bool().to(device)

        masked_scores = scores.masked_fill(~ptst_mask, -1e9)
        crf = LinearCRF(masked_scores)
        crf_z = LinearCRF(scores)
        ptst_loss = -crf.log_partitions().sum() + crf_z.log_partitions().sum()
        state["ptst_loss"] = ptst_loss.item()
        state["size"] = mask.size(0)

    @runner.on(Event.BATCH)
    def set_n_items(state):
        state["n_items"] = int(state["arr"]["mask"].sum())

    n_tokens = sum(len(s["word_ids"]) for s in samples)
    ProgressBar(leave=False, total=n_tokens, unit="tok").attach_on(runner)
    if compute_loss:
        MeanReducer("mean_ptst_loss", value="ptst_loss").attach_on(runner)

    with torch.no_grad():
        runner.run(
            BucketIterator(samples, lambda s: len(s["word_ids"]), batch_size))

    if runner.state["preds"]:
        assert len(runner.state["preds"]) == len(samples)
        assert len(runner.state["_ids"]) == len(samples)
        for i, preds in zip(runner.state["_ids"], runner.state["preds"]):
            samples[i]["preds"] = preds

    if gold_path:
        group = defaultdict(list)
        for s in samples:
            group[str(s["path"])].append(s)

        with tempfile.TemporaryDirectory() as dirname:
            dirname = Path(dirname)
            for doc_path, doc_samples in group.items():
                spans = [x for s in doc_samples for x in s["spans"]]
                labels = [id2label[x] for s in doc_samples for x in s["preds"]]
                doc_path = Path(doc_path[len(f"{corpus['path']}/"):])
                data = make_anafora(spans, labels, doc_path.name)
                (dirname / doc_path.parent).mkdir(parents=True, exist_ok=True)
                data.to_file(f"{str(dirname / doc_path)}.xml")
            return (
                score_time(gold_path, str(dirname), confusion),
                runner.state.get("mean_ptst_loss"),
            )
    return None, runner.state.get("mean_ptst_loss")
Ejemplo n.º 7
0
def finetune(
    _log,
    _run,
    _rnd,
    corpus,
    artifacts_dir="artifacts",
    overwrite=False,
    temperature=1.0,
    freeze_embeddings=True,
    freeze_encoder_up_to=1,
    device="cpu",
    thresh=0.95,
    batch_size=16,
    lr=1e-5,
    max_epoch=5,
    predict_on_finished=False,
):
    """Finetune/train the source model on unlabeled target data."""
    artifacts_dir = Path(artifacts_dir)
    artifacts_dir.mkdir(exist_ok=overwrite)

    samples = read_samples_()
    eval_samples = read_samples_(max_length=None)
    model_name = "clulab/roberta-timex-semeval"
    _log.info("Loading %s", model_name)
    config = AutoConfig.from_pretrained(model_name)
    token_clf = AutoModelForTokenClassification.from_pretrained(model_name,
                                                                config=config)
    model = RoBERTagger(token_clf, config.num_labels, temperature)

    _log.info("Initializing transitions")
    torch.nn.init.zeros_(model.start_transition)
    torch.nn.init.zeros_(model.transition)
    for lid, label in config.id2label.items():
        if not label.startswith("I-"):
            continue

        with torch.no_grad():
            model.start_transition[lid] = -1e9
        for plid, plabel in config.id2label.items():
            if plabel == "O" or plabel[2:] != label[2:]:
                with torch.no_grad():
                    model.transition[plid, lid] = -1e9

    for name, p in model.named_parameters():
        freeze = False
        if freeze_embeddings and ".embeddings." in name:
            freeze = True
        if freeze_encoder_up_to >= 0:
            for i in range(freeze_encoder_up_to + 1):
                if f".encoder.layer.{i}." in name:
                    freeze = True
        if freeze:
            _log.info("Freezing %s", name)
            p.requires_grad_(False)

    model.to(device)

    _log.info("Computing ambiguous PTST tag pairs mask")
    model.eval()
    ptst_masks, _ids = [], []
    pbar = tqdm(total=sum(len(s["word_ids"]) for s in samples), unit="tok")
    for batch in BucketIterator(samples, lambda s: len(s["word_ids"]),
                                batch_size):
        arr = batch.to_array()
        assert arr["mask"].all()
        words = torch.from_numpy(arr["word_ids"]).long().to(device)
        with torch.no_grad():
            ptst_mask = compute_ambiguous_tag_pairs_mask(model(words), thresh)
        ptst_masks.extend(ptst_mask.tolist())
        _ids.extend(arr["_id"].tolist())
        pbar.update(int(arr["mask"].sum()))
    pbar.close()

    assert len(ptst_masks) == len(samples)
    assert len(_ids) == len(samples)
    for i, ptst_mask in zip(_ids, ptst_masks):
        samples[i]["ptst_mask"] = ptst_mask

    _log.info("Report number of sequences")
    log_total_nseqs, log_nseqs = [], []
    pbar = tqdm(total=sum(len(s["word_ids"]) for s in samples), leave=False)
    for batch in BucketIterator(samples, lambda s: len(s["word_ids"]),
                                batch_size):
        arr = batch.to_array()
        assert arr["mask"].all()
        ptst_mask = torch.from_numpy(arr["ptst_mask"]).bool().to(device)
        cnt_scores = torch.zeros_like(ptst_mask).float()
        cnt_scores_masked = cnt_scores.masked_fill(~ptst_mask, -1e9)
        log_total_nseqs.extend(LinearCRF(cnt_scores).log_partitions().tolist())
        log_nseqs.extend(
            LinearCRF(cnt_scores_masked).log_partitions().tolist())
        pbar.update(arr["word_ids"].size)
    pbar.close()
    cov = [math.exp(x - x_) for x, x_ in zip(log_nseqs, log_total_nseqs)]
    _log.info(
        "Number of seqs: min {:.2} ({:.2}%) | med {:.2} ({:.2}%) | max {:.2} ({:.2}%)"
        .format(
            math.exp(min(log_nseqs)),
            100 * min(cov),
            math.exp(median(log_nseqs)),
            100 * median(cov),
            math.exp(max(log_nseqs)),
            100 * max(cov),
        ))

    _log.info("Creating optimizer")
    opt = torch.optim.Adam(model.parameters(), lr=lr)

    finetuner = Runner()

    @finetuner.on(Event.BATCH)
    def compute_loss(state):
        arr = state["batch"].to_array()
        words = torch.from_numpy(arr["word_ids"]).long().to(device)
        mask = torch.from_numpy(arr["mask"]).bool().to(device)
        ptst_mask = torch.from_numpy(arr["ptst_mask"]).bool().to(device)

        model.train()
        scores = model(words, mask)
        masked_scores = scores.masked_fill(~ptst_mask, -1e9)

        # mask passed to LinearCRF shouldn't include the last token
        last_idx = mask.long().sum(dim=1, keepdim=True) - 1
        mask_ = mask.scatter(1, last_idx, False)[:, :-1]

        crf = LinearCRF(masked_scores, mask_)
        crf_z = LinearCRF(scores, mask_)
        ptst_loss = -crf.log_partitions().sum() + crf_z.log_partitions().sum()
        ptst_loss /= mask.size(0)

        state["loss"] = ptst_loss
        state["stats"] = {"ptst_loss": ptst_loss.item()}
        state["n_items"] = mask.long().sum().item()

    finetuner.on(Event.BATCH,
                 [update_params(opt),
                  log_grads(_run, model),
                  log_stats(_run)])

    @finetuner.on(Event.EPOCH_FINISHED)
    def evaluate(state):
        _log.info("Evaluating on train")
        eval_score, loss = run_eval(model,
                                    config.id2label,
                                    samples,
                                    compute_loss=True)
        if eval_score is not None:
            print_accs(eval_score, on="train", run=_run, step=state["n_iters"])
        _log.info("train_ptst_loss: %.4f", loss)
        _run.log_scalar("train_ptst_loss", loss, step=state["n_iters"])

        _log.info("Evaluating on eval")
        eval_score, _ = run_eval(model, config.id2label, eval_samples)
        if eval_score is not None:
            print_accs(eval_score, on="eval", run=_run, step=state["n_iters"])

        state["eval_f1"] = None if eval_score is None else eval_score["f1"]

    finetuner.on(Event.EPOCH_FINISHED,
                 save_state_dict("model", model, under=artifacts_dir))

    @finetuner.on(Event.FINISHED)
    def maybe_predict(state):
        if not predict_on_finished:
            return

        _log.info("Computing predictions")
        model.eval()
        preds, _ids = [], []
        pbar = tqdm(total=sum(len(s["word_ids"]) for s in eval_samples),
                    unit="tok")
        for batch in BucketIterator(eval_samples, lambda s: len(s["word_ids"]),
                                    batch_size):
            arr = batch.to_array()
            assert arr["mask"].all()
            words = torch.from_numpy(arr["word_ids"]).long().to(device)
            scores = model(words)
            pred = LinearCRF(scores).argmax()
            preds.extend(pred.tolist())
            _ids.extend(arr["_id"].tolist())
            pbar.update(int(arr["mask"].sum()))
        pbar.close()

        assert len(preds) == len(eval_samples)
        assert len(_ids) == len(eval_samples)
        for i, preds_ in zip(_ids, preds):
            eval_samples[i]["preds"] = preds_

        group = defaultdict(list)
        for s in eval_samples:
            group[str(s["path"])].append(s)

        _log.info("Writing predictions")
        for doc_path, doc_samples in group.items():
            spans = [x for s in doc_samples for x in s["spans"]]
            labels = [
                config.id2label[x] for s in doc_samples for x in s["preds"]
            ]
            doc_path = Path(doc_path[len(f"{corpus['path']}/"):])
            data = make_anafora(spans, labels, doc_path.name)
            (artifacts_dir / "time" / doc_path.parent).mkdir(parents=True,
                                                             exist_ok=True)
            data.to_file(
                f"{str(artifacts_dir / 'time' / doc_path)}.TimeNorm.system.completed.xml"
            )

    EpochTimer().attach_on(finetuner)
    n_tokens = sum(len(s["word_ids"]) for s in samples)
    ProgressBar(stats="stats", total=n_tokens, unit="tok").attach_on(finetuner)

    bucket_key = lambda s: (len(s["word_ids"]) - 1) // 10
    trn_iter = ShuffleIterator(
        BucketIterator(samples,
                       bucket_key,
                       batch_size,
                       shuffle_bucket=True,
                       rng=_rnd),
        rng=_rnd,
    )
    _log.info("Starting finetuning")
    try:
        finetuner.run(trn_iter, max_epoch)
    except KeyboardInterrupt:
        _log.info("Interrupt detected, training will abort")
    else:
        return finetuner.state.get("eval_f1")
Ejemplo n.º 8
0
def report_coverage(corpus,
                    _log,
                    temperature=1.0,
                    device="cpu",
                    batch_size=16,
                    thresh=0.95,
                    gold_path=""):
    """Report coverage of gold tags in the chart."""
    samples = read_samples_()
    model_name = "clulab/roberta-timex-semeval"
    _log.info("Loading %s", model_name)
    config = AutoConfig.from_pretrained(model_name)
    token_clf = AutoModelForTokenClassification.from_pretrained(model_name,
                                                                config=config)
    model = RoBERTagger(token_clf, config.num_labels, temperature)

    _log.info("Initializing transitions")
    torch.nn.init.zeros_(model.start_transition)
    torch.nn.init.zeros_(model.transition)
    for lid, label in config.id2label.items():
        if not label.startswith("I-"):
            continue

        with torch.no_grad():
            model.start_transition[lid] = -1e9
        for plid, plabel in config.id2label.items():
            if plabel == "O" or plabel[2:] != label[2:]:
                with torch.no_grad():
                    model.transition[plid, lid] = -1e9

    model.to(device)

    _log.info("Computing ambiguous PTST tag pairs mask")
    model.eval()
    ptst_masks, _ids = [], []
    pbar = tqdm(total=sum(len(s["word_ids"]) for s in samples), unit="tok")
    for batch in BucketIterator(samples, lambda s: len(s["word_ids"]),
                                batch_size):
        arr = batch.to_array()
        assert arr["mask"].all()
        words = torch.from_numpy(arr["word_ids"]).long().to(device)
        with torch.no_grad():
            ptst_mask = compute_ambiguous_tag_pairs_mask(model(words), thresh)
        ptst_masks.extend(ptst_mask.tolist())
        _ids.extend(arr["_id"].tolist())
        pbar.update(int(arr["mask"].sum()))
    pbar.close()

    assert len(ptst_masks) == len(samples)
    assert len(_ids) == len(samples)
    for i, ptst_mask in zip(_ids, ptst_masks):
        samples[i]["ptst_mask"] = ptst_mask

    _log.info("Reporting coverage of gold labels")

    group = defaultdict(list)
    for s in samples:
        k = str(s["path"])[len(f"{corpus['path']}/"):]
        group[k].append(s)

    n_cov_tp, n_total_tp, n_cov_ts, n_total_ts = 0, 0, 0, 0
    for dirpath, _, filenames in os.walk(gold_path):
        if not filenames:
            continue
        if len(filenames) > 1:
            raise ValueError(f"more than 1 file is found in {dirpath}")
        if not filenames[0].endswith(".TimeNorm.gold.completed.xml"):
            raise ValueError(
                f"{filenames[0]} doesn't have the expected suffix")

        doc_path = os.path.join(dirpath, filenames[0])
        data = AnaforaData.from_file(doc_path)
        prefix, suffix = f"{gold_path}/", ".TimeNorm.gold.completed.xml"
        doc_path = doc_path[len(prefix):-len(suffix)]
        tok_spans = [p for s in group[doc_path] for p in s["spans"]]
        tok_spans.sort()

        labeling = {}
        for ann in data.annotations:
            if len(ann.spans) != 1:
                raise ValueError("found annotation with >1 span")
            span = ann.spans[0]
            beg = 0
            while beg < len(tok_spans) and tok_spans[beg][0] < span[0]:
                beg += 1
            end = beg
            while end < len(tok_spans) and tok_spans[end][1] < span[1]:
                end += 1
            if (beg < len(tok_spans) and end < len(tok_spans)
                    and tok_spans[beg][0] == span[0]
                    and tok_spans[end][1] == span[1] and beg not in labeling):
                labeling[beg] = f"B-{ann.type}"
                for i in range(beg + 1, end + 1):
                    if i not in labeling:
                        labeling[i] = f"I-{ann.type}"

        labels = ["O"] * len(tok_spans)
        for k, v in labeling.items():
            labels[k] = v

        offset = 0
        for s in group[doc_path]:
            ts_covd = True
            for i in range(1, len(s["spans"])):
                plab = labels[offset + i - 1]
                lab = labels[offset + i]
                if s["ptst_mask"][i - 1][config.label2id[plab]][
                        config.label2id[lab]]:
                    n_cov_tp += 1
                else:
                    ts_covd = False
                n_total_tp += 1
            if ts_covd:
                n_cov_ts += 1
            n_total_ts += 1
            offset += len(s["spans"])

    _log.info(
        "Number of covered tag pairs: %d out of %d (%.1f%%)",
        n_cov_tp,
        n_total_tp,
        100.0 * n_cov_tp / n_total_tp,
    )
    _log.info(
        "Number of covered tag sequences: %d out of %d (%.1f%%)",
        n_cov_ts,
        n_total_ts,
        100.0 * n_cov_ts / n_total_ts,
    )
Ejemplo n.º 9
0
def train(
    _log,
    _run,
    _rnd,
    artifacts_dir="artifacts",
    overwrite=False,
    max_length=None,
    load_types_vocab_from=None,
    batch_size=16,
    device="cpu",
    lr=0.001,
    patience=5,
    max_epoch=1000,
):
    """Train a self-attention graph-based parser."""
    if max_length is None:
        max_length = {}

    artifacts_dir = Path(artifacts_dir)
    _log.info("Creating artifacts directory %s", artifacts_dir)
    artifacts_dir.mkdir(exist_ok=overwrite)

    samples = {
        wh: list(read_samples(which=wh, max_length=max_length.get(wh)))
        for wh in ["train", "dev", "test"]
    }
    for wh in samples:
        n_toks = sum(len(s["words"]) for s in samples[wh])
        _log.info("Read %d %s samples and %d tokens", len(samples[wh]), wh,
                  n_toks)

    _log.info("Creating vocabulary")
    vocab = Vocab.from_samples(chain(*samples.values()))
    if load_types_vocab_from:
        path = Path(load_types_vocab_from)
        _log.info("Loading types vocab from %s", path)
        vocab["types"] = load(path.read_text(encoding="utf8"))["types"]

    _log.info("Vocabulary created")
    for name in vocab:
        _log.info("Found %d %s", len(vocab[name]), name)

    path = artifacts_dir / "vocab.yml"
    _log.info("Saving vocabulary to %s", path)
    path.write_text(dump(vocab), encoding="utf8")

    samples = {wh: list(vocab.stoi(samples[wh])) for wh in samples}

    model = make_model(vocab)
    model.to(device)

    _log.info("Creating optimizer")
    opt = torch.optim.Adam(model.parameters(), lr=lr)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(opt,
                                                           mode="max",
                                                           factor=0.5)

    trainer = Runner()
    trainer.state.update({"dev_larcs_nopunct": -1, "dev_uarcs_nopunct": -1})
    trainer.on(Event.BATCH,
               [batch2tensors(device, vocab),
                set_train_mode(model)])

    @trainer.on(Event.BATCH)
    def compute_loss(state):
        bat = state["batch"]
        words, tags, heads, types = bat["words"], bat["tags"], bat[
            "heads"], bat["types"]
        mask = bat["mask"]

        arc_scores, type_scores = model(words, tags, mask, heads)
        arc_scores = arc_scores.masked_fill(~mask.unsqueeze(2),
                                            -1e9)  # mask padding heads
        type_scores[..., vocab["types"].index(Vocab.PAD_TOKEN)] = -1e9

        # remove root
        arc_scores, type_scores = arc_scores[:, :, 1:], type_scores[:, 1:]
        heads, types, mask = heads[:, 1:], types[:, 1:], mask[:, 1:]

        arc_scores = rearrange(arc_scores,
                               "bsz slen1 slen2 -> (bsz slen2) slen1")
        heads = heads.reshape(-1)
        arc_loss = torch.nn.functional.cross_entropy(arc_scores,
                                                     heads,
                                                     reduction="none")

        type_scores = rearrange(type_scores,
                                "bsz slen ntypes -> (bsz slen) ntypes")
        types = types.reshape(-1)
        type_loss = torch.nn.functional.cross_entropy(type_scores,
                                                      types,
                                                      reduction="none")

        arc_loss = arc_loss.masked_select(mask.reshape(-1)).mean()
        type_loss = type_loss.masked_select(mask.reshape(-1)).mean()
        loss = arc_loss + type_loss

        state["loss"] = loss
        arc_loss, type_loss = arc_loss.item(), type_loss.item()
        state["stats"] = {
            "arc_ppl": math.exp(arc_loss),
            "type_ppl": math.exp(type_loss),
        }
        state["extra_stats"] = {"arc_loss": arc_loss, "type_loss": type_loss}
        state["n_items"] = bat["mask"].long().sum().item()

    trainer.on(Event.BATCH,
               [update_params(opt),
                log_grads(_run, model),
                log_stats(_run)])

    @trainer.on(Event.EPOCH_FINISHED)
    def eval_on_dev(state):
        _log.info("Evaluating on dev")
        eval_state = run_eval(model, vocab, samples["dev"])
        accs = eval_state["counts"].accs
        print_accs(accs, run=_run, step=state["n_iters"])

        scheduler.step(accs["las_nopunct"])

        if eval_state["counts"].larcs_nopunct > state["dev_larcs_nopunct"]:
            state["better"] = True
        elif eval_state["counts"].larcs_nopunct < state["dev_larcs_nopunct"]:
            state["better"] = False
        elif eval_state["counts"].uarcs_nopunct > state["dev_uarcs_nopunct"]:
            state["better"] = True
        else:
            state["better"] = False

        if state["better"]:
            _log.info("Found new best result on dev!")
            state["dev_larcs_nopunct"] = eval_state["counts"].larcs_nopunct
            state["dev_uarcs_nopunct"] = eval_state["counts"].uarcs_nopunct
            state["dev_accs"] = accs
            state["dev_epoch"] = state["epoch"]
        else:
            _log.info("Not better, the best so far is epoch %d:",
                      state["dev_epoch"])
            print_accs(state["dev_accs"])
            print_accs(state["test_accs"], on="test")

    @trainer.on(Event.EPOCH_FINISHED)
    def maybe_eval_on_test(state):
        if not state["better"]:
            return

        _log.info("Evaluating on test")
        eval_state = run_eval(model, vocab, samples["test"])
        state["test_accs"] = eval_state["counts"].accs
        print_accs(state["test_accs"],
                   on="test",
                   run=_run,
                   step=state["n_iters"])

    trainer.on(
        Event.EPOCH_FINISHED,
        [
            maybe_stop_early(patience=patience),
            save_state_dict("model", model, under=artifacts_dir,
                            when="better"),
        ],
    )

    EpochTimer().attach_on(trainer)
    n_tokens = sum(len(s["words"]) for s in samples["train"])
    ProgressBar(stats="stats", total=n_tokens, unit="tok").attach_on(trainer)

    bucket_key = lambda s: (len(s["words"]) - 1) // 10
    trn_iter = ShuffleIterator(
        BucketIterator(samples["train"],
                       bucket_key,
                       batch_size,
                       shuffle_bucket=True,
                       rng=_rnd),
        rng=_rnd,
    )
    _log.info("Starting training")
    try:
        trainer.run(trn_iter, max_epoch)
    except KeyboardInterrupt:
        _log.info("Interrupt detected, training will abort")
    else:
        return trainer.state["dev_accs"]["las_nopunct"]
Ejemplo n.º 10
0
def finetune(
    _log,
    _run,
    _rnd,
    max_length=None,
    artifacts_dir="ft_artifacts",
    overwrite=False,
    load_from="artifacts",
    load_params="model.pth",
    device="cpu",
    word_emb_path="wiki.id.vec",
    freeze=False,
    projective=False,
    multiroot=True,
    batch_size=32,
    lr=1e-5,
    l2_coef=1.0,
    max_epoch=5,
):
    """Finetune a trained model with self-training."""
    if max_length is None:
        max_length = {}

    artifacts_dir = Path(artifacts_dir)
    _log.info("Creating artifacts directory %s", artifacts_dir)
    artifacts_dir.mkdir(exist_ok=overwrite)

    samples = {
        wh: list(read_samples(which=wh, max_length=max_length.get(wh)))
        for wh in ["train", "dev", "test"]
    }
    for wh in samples:
        n_toks = sum(len(s["words"]) for s in samples[wh])
        _log.info("Read %d %s samples and %d tokens", len(samples[wh]), wh,
                  n_toks)

    path = Path(load_from) / "vocab.yml"
    _log.info("Loading vocabulary from %s", path)
    vocab = load(path.read_text(encoding="utf8"))
    for name in vocab:
        _log.info("Found %d %s", len(vocab[name]), name)

    _log.info("Extending vocabulary with target words")
    vocab.extend(chain(*samples.values()), ["words"])
    _log.info("Found %d words now", len(vocab["words"]))

    path = artifacts_dir / "vocab.yml"
    _log.info("Saving vocabulary to %s", path)
    path.write_text(dump(vocab), encoding="utf8")

    samples = {wh: list(vocab.stoi(samples[wh])) for wh in samples}

    path = Path(load_from) / "model.yml"
    _log.info("Loading model from metadata %s", path)
    model = load(path.read_text(encoding="utf8"))

    path = Path(load_from) / load_params
    _log.info("Loading model parameters from %s", path)
    model.load_state_dict(torch.load(path, "cpu"))

    _log.info("Creating extended word embedding layer")
    kv = KeyedVectors.load_word2vec_format(word_emb_path)
    assert model.word_emb.embedding_dim == kv.vector_size
    with torch.no_grad():
        model.word_emb = torch.nn.Embedding.from_pretrained(
            extend_word_embedding(model.word_emb.weight, vocab["words"], kv))

    path = artifacts_dir / "model.yml"
    _log.info("Saving model metadata to %s", path)
    path.write_text(dump(model), encoding="utf8")

    model.word_emb.requires_grad_(not freeze)
    model.tag_emb.requires_grad_(not freeze)
    model.to(device)

    for wh in ["train"]:
        for i, s in enumerate(samples[wh]):
            s["_id"] = i

        runner = Runner()
        runner.state.update({"st_heads": [], "st_types": [], "_ids": []})
        runner.on(
            Event.BATCH,
            [
                batch2tensors(device, vocab),
                set_train_mode(model, training=False),
                compute_total_arc_type_scores(model, vocab),
                predict_batch(projective, multiroot),
            ],
        )

        @runner.on(Event.BATCH)
        def save_st_trees(state):
            state["st_heads"].extend(state["pred_heads"].tolist())
            state["st_types"].extend(state["pred_types"].tolist())
            state["_ids"].extend(state["batch"]["_id"].tolist())
            state["n_items"] = state["batch"]["words"].numel()

        n_toks = sum(len(s["words"]) for s in samples[wh])
        ProgressBar(total=n_toks, unit="tok").attach_on(runner)

        _log.info("Computing ST trees for %s set", wh)
        with torch.no_grad():
            runner.run(
                BucketIterator(samples[wh], lambda s: len(s["words"]),
                               batch_size))

        assert len(runner.state["st_heads"]) == len(samples[wh])
        assert len(runner.state["st_types"]) == len(samples[wh])
        assert len(runner.state["_ids"]) == len(samples[wh])
        for i, st_heads, st_types in zip(runner.state["_ids"],
                                         runner.state["st_heads"],
                                         runner.state["st_types"]):
            assert len(samples[wh][i]["words"]) == len(st_heads)
            assert len(samples[wh][i]["words"]) == len(st_types)
            samples[wh][i]["st_heads"] = st_heads
            samples[wh][i]["st_types"] = st_types

    _log.info("Creating optimizer")
    opt = torch.optim.Adam(model.parameters(), lr=lr)

    finetuner = Runner()
    origin_params = {
        name: p.clone().detach()
        for name, p in model.named_parameters()
    }
    finetuner.on(
        Event.BATCH,
        [
            batch2tensors(device, vocab),
            set_train_mode(model),
            compute_l2_loss(model, origin_params),
        ],
    )

    @finetuner.on(Event.BATCH)
    def compute_loss(state):
        bat = state["batch"]
        words, tags, heads, types = bat["words"], bat["tags"], bat[
            "st_heads"], bat["st_types"]
        mask = bat["mask"]

        arc_scores, type_scores = model(words, tags, mask, heads)
        arc_scores = arc_scores.masked_fill(~mask.unsqueeze(2),
                                            -1e9)  # mask padding heads
        type_scores[..., vocab["types"].index(vocab.PAD_TOKEN)] = -1e9

        # remove root
        arc_scores, type_scores = arc_scores[:, :, 1:], type_scores[:, 1:]
        heads, types, mask = heads[:, 1:], types[:, 1:], mask[:, 1:]

        arc_scores = rearrange(arc_scores,
                               "bsz slen1 slen2 -> (bsz slen2) slen1")
        heads = heads.reshape(-1)
        arc_loss = torch.nn.functional.cross_entropy(arc_scores,
                                                     heads,
                                                     reduction="none")

        type_scores = rearrange(type_scores,
                                "bsz slen ntypes -> (bsz slen) ntypes")
        types = types.reshape(-1)
        type_loss = torch.nn.functional.cross_entropy(type_scores,
                                                      types,
                                                      reduction="none")

        arc_loss = arc_loss.masked_select(mask.reshape(-1)).mean()
        type_loss = type_loss.masked_select(mask.reshape(-1)).mean()
        loss = arc_loss + type_loss + l2_coef * state["l2_loss"]

        state["loss"] = loss
        state["stats"] = {
            "arc_ppl": arc_loss.exp().item(),
            "type_ppl": type_loss.exp().item(),
            "l2_loss": state["l2_loss"].item(),
        }
        state["extra_stats"] = {
            "arc_loss": arc_loss.item(),
            "type_loss": type_loss.item()
        }

    finetuner.on(
        Event.BATCH,
        [
            get_n_items(),
            update_params(opt),
            log_grads(_run, model),
            log_stats(_run)
        ],
    )

    @finetuner.on(Event.EPOCH_FINISHED)
    def eval_on_dev(state):
        _log.info("Evaluating on dev")
        eval_state = run_eval(model, vocab, samples["dev"])
        accs = eval_state["counts"].accs
        print_accs(accs, run=_run, step=state["n_iters"])
        state["dev_accs"] = accs

    @finetuner.on(Event.EPOCH_FINISHED)
    def maybe_eval_on_test(state):
        if state["epoch"] != max_epoch:
            return

        _log.info("Evaluating on test")
        eval_state = run_eval(model, vocab, samples["test"])
        print_accs(eval_state["counts"].accs,
                   on="test",
                   run=_run,
                   step=state["n_iters"])

    finetuner.on(Event.EPOCH_FINISHED,
                 save_state_dict("model", model, under=artifacts_dir))

    EpochTimer().attach_on(finetuner)
    n_tokens = sum(len(s["words"]) for s in samples["train"])
    ProgressBar(stats="stats", total=n_tokens, unit="tok").attach_on(finetuner)

    bucket_key = lambda s: (len(s["words"]) - 1) // 10
    trn_iter = ShuffleIterator(
        BucketIterator(samples["train"],
                       bucket_key,
                       batch_size,
                       shuffle_bucket=True,
                       rng=_rnd),
        rng=_rnd,
    )
    _log.info("Starting finetuning")
    try:
        finetuner.run(trn_iter, max_epoch)
    except KeyboardInterrupt:
        _log.info("Interrupt detected, training will abort")
    else:
        return finetuner.state["dev_accs"]["las_nopunct"]