Exemplo n.º 1
0
def evaluate(ckpt, hparams_file):
    """main"""

    trainer = Trainer(gpus=0)

    model = BertLabeling.load_from_checkpoint(checkpoint_path=ckpt,
                                              hparams_file=hparams_file,
                                              map_location=None,
                                              batch_size=1,
                                              max_length=128,
                                              workers=0)
    trainer.test(model=model)
Exemplo n.º 2
0
def evaluate(ckpt, hparams_file):
    """main"""

    trainer = Trainer(gpus=[0, 1, 2, 3], distributed_backend="ddp")

    model = BertLabeling.load_from_checkpoint(checkpoint_path=ckpt,
                                              hparams_file=hparams_file,
                                              map_location=None,
                                              batch_size=8,
                                              max_length=128,
                                              workers=40)
    trainer.test(model=model)
Exemplo n.º 3
0
def evaluate(ckpt, hparams_file):
    """main"""

    trainer = Trainer(gpus=[1])

    model = BertLabeling.load_from_checkpoint(checkpoint_path=ckpt,
                                              hparams_file=hparams_file,
                                              map_location=None,
                                              batch_size=16,
                                              max_length=128,
                                              workers=0)
    dataset_seen, dataset_unseen = get_dataloader_test(
        model.args.tgt_domain, tokenizer=model.tokenizer)
    model.dataset_test = dataset_unseen
    trainer.test(model=model)
    model.dataset_test = dataset_seen
    trainer.test(model=model)
def main():
    parser = get_parser()
    args = parser.parse_args()
    trained_mrc_ner_model = BertLabeling.load_from_checkpoint(
        checkpoint_path=args.model_ckpt,
        hparams_file=args.hparams_file,
        map_location=None,
        batch_size=1,
        max_length=args.max_length,
        workers=0)

    data_loader, data_tokenizer = get_dataloader(args,)
    # load token
    vocab_path = os.path.join(args.bert_dir, "vocab.txt")
    with open(vocab_path, "r") as f:
        subtokens = [token.strip() for token in f.readlines()]
    idx2tokens = {}
    for token_idx, token in enumerate(subtokens):
        idx2tokens[token_idx] = token

    query2label_dict = get_query_index_to_label_cate(args.dataset_sign)

    for batch in data_loader:
        tokens, token_type_ids, start_labels, end_labels, start_label_mask, end_label_mask, match_labels, sample_idx, label_idx = batch
        attention_mask = (tokens != 0).long()

        start_logits, end_logits, span_logits = trained_mrc_ner_model.model(tokens, attention_mask=attention_mask, token_type_ids=token_type_ids)
        start_preds, end_preds, span_preds = start_logits > 0, end_logits > 0, span_logits > 0

        subtokens_idx_lst = tokens.numpy().tolist()[0]
        subtokens_lst = [idx2tokens[item] for item in subtokens_idx_lst]
        label_cate = query2label_dict[label_idx.item()]
        readable_input_str = data_tokenizer.decode(subtokens_idx_lst, skip_special_tokens=True)

        if args.flat_ner:
            entities_info = extract_flat_spans(torch.squeeze(start_preds), torch.squeeze(end_preds),
                                               torch.squeeze(span_preds), torch.squeeze(attention_mask), pseudo_tag=label_cate)
            entity_lst = []

            if len(entities_info) != 0:
                for entity_info in entities_info:
                    start, end = entity_info[0], entity_info[1]
                    entity_string = " ".join(subtokens_lst[start: end])
                    entity_string = entity_string.replace(" ##", "")
                    entity_lst.append((start, end, entity_string, entity_info[2]))

        else:
            match_preds = span_logits > 0
            entities_info = extract_nested_spans(start_preds, end_preds, match_preds, start_label_mask, end_label_mask, pseudo_tag=label_cate)

            entity_lst = []

            if len(entities_info) != 0:
                for entity_info in entities_info:
                    start, end = entity_info[0], entity_info[1]
                    entity_string = " ".join(subtokens_lst[start: end+1 ])
                    entity_string = entity_string.replace(" ##", "")
                    entity_lst.append((start, end+1, entity_string, entity_info[2]))

        print("*="*10)
        print(f"Given input: {readable_input_str}")
        print(f"Model predict: {entity_lst}")
Exemplo n.º 5
0
    )

    return dataloader



if __name__ == '__main__':

    #company
    CHECKPOINTS = "/root/mao/249/mrc-ner-company/train_logs/zh_msra_company/zh_msra_bertlarge_lr8e-620200913_dropout0.2_bsz16_maxlen128/epoch=19_v0.ckpt"
    HPARAMS = "/root/mao/249/mrc-ner-company/train_logs/zh_msra_company/zh_msra_bertlarge_lr8e-620200913_dropout0.2_bsz16_maxlen128/lightning_logs/version_0/hparams.yaml"

    model = BertLabeling.load_from_checkpoint(
        checkpoint_path=CHECKPOINTS,
        hparams_file=HPARAMS,
        map_location=None,
        batch_size=1,
        max_length=128,
        workers=0
    )


    dataloader=get_dataloader('demo')

    vocab_file = os.path.join(bert_dir, "vocab.txt")
    tokenizer = BertWordPieceTokenizer(vocab_file=vocab_file)

    query = '公司指企业的组织形式,社会经济组织'

    with torch.no_grad():
        for batch in dataloader:          
            tokens, token_type_ids = batch