コード例 #1
0
def run_test(test_path):
    test_data = ConllxDataLoader().load(test_path, return_dataset=True)

    with open("model_pp_0117.pkl", "rb") as f:
        save_dict = torch.load(f)
    tag_vocab = save_dict["tag_vocab"]
    pipeline = save_dict["pipeline"]
    index_tag = IndexerProcessor(vocab=tag_vocab,
                                 field_name="tag",
                                 new_added_field_name="truth",
                                 is_input=False)
    pipeline.pipeline = [index_tag] + pipeline.pipeline

    pipeline(test_data)
    test_data.set_target("truth")
    prediction = test_data.field_arrays["predict"].content
    truth = test_data.field_arrays["truth"].content
    seq_len = test_data.field_arrays["word_seq_origin_len"].content

    # padding by hand
    max_length = max([len(seq) for seq in prediction])
    for idx in range(len(prediction)):
        prediction[idx] = list(
            prediction[idx]) + ([0] * (max_length - len(prediction[idx])))
        truth[idx] = list(truth[idx]) + ([0] * (max_length - len(truth[idx])))
    evaluator = SpanFPreRecMetric(tag_vocab=tag_vocab,
                                  pred="predict",
                                  target="truth",
                                  seq_lens="word_seq_origin_len")
    evaluator(
        {
            "predict": torch.Tensor(prediction),
            "word_seq_origin_len": torch.Tensor(seq_len)
        }, {"truth": torch.Tensor(truth)})
    test_result = evaluator.get_metric()
    f1 = round(test_result['f'] * 100, 2)
    pre = round(test_result['pre'] * 100, 2)
    rec = round(test_result['rec'] * 100, 2)

    return {"F1": f1, "precision": pre, "recall": rec}
コード例 #2
0
ファイル: run.py プロジェクト: wlhgtc/fastNLP
def load(path):
    data = ConllxDataLoader().load(path)
    return convert(data)
コード例 #3
0
ファイル: test_dataset_loader.py プロジェクト: wlhgtc/fastNLP
 def test_ConllxDataLoader(self):
     dataset = ConllxDataLoader().load(
         "test/data_for_tests/zh_sample.conllx")
コード例 #4
0
def train(train_data_path, dev_data_path, checkpoint=None, save=None):
    # load config
    train_param = ConfigSection()
    model_param = ConfigSection()
    ConfigLoader().load_config(cfgfile, {
        "train": train_param,
        "model": model_param
    })
    print("config loaded")

    # Data Loader
    print("loading training set...")
    dataset = ConllxDataLoader().load(train_data_path, return_dataset=True)
    print("loading dev set...")
    dev_data = ConllxDataLoader().load(dev_data_path, return_dataset=True)
    print(dataset)
    print("================= dataset ready =====================")

    dataset.rename_field("tag", "truth")
    dev_data.rename_field("tag", "truth")

    vocab_proc = VocabIndexerProcessor("words",
                                       new_added_filed_name="word_seq")
    tag_proc = VocabIndexerProcessor("truth", is_input=True)
    seq_len_proc = SeqLenProcessor(field_name="word_seq",
                                   new_added_field_name="word_seq_origin_len",
                                   is_input=True)
    set_input_proc = SetInputProcessor("word_seq", "word_seq_origin_len")

    vocab_proc(dataset)
    tag_proc(dataset)
    seq_len_proc(dataset)

    # index dev set
    word_vocab, tag_vocab = vocab_proc.vocab, tag_proc.vocab
    dev_data.apply(lambda ins: [word_vocab.to_index(w) for w in ins["words"]],
                   new_field_name="word_seq")
    dev_data.apply(lambda ins: [tag_vocab.to_index(w) for w in ins["truth"]],
                   new_field_name="truth")
    dev_data.apply(lambda ins: len(ins["word_seq"]),
                   new_field_name="word_seq_origin_len")

    # set input & target
    dataset.set_input("word_seq", "word_seq_origin_len", "truth")
    dev_data.set_input("word_seq", "word_seq_origin_len", "truth")
    dataset.set_target("truth", "word_seq_origin_len")
    dev_data.set_target("truth", "word_seq_origin_len")

    # dataset.set_is_target(tag_ids=True)
    model_param["vocab_size"] = vocab_proc.get_vocab_size()
    model_param["num_classes"] = tag_proc.get_vocab_size()
    print("vocab_size={}  num_classes={}".format(model_param["vocab_size"],
                                                 model_param["num_classes"]))

    # define a model
    if checkpoint is None:
        # pre_trained = load_tencent_embed("/home/zyfeng/data/char_tencent_embedding.pkl", vocab_proc.vocab.word2idx)
        pre_trained = None
        model = AdvSeqLabel(model_param, id2words=None, emb=pre_trained)
        print(model)
    else:
        model = torch.load(checkpoint)

    # call trainer to train
    trainer = Trainer(dataset,
                      model,
                      loss=None,
                      metrics=SpanFPreRecMetric(
                          tag_proc.vocab,
                          pred="predict",
                          target="truth",
                          seq_lens="word_seq_origin_len"),
                      dev_data=dev_data,
                      metric_key="f",
                      use_tqdm=True,
                      use_cuda=True,
                      print_every=10,
                      n_epochs=20,
                      save_path=save)
    trainer.train(load_best_model=True)

    # save model & pipeline
    model_proc = ModelProcessor(model,
                                seq_len_field_name="word_seq_origin_len")
    id2tag = Index2WordProcessor(tag_proc.vocab, "predict", "tag")

    pp = Pipeline(
        [vocab_proc, seq_len_proc, set_input_proc, model_proc, id2tag])
    save_dict = {"pipeline": pp, "model": model, "tag_vocab": tag_proc.vocab}
    torch.save(save_dict, os.path.join(save, "model_pp.pkl"))
    print("pipeline saved")
コード例 #5
0
parser.add_argument('--pipe', type=str, default='')
parser.add_argument('--gold_data', type=str, default='')
parser.add_argument('--new_data', type=str)
args = parser.parse_args()

pipe = torch.load(args.pipe)['pipeline']
for p in pipe:
    if p.field_name == 'word_list':
        print(p.field_name)
        p.field_name = 'gold_words'
    elif p.field_name == 'pos_list':
        print(p.field_name)
        p.field_name = 'gold_pos'


data = ConllxDataLoader().load(args.gold_data)
ds = DataSet()
for ins1, ins2 in zip(add_seg_tag(data), data):
    ds.append(Instance(words=ins1[0], tag=ins1[1],
                       gold_words=ins2[0], gold_pos=ins2[1],
                       gold_heads=ins2[2], gold_head_tags=ins2[3]))

ds = pipe(ds)

seg_threshold = 0.
pos_threshold = 0.
parse_threshold = 0.74


def get_heads(ins, head_f, word_f):
    head_pred = []