Ejemplo n.º 1
0
def judge_data_quality(opt):
    train_npy_data = np.load(opt.npy_data_root + 'train/relations.npy')
    dev_npy_data = np.load(opt.npy_data_root + 'dev/relations.npy')
    train_data = []
    dev_data = []
    for i in train_npy_data:
        train_data.append(i)
    for i in dev_npy_data:
        dev_data.append(i)
    json_data = load_data(opt.train_data_dir)
    train_data = utils.get_text_spolist(opt, train_data, json_data)
    json_data = load_data(opt.dev_data_dir)
    dev_data = utils.get_text_spolist(opt, dev_data, json_data)
    print("judge train data..")
    p, r, f = eval_file(train_data, opt.train_data_dir)
    print("train_data p:{};r:{};f1:{}".format(p, r, f))
    print("judge dev data..")
    p, r, f = eval_file(dev_data, opt.dev_data_dir)
    print("dev_data p:{};r:{};f1:{}".format(p, r, f))
Ejemplo n.º 2
0
def tofile(**kwargs):
    opt.parse(kwargs)
    if opt.use_gpu:
        torch.cuda.set_device(opt.gpu_id)
    # 2 model
    model = getattr(models, opt.model)(opt)
    if opt.use_gpu:
        model.cuda()
    print("{} load ckpt from: {}".format(now(), opt.ckpt_path))
    model.load(opt.ckpt_path)
    model.eval()
    data = Data(opt, case=opt.case + 1)
    data_loader = DataLoader(data,
                             batch_size=opt.batch_size,
                             shuffle=False,
                             num_workers=opt.num_workers,
                             collate_fn=collate_fn)
    print("predict case:{},data num:{}".format(opt.case, len(data)))

    tag2id = json.loads(open(opt.tag2id_dir, 'r').readline())
    id2tag = {tag2id[k]: k for k in tag2id.keys()}

    steps = (len(data) + opt.batch_size - 1) // opt.batch_size
    data_interator = enumerate(data_loader)
    t = trange(steps)
    p_entRel_t = []
    dev_entRel_t = []
    pred_tags = []
    true_tags = []
    with torch.no_grad():
        for i in t:
            idx, data = next(data_interator)
            sens, true_tag = list(map(lambda x: torch.LongTensor(x), data[:2]))
            dev_entRel = data[-1]
            if opt.use_gpu:
                sens = sens.cuda()
            p_tags, all_out = model(sens, None, None)
            if 'crf' not in opt.model.lower():
                p_tags = torch.max(p_tags, 2)[1]
            if opt.use_gpu:
                if 'crf' not in opt.model.lower():
                    p_tags = p_tags.cpu()
            p_entRel_t.extend(all_out)
            dev_entRel_t.extend(dev_entRel)
            true_tags.extend(true_tag.tolist())
            if 'crf' not in opt.model.lower():
                pred_tags.extend(p_tags.tolist())
            else:
                pred_tags.extend(p_tags)

    if opt.case == 0:
        data_path = opt.dev_data_dir
    elif opt.case == 1:
        data_path = opt.test1_data_dir
    else:
        data_path = opt.test2_data_dir
    json_data = load_data(data_path)[:len(true_tags)]
    # assert len(json_data) == len(p_entRel_t)
    predict_data = utils.get_text_spolist(opt, p_entRel_t, json_data)
    if opt.case == 0:
        p, r, f = eval_file(predict_data, opt.dev_data_dir)
        print("predict res: pre:{};rel:{};f1:{}".format(p, r, f))
    with open('out/pred_out', 'w') as f:
        for p_data in predict_data:
            f.write(json.dumps(p_data, ensure_ascii=False) + '\n')

    if opt.case == 0:
        dev_data = utils.get_text_spolist(opt, dev_entRel_t, json_data)
        p, r, f = eval_file(dev_data, opt.dev_data_dir)
        print("origin dev res: pre:{};rel:{};f1:{}".format(p, r, f))
        with open('./out/true_out', 'w') as f:
            for p_data in dev_data:
                f.write(json.dumps(p_data, ensure_ascii=False) + '\n')
        utils.write_tags(opt, true_tags, pred_tags, json_data, './out/tag_out',
                         id2tag)
Ejemplo n.º 3
0
            g_entRel_t.extend(g_entRel)
            p_entRel_t.extend(all_out)
        # 测试单纯的位置对应准确率
    assert len(g_entRel_t) == len(p_entRel_t)
    p_t, r_t, f_t = f1_score_ent_rel(g_entRel_t, p_entRel_t)
    logging.info("epoch {}; POS: pre: {}; rel: {}; f1: {}".format(
        epoch, p_t, r_t, f_t))

    # 测试实际转换为文字的准确率
    if case == 'dev':
        data_path = opt.dev_data_dir
    else:
        data_path = opt.train_data_dir
    json_data = load_data(data_path)
    # assert len(json_data) == len(p_entRel_t)
    predict_data = utils.get_text_spolist(opt, p_entRel_t, json_data)
    p, r, f = eval_file(predict_data, data_path)
    logging.info("epoch {}; REL: pre:{};rel:{};f1:{}".format(epoch, p, r, f))

    assert len(g_tags) == len(p_tags)
    p, r, f = f1_score(goldens, predicts)
    logging.info("epoch {}; NER: pre: {}; rel: {}; f1: {}".format(
        epoch, p, r, f))


def tofile(**kwargs):
    opt.parse(kwargs)
    if opt.use_gpu:
        torch.cuda.set_device(opt.gpu_id)
    # 2 model
    model = getattr(models, opt.model)(opt)