def evaluate(args):
    pipeline = build_pipeline(schema=args.schema,
                              segmenter_name=args.segmenter_name,
                              use_gpu=args.use_gpu)
    cdtb = CDTB(args.data,
                "TRAIN",
                "VALIDATE",
                "TEST",
                ctb_dir=args.ctb_dir,
                preprocess=True,
                cache_dir=args.cache_dir)
    golds = list(filter(lambda d: d.root_relation(), chain(*cdtb.test)))
    parses = []

    if args.use_gold_edu:
        logger.info("evaluation with gold edu segmentation")
    else:
        logger.info("evaluation with auto edu segmentation")

    for para in tqdm(golds, desc="parsing", unit=" para"):
        if args.use_gold_edu:
            edus = []
            for edu in para.edus():
                edu_copy = EDU([TEXT(edu.text)])
                setattr(edu_copy, "words", edu.words)
                setattr(edu_copy, "tags", edu.tags)
                edus.append(edu_copy)
        else:
            sentences = []
            for sentence in para.sentences():
                if list(sentence.iterfind(node_type_filter(EDU))):
                    copy_sentence = Sentence([TEXT([sentence.text])])
                    if hasattr(sentence, "words"):
                        setattr(copy_sentence, "words", sentence.words)
                    if hasattr(sentence, "tags"):
                        setattr(copy_sentence, "tags", sentence.tags)
                    setattr(copy_sentence, "parse", cdtb.ctb[sentence.sid])
                    sentences.append(copy_sentence)
            para = pipeline.cut_edu(Paragraph(sentences))
            edus = []
            for edu in para.edus():
                edu_copy = EDU([TEXT(edu.text)])
                setattr(edu_copy, "words", edu.words)
                setattr(edu_copy, "tags", edu.tags)
                edus.append(edu_copy)
        parse = pipeline.parse(Paragraph(edus))
        parses.append(parse)

    # edu score
    scores = edu_eval(golds, parses)
    logger.info("EDU segmentation scores:")
    logger.info(gen_edu_report(scores))

    # parser score
    cdtb_macro_scores = eval.parse_eval(parses, golds, average="macro")
    logger.info("CDTB macro (strict) scores:")
    logger.info(eval.gen_parse_report(*cdtb_macro_scores))

    # nuclear scores
    nuclear_scores = eval.nuclear_eval(parses, golds)
    logger.info("nuclear scores:")
    logger.info(eval.gen_category_report(nuclear_scores))

    # relation scores
    ctype_scores, ftype_scores = eval.relation_eval(parses, golds)
    logger.info("coarse relation scores:")
    logger.info(eval.gen_category_report(ctype_scores))
    logger.info("fine relation scores:")
    logger.info(eval.gen_category_report(ftype_scores))

    # height eval
    height_scores = eval.height_eval(parses, golds)
    logger.info("structure precision by node height:")
    logger.info(eval.gen_height_report(height_scores))
def main(args):
    random.seed(args.seed)
    torch.random.manual_seed(args.seed)
    np.random.seed(args.seed)

    logger.info("args:" + str(args))
    # load dataset
    cdtb = CDTB(args.data,
                "TRAIN",
                "VALIDATE",
                "TEST",
                ctb_dir=args.ctb_dir,
                preprocess=True,
                cache_dir=args.cache_dir)
    word_vocab, pos_vocab = build_vocab(cdtb.train)
    instances, tags = gen_train_instances(cdtb.train)
    tag_label = Label("tag", Counter(chain(*tags)))
    trainset = numericalize(instances, tags, word_vocab, pos_vocab, tag_label)

    # build model
    model = RNNSegmenterModel(hidden_size=args.hidden_size,
                              dropout=args.dropout,
                              rnn_layers=args.rnn_layers,
                              word_vocab=word_vocab,
                              pos_vocab=pos_vocab,
                              tag_label=tag_label,
                              pos_size=args.pos_size,
                              pretrained=args.pretrained,
                              w2v_freeze=args.w2v_freeze,
                              use_gpu=args.use_gpu)
    if args.use_gpu:
        model.cuda()
    logger.info(model)

    # train
    step = 0
    best_model_f1 = 0
    wait_count = 0
    optimizer = optim.Adam(model.parameters(),
                           lr=args.lr,
                           weight_decay=args.l2)
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer,
                                                     factor=0.1,
                                                     patience=3)
    for nepoch in range(1, args.epoch + 1):
        batch_iter = gen_batch_iter(trainset,
                                    args.batch_size,
                                    use_gpu=args.use_gpu)
        for nbatch, (inputs, target) in enumerate(batch_iter, start=1):
            step += 1
            model.train()
            optimizer.zero_grad()
            loss = model.loss(inputs, target)
            loss.backward()
            optimizer.step()
            if nbatch > 0 and nbatch % args.log_every == 0:
                logger.info(
                    "step %d, patient %d, lr %f, epoch %d, batch %d, train loss %.4f"
                    % (step, wait_count, get_lr(optimizer), nepoch, nbatch,
                       loss.item()))
        # model selection
        score = evaluate(cdtb.validate, model)
        f1 = score[-1]
        scheduler.step(f1, nepoch)
        logger.info("evaluation score:")
        logger.info("\n" + gen_edu_report(score))
        if f1 > best_model_f1:
            wait_count = 0
            best_model_f1 = f1
            logger.info("save new best model to %s" % args.model_save)
            with open(args.model_save, "wb+") as model_fd:
                torch.save(model, model_fd)
            logger.info("test on new best model...")
            test_score = evaluate(cdtb.test, model)
            logger.info("test score:")
            logger.info("\n" + gen_edu_report(test_score))
        else:
            wait_count += 1
            if wait_count > args.patient:
                logger.info("early stopping...")
                break

    with open(args.model_save, "rb") as model_fd:
        best_model = torch.load(model_fd)
    test_score = evaluate(cdtb.test, best_model)
    logger.info("test score on final best model:")
    logger.info("\n" + gen_edu_report(test_score))
Пример #3
0
logger = logging.getLogger("test rnn segmenter")

if __name__ == '__main__':
    logging.basicConfig(level=logging.INFO)
    with open("data/models/segmenter.rnn.model", "rb") as model_fd:
        model = torch.load(model_fd, map_location='cpu')
        model.use_gpu = False
        model.eval()
    segmenter = RNNSegmenter(model)
    cdtb = CDTB("data/CDTB",
                "TRAIN",
                "VALIDATE",
                "TEST",
                ctb_dir="data/CTB",
                preprocess=True,
                cache_dir="data/cache")

    golds = []
    segs = []
    for paragraph in tqdm.tqdm(chain(*cdtb.test), desc="segmenting"):
        seged_sents = []
        for sentence in paragraph.sentences():
            # make sure sentence has edus
            if list(sentence.iterfind(node_type_filter(EDU))):
                seged_sents.append(Sentence(segmenter.cut_edu(sentence)))
        if seged_sents:
            segs.append(Paragraph(seged_sents))
            golds.append(paragraph)
    scores = edu_eval(segs, golds)
    logger.info(gen_edu_report(scores))
Пример #4
0
def evaluate(args):
    with open("pub/models/segmenter.svm.model", "rb") as segmenter_fd:
        segmenter_model = pickle.load(segmenter_fd)
    with open("pub/models/treebuilder.partptr.model", "rb") as parser_fd:
        parser_model = torch.load(parser_fd, map_location="cpu")
        parser_model.use_gpu = False
        parser_model.eval()
    segmenter = SVMSegmenter(segmenter_model)
    parser = PartPtrParser(parser_model)

    cdtb = CDTB(args.data,
                "TRAIN",
                "VALIDATE",
                "TEST",
                ctb_dir=args.ctb_dir,
                preprocess=True,
                cache_dir=args.cache_dir)
    golds = list(filter(lambda d: d.root_relation(), chain(*cdtb.test)))
    parses = []

    if args.use_gold_edu:
        logger.info("evaluation with gold edu segmentation")
    else:
        logger.info("evaluation with auto edu segmentation")

    for para in tqdm(golds, desc="parsing", unit=" para"):
        if args.use_gold_edu:
            edus = []
            for edu in para.edus():
                edu_copy = EDU([TEXT(edu.text)])
                setattr(edu_copy, "words", edu.words)
                setattr(edu_copy, "tags", edu.tags)
                edus.append(edu_copy)
            parse = parser.parse(Paragraph(edus))
            parses.append(parse)
        else:
            edus = []
            for sentence in para.sentences():
                if list(sentence.iterfind(node_type_filter(EDU))):
                    setattr(sentence, "parse", cdtb.ctb[sentence.sid])
                    edus.extend(segmenter.cut_edu(sentence))
            parse = parser.parse(Paragraph(edus))
            parses.append(parse)

    # edu score
    scores = edu_eval(golds, parses)
    logger.info("EDU segmentation scores:")
    logger.info(gen_edu_report(scores))

    # parser score
    cdtb_macro_scores = eval.parse_eval(parses, golds, average="macro")
    logger.info("CDTB macro (strict) scores:")
    logger.info(eval.gen_parse_report(*cdtb_macro_scores))

    # nuclear scores
    nuclear_scores = eval.nuclear_eval(parses, golds)
    logger.info("nuclear scores:")
    logger.info(eval.gen_category_report(nuclear_scores))

    # relation scores
    ctype_scores, ftype_scores = eval.relation_eval(parses, golds)
    logger.info("coarse relation scores:")
    logger.info(eval.gen_category_report(ctype_scores))
    logger.info("fine relation scores:")
    logger.info(eval.gen_category_report(ftype_scores))