示例#1
0
def compute_ambiguous_tag_pairs_mask(scores: Tensor,
                                     threshold: float = 0.95) -> BoolTensor:
    bsz, slen, n_next_tags, n_tags = scores.shape

    crf = LinearCRF(scores)
    margs = crf.marginals()

    # select high prob tag pairs until their cumulative probability exceeds threshold
    margs = rearrange(margs,
                      "bsz slen nntags ntags -> bsz slen (nntags ntags)")
    margs, orig_indices = margs.sort(dim=2, descending=True)
    tp_mask = margs.cumsum(dim=2) < threshold

    # select the tag pairs that make the cum sum exceeds threshold
    last_idx = tp_mask.long().sum(
        dim=2, keepdim=True).clamp(max=n_next_tags * n_tags - 1)
    tp_mask = tp_mask.scatter(2, last_idx, True)

    # restore the order and shape
    _, restore_indices = orig_indices.sort(dim=2)
    tp_mask = tp_mask.gather(2, restore_indices)

    # ensure best tag sequence is selected
    best_tags = crf.argmax()
    assert best_tags.shape == (bsz, slen + 1)
    best_idx = best_tags[:, 1:] * n_tags + best_tags[:, :-1]
    assert best_idx.shape == (bsz, slen)
    tp_mask = tp_mask.scatter(2, best_idx.unsqueeze(2), True)

    tp_mask = rearrange(tp_mask,
                        "bsz slen (nntags ntags) -> bsz slen nntags ntags",
                        nntags=n_next_tags)
    return tp_mask  # type: ignore
示例#2
0
def test_file(model_path, test_file_path):
    """Test model

    test file format
    今   B
    晚   E
    月   B
    色   E
    真   S
    美   S
    。   S
    <MUST SEPERATE BY SPACE LINE> 
    我   S

    output file format
    今   B   preTag
    晚   E   preTag
    """
    if not os.path.isfile(model_path) or not os.path.isfile(test_file_path):
        print("File don't exist!")
    model = LinearCRF()
    model.load(model_path)

    f = codecs.open(test_file_path, 'r', encoding='utf-8')
    lines = f.readlines()
    f.close()

    sentences = []
    labels = []
    sentence = []
    label = []
    for line in lines:
        if len(line) < 2:
            # sentence end
            sentences.append(sentence)
            labels.append(label)
            sentence = []
            label = []
        else:
            char, tag = line.split()
            sentence.append(char)
            label.append(tag)

    pre_tags = [model.inference_viterbi(sen) for sen in sentences]

    with open('test_result.txt', 'w+') as f:
        for sen, sen_tag, sen_pre in zip(sentences, labels, pre_tags):
            for i in range(len(sen)):
                f.write('{}\t{}\t{}\n'.format(sen[i], sen_tag[i], sen_pre[i]))
            f.write('\n')

    print('Test finished!')
示例#3
0
class Segmentation(object):
    def __init__(self, model_path='model/linear_crf.model'):
        self.model = LinearCRF()
        self.model.load(model_path)

    def seg(self, sentence):
        sentence.strip()
        tags = self.model.inference_viterbi(sentence)

        str_seg = ""
        for word, tag in zip(sentence, tags):
            str_seg += word
            if tag == 'S' or tag == 'E':
                str_seg += ' '
        result = str_seg.split()
        return result
示例#4
0
    def maybe_compute_prediction(state):
        if not gold_path:
            return

        arr = state["batch"].to_array()
        state["arr"] = arr
        assert arr["mask"].all()
        words = torch.from_numpy(arr["word_ids"]).long().to(device)

        model.eval()
        scores = model(words)
        preds = LinearCRF(scores).argmax()

        state["preds"].extend(preds.tolist())
        state["_ids"].extend(arr["_id"].tolist())
        if compute_loss:
            state["scores"] = scores
示例#5
0
    def maybe_predict(state):
        if not predict_on_finished:
            return

        _log.info("Computing predictions")
        model.eval()
        preds, _ids = [], []
        pbar = tqdm(total=sum(len(s["word_ids"]) for s in eval_samples),
                    unit="tok")
        for batch in BucketIterator(eval_samples, lambda s: len(s["word_ids"]),
                                    batch_size):
            arr = batch.to_array()
            assert arr["mask"].all()
            words = torch.from_numpy(arr["word_ids"]).long().to(device)
            scores = model(words)
            pred = LinearCRF(scores).argmax()
            preds.extend(pred.tolist())
            _ids.extend(arr["_id"].tolist())
            pbar.update(int(arr["mask"].sum()))
        pbar.close()

        assert len(preds) == len(eval_samples)
        assert len(_ids) == len(eval_samples)
        for i, preds_ in zip(_ids, preds):
            eval_samples[i]["preds"] = preds_

        group = defaultdict(list)
        for s in eval_samples:
            group[str(s["path"])].append(s)

        _log.info("Writing predictions")
        for doc_path, doc_samples in group.items():
            spans = [x for s in doc_samples for x in s["spans"]]
            labels = [
                config.id2label[x] for s in doc_samples for x in s["preds"]
            ]
            doc_path = Path(doc_path[len(f"{corpus['path']}/"):])
            data = make_anafora(spans, labels, doc_path.name)
            (artifacts_dir / "time" / doc_path.parent).mkdir(parents=True,
                                                             exist_ok=True)
            data.to_file(
                f"{str(artifacts_dir / 'time' / doc_path)}.TimeNorm.system.completed.xml"
            )
示例#6
0
    def compute_loss(state):
        arr = state["batch"].to_array()
        words = torch.from_numpy(arr["word_ids"]).long().to(device)
        mask = torch.from_numpy(arr["mask"]).bool().to(device)
        ptst_mask = torch.from_numpy(arr["ptst_mask"]).bool().to(device)

        model.train()
        scores = model(words, mask)
        masked_scores = scores.masked_fill(~ptst_mask, -1e9)

        # mask passed to LinearCRF shouldn't include the last token
        last_idx = mask.long().sum(dim=1, keepdim=True) - 1
        mask_ = mask.scatter(1, last_idx, False)[:, :-1]

        crf = LinearCRF(masked_scores, mask_)
        crf_z = LinearCRF(scores, mask_)
        ptst_loss = -crf.log_partitions().sum() + crf_z.log_partitions().sum()
        ptst_loss /= mask.size(0)

        state["loss"] = ptst_loss
        state["stats"] = {"ptst_loss": ptst_loss.item()}
        state["n_items"] = mask.long().sum().item()
示例#7
0
    def maybe_compute_loss(state):
        if not compute_loss:
            return

        arr = state["arr"] if "arr" in state else state["batch"].to_array()
        state["arr"] = arr
        if "scores" in state:
            scores = state["scores"]
        else:
            assert arr["mask"].all()
            words = torch.from_numpy(arr["word_ids"]).long().to(device)
            model.eval()
            scores = model(words)

        mask = torch.from_numpy(arr["mask"]).bool().to(device)
        ptst_mask = torch.from_numpy(arr["ptst_mask"]).bool().to(device)

        masked_scores = scores.masked_fill(~ptst_mask, -1e9)
        crf = LinearCRF(masked_scores)
        crf_z = LinearCRF(scores)
        ptst_loss = -crf.log_partitions().sum() + crf_z.log_partitions().sum()
        state["ptst_loss"] = ptst_loss.item()
        state["size"] = mask.size(0)
示例#8
0
 def __init__(self, model_path='model/linear_crf.model'):
     self.model = LinearCRF()
     self.model.load(model_path)
示例#9
0
def finetune(
    _log,
    _run,
    _rnd,
    corpus,
    artifacts_dir="artifacts",
    overwrite=False,
    temperature=1.0,
    freeze_embeddings=True,
    freeze_encoder_up_to=1,
    device="cpu",
    thresh=0.95,
    batch_size=16,
    lr=1e-5,
    max_epoch=5,
    predict_on_finished=False,
):
    """Finetune/train the source model on unlabeled target data."""
    artifacts_dir = Path(artifacts_dir)
    artifacts_dir.mkdir(exist_ok=overwrite)

    samples = read_samples_()
    eval_samples = read_samples_(max_length=None)
    model_name = "clulab/roberta-timex-semeval"
    _log.info("Loading %s", model_name)
    config = AutoConfig.from_pretrained(model_name)
    token_clf = AutoModelForTokenClassification.from_pretrained(model_name,
                                                                config=config)
    model = RoBERTagger(token_clf, config.num_labels, temperature)

    _log.info("Initializing transitions")
    torch.nn.init.zeros_(model.start_transition)
    torch.nn.init.zeros_(model.transition)
    for lid, label in config.id2label.items():
        if not label.startswith("I-"):
            continue

        with torch.no_grad():
            model.start_transition[lid] = -1e9
        for plid, plabel in config.id2label.items():
            if plabel == "O" or plabel[2:] != label[2:]:
                with torch.no_grad():
                    model.transition[plid, lid] = -1e9

    for name, p in model.named_parameters():
        freeze = False
        if freeze_embeddings and ".embeddings." in name:
            freeze = True
        if freeze_encoder_up_to >= 0:
            for i in range(freeze_encoder_up_to + 1):
                if f".encoder.layer.{i}." in name:
                    freeze = True
        if freeze:
            _log.info("Freezing %s", name)
            p.requires_grad_(False)

    model.to(device)

    _log.info("Computing ambiguous PTST tag pairs mask")
    model.eval()
    ptst_masks, _ids = [], []
    pbar = tqdm(total=sum(len(s["word_ids"]) for s in samples), unit="tok")
    for batch in BucketIterator(samples, lambda s: len(s["word_ids"]),
                                batch_size):
        arr = batch.to_array()
        assert arr["mask"].all()
        words = torch.from_numpy(arr["word_ids"]).long().to(device)
        with torch.no_grad():
            ptst_mask = compute_ambiguous_tag_pairs_mask(model(words), thresh)
        ptst_masks.extend(ptst_mask.tolist())
        _ids.extend(arr["_id"].tolist())
        pbar.update(int(arr["mask"].sum()))
    pbar.close()

    assert len(ptst_masks) == len(samples)
    assert len(_ids) == len(samples)
    for i, ptst_mask in zip(_ids, ptst_masks):
        samples[i]["ptst_mask"] = ptst_mask

    _log.info("Report number of sequences")
    log_total_nseqs, log_nseqs = [], []
    pbar = tqdm(total=sum(len(s["word_ids"]) for s in samples), leave=False)
    for batch in BucketIterator(samples, lambda s: len(s["word_ids"]),
                                batch_size):
        arr = batch.to_array()
        assert arr["mask"].all()
        ptst_mask = torch.from_numpy(arr["ptst_mask"]).bool().to(device)
        cnt_scores = torch.zeros_like(ptst_mask).float()
        cnt_scores_masked = cnt_scores.masked_fill(~ptst_mask, -1e9)
        log_total_nseqs.extend(LinearCRF(cnt_scores).log_partitions().tolist())
        log_nseqs.extend(
            LinearCRF(cnt_scores_masked).log_partitions().tolist())
        pbar.update(arr["word_ids"].size)
    pbar.close()
    cov = [math.exp(x - x_) for x, x_ in zip(log_nseqs, log_total_nseqs)]
    _log.info(
        "Number of seqs: min {:.2} ({:.2}%) | med {:.2} ({:.2}%) | max {:.2} ({:.2}%)"
        .format(
            math.exp(min(log_nseqs)),
            100 * min(cov),
            math.exp(median(log_nseqs)),
            100 * median(cov),
            math.exp(max(log_nseqs)),
            100 * max(cov),
        ))

    _log.info("Creating optimizer")
    opt = torch.optim.Adam(model.parameters(), lr=lr)

    finetuner = Runner()

    @finetuner.on(Event.BATCH)
    def compute_loss(state):
        arr = state["batch"].to_array()
        words = torch.from_numpy(arr["word_ids"]).long().to(device)
        mask = torch.from_numpy(arr["mask"]).bool().to(device)
        ptst_mask = torch.from_numpy(arr["ptst_mask"]).bool().to(device)

        model.train()
        scores = model(words, mask)
        masked_scores = scores.masked_fill(~ptst_mask, -1e9)

        # mask passed to LinearCRF shouldn't include the last token
        last_idx = mask.long().sum(dim=1, keepdim=True) - 1
        mask_ = mask.scatter(1, last_idx, False)[:, :-1]

        crf = LinearCRF(masked_scores, mask_)
        crf_z = LinearCRF(scores, mask_)
        ptst_loss = -crf.log_partitions().sum() + crf_z.log_partitions().sum()
        ptst_loss /= mask.size(0)

        state["loss"] = ptst_loss
        state["stats"] = {"ptst_loss": ptst_loss.item()}
        state["n_items"] = mask.long().sum().item()

    finetuner.on(Event.BATCH,
                 [update_params(opt),
                  log_grads(_run, model),
                  log_stats(_run)])

    @finetuner.on(Event.EPOCH_FINISHED)
    def evaluate(state):
        _log.info("Evaluating on train")
        eval_score, loss = run_eval(model,
                                    config.id2label,
                                    samples,
                                    compute_loss=True)
        if eval_score is not None:
            print_accs(eval_score, on="train", run=_run, step=state["n_iters"])
        _log.info("train_ptst_loss: %.4f", loss)
        _run.log_scalar("train_ptst_loss", loss, step=state["n_iters"])

        _log.info("Evaluating on eval")
        eval_score, _ = run_eval(model, config.id2label, eval_samples)
        if eval_score is not None:
            print_accs(eval_score, on="eval", run=_run, step=state["n_iters"])

        state["eval_f1"] = None if eval_score is None else eval_score["f1"]

    finetuner.on(Event.EPOCH_FINISHED,
                 save_state_dict("model", model, under=artifacts_dir))

    @finetuner.on(Event.FINISHED)
    def maybe_predict(state):
        if not predict_on_finished:
            return

        _log.info("Computing predictions")
        model.eval()
        preds, _ids = [], []
        pbar = tqdm(total=sum(len(s["word_ids"]) for s in eval_samples),
                    unit="tok")
        for batch in BucketIterator(eval_samples, lambda s: len(s["word_ids"]),
                                    batch_size):
            arr = batch.to_array()
            assert arr["mask"].all()
            words = torch.from_numpy(arr["word_ids"]).long().to(device)
            scores = model(words)
            pred = LinearCRF(scores).argmax()
            preds.extend(pred.tolist())
            _ids.extend(arr["_id"].tolist())
            pbar.update(int(arr["mask"].sum()))
        pbar.close()

        assert len(preds) == len(eval_samples)
        assert len(_ids) == len(eval_samples)
        for i, preds_ in zip(_ids, preds):
            eval_samples[i]["preds"] = preds_

        group = defaultdict(list)
        for s in eval_samples:
            group[str(s["path"])].append(s)

        _log.info("Writing predictions")
        for doc_path, doc_samples in group.items():
            spans = [x for s in doc_samples for x in s["spans"]]
            labels = [
                config.id2label[x] for s in doc_samples for x in s["preds"]
            ]
            doc_path = Path(doc_path[len(f"{corpus['path']}/"):])
            data = make_anafora(spans, labels, doc_path.name)
            (artifacts_dir / "time" / doc_path.parent).mkdir(parents=True,
                                                             exist_ok=True)
            data.to_file(
                f"{str(artifacts_dir / 'time' / doc_path)}.TimeNorm.system.completed.xml"
            )

    EpochTimer().attach_on(finetuner)
    n_tokens = sum(len(s["word_ids"]) for s in samples)
    ProgressBar(stats="stats", total=n_tokens, unit="tok").attach_on(finetuner)

    bucket_key = lambda s: (len(s["word_ids"]) - 1) // 10
    trn_iter = ShuffleIterator(
        BucketIterator(samples,
                       bucket_key,
                       batch_size,
                       shuffle_bucket=True,
                       rng=_rnd),
        rng=_rnd,
    )
    _log.info("Starting finetuning")
    try:
        finetuner.run(trn_iter, max_epoch)
    except KeyboardInterrupt:
        _log.info("Interrupt detected, training will abort")
    else:
        return finetuner.state.get("eval_f1")
                                             smoothness_gold.std())

        y_all_gold = np.concatenate(zip(*clf_results)[0])
        y_all_predict = np.concatenate(zip(*clf_results)[1])

        #print classification_report(y_all_gold, y_all_predict, target_names = labels)
        print confusion_matrix_report(y_all_gold, y_all_predict, labels)
        print confusion_matrix(y_all_gold, y_all_predict)

    crf_classifiers = {
        "CRF": {
            'clf':
            LinearCRF(feature_names=feature_names,
                      label_names=labels,
                      addone=True,
                      regularization="l2",
                      lmbd=0.01,
                      sigma=100,
                      transition_weighting=False),
            'structured':
            True
        },
        "CRF transition weights": {
            'clf':
            LinearCRF(feature_names=feature_names,
                      label_names=labels,
                      addone=True,
                      regularization="l2",
                      lmbd=0.01,
                      sigma=100,
                      transition_weighting=True),
示例#11
0
def main():
    print("main start")
    model = LinearCRF()
    model.train('../data/pku_training.data')
示例#12
0
def main():
    print("main start")
    model = LinearCRF()
    model.train('../data/train.data')