Esempio n. 1
0
 def cut_sent(self, text, sid=None):
     last_cut = 0
     sentences = []
     for i in range(0, len(text) - 1):
         if text[i] in self._eos:
             sentences.append(Sentence([TEXT(text[last_cut:i + 1])]))
             last_cut = i + 1
     if last_cut < len(text) - 1:
         sentences.append(Sentence([TEXT(text[last_cut:])]))
     return sentences
    def cut_edu(self, sent):
        if (not hasattr(sent, "words")) or (not hasattr(sent, "tags")):
            if hasattr(sent, "parse"):
                parse = getattr(sent, "parse")
            else:
                parse = self.parser.parse(sent.text)
            children = list(
                parse.subtrees(
                    lambda t: t.height() == 2 and t.label() != '-NONE-'))
            setattr(sent, "words", [child[0] for child in children])
            setattr(sent, "tags", [child.label() for child in children])

        if not hasattr(sent, "dependency"):
            dep = self.dep_parser.parse(sent.words)
            setattr(sent, "dependency", dep)

        word_ids = [self.model.word_vocab[word] for word in sent.words]
        pos_ids = [self.model.pos_vocab[pos] for pos in sent.tags]
        word_ids = torch.tensor([word_ids]).long()
        pos_ids = torch.tensor([pos_ids]).long()
        graph = torch.zeros(
            (1, word_ids.size(1), word_ids.size(1), len(self.model.gcn_vocab)),
            dtype=torch.uint8)
        for i, token in enumerate(sent.dependency):
            graph[0, i, i, self.model.gcn_vocab['self']] = 1
            graph[0, i, token.head - 1, self.model.gcn_vocab['head']] = 1
            graph[0, token.head - 1, i, self.model.gcn_vocab['dep']] = 1
        if self.model.use_gpu:
            word_ids = word_ids.cuda()
            pos_ids = pos_ids.cuda()
            graph = graph.cuda()
        pred = self.model(word_ids, pos_ids, graph).squeeze(0)
        labels = [self.model.tag_label.id2label[t] for t in pred.argmax(-1)]

        edus = []
        last_edu_words = []
        last_edu_tags = []
        for word, pos, label in zip(sent.words, sent.tags, labels):
            last_edu_words.append(word)
            last_edu_tags.append(pos)
            if label == "B":
                text = "".join(last_edu_words)
                edu = EDU([TEXT(text)])
                setattr(edu, "words", last_edu_words)
                setattr(edu, "tags", last_edu_tags)
                edus.append(edu)
                last_edu_words = []
                last_edu_tags = []
        if last_edu_words:
            text = "".join(last_edu_words)
            edu = EDU([TEXT(text)])
            setattr(edu, "words", last_edu_words)
            setattr(edu, "tags", last_edu_tags)
            edus.append(edu)
        return edus
    def parse(self, para: Paragraph) -> Paragraph:
        edus = []
        for edu in para.edus():
            edu_copy = EDU([TEXT(edu.text)])
            setattr(edu_copy, "words", edu.words)
            setattr(edu_copy, "tags", edu.tags)
            edus.append(edu_copy)
        if len(edus) < 2:
            return para

        trans_probs = []
        state = self.init_state(edus)
        while not self.terminate(state):
            logits = self.model(state)
            valid = self.valid_trans(state)
            for i, (trans, _, _) in enumerate(self.model.trans_label.id2label):
                if trans not in valid:
                    logits[i] = -INF
            probs = logits.softmax(dim=0)
            trans_probs.append(probs)
            next_trans, _, _ = self.model.trans_label.id2label[probs.argmax(
                dim=0)]
            if next_trans == SHIFT:
                state = self.model.shift(state)
            elif next_trans == REDUCE:
                state = self.model.reduce(state)
            else:
                raise ValueError("unexpected transition occured")
        parsed = self.build_tree(edus, trans_probs)
        return parsed
Esempio n. 4
0
    def cut_edu(self, sent):
        if (not hasattr(sent, "words")) or (not hasattr(sent, "tags")):
            if hasattr(sent, "parse"):
                parse = getattr(sent, "parse")
            else:
                parse = self.parser.parse(sent.text)
            children = list(
                parse.subtrees(
                    lambda t: t.height() == 2 and t.label() != '-NONE-'))
            setattr(sent, "words", [child[0] for child in children])
            setattr(sent, "tags", [child.label() for child in children])
        word_ids = [self.model.word_vocab[word] for word in sent.words]
        pos_ids = [self.model.pos_vocab[pos] for pos in sent.tags]
        word_ids = torch.tensor([word_ids]).long()
        pos_ids = torch.tensor([pos_ids]).long()
        if self.model.use_gpu:
            word_ids = word_ids.cuda()
            pos_ids = pos_ids.cuda()
        pred = self.model(word_ids, pos_ids).squeeze(0)
        labels = [self.model.tag_label.id2label[t] for t in pred.argmax(-1)]

        edus = []
        last_edu_words = []
        last_edu_tags = []
        for word, pos, label in zip(sent.words, sent.tags, labels):
            last_edu_words.append(word)
            last_edu_tags.append(pos)
            if label == "B":
                text = "".join(last_edu_words)
                edu = EDU([TEXT(text)])
                setattr(edu, "words", last_edu_words)
                setattr(edu, "tags", last_edu_tags)
                edus.append(edu)
                last_edu_words = []
                last_edu_tags = []
        if last_edu_words:
            text = "".join(last_edu_words)
            edu = EDU([TEXT(text)])
            setattr(edu, "words", last_edu_words)
            setattr(edu, "tags", last_edu_tags)
            edus.append(edu)
        return edus
Esempio n. 5
0
def parse_and_eval(dataset, model):
    parser = ShiftReduceParser(model)
    golds = list(filter(lambda d: d.root_relation(), chain(*dataset)))
    num_instances = len(golds)
    strips = []
    for paragraph in golds:
        edus = []
        for edu in paragraph.edus():
            edu_copy = EDU([TEXT(edu.text)])
            setattr(edu_copy, "words", edu.words)
            setattr(edu_copy, "tags", edu.tags)
            edus.append(edu_copy)
        strips.append(Paragraph(edus))

    parses = []
    for strip in strips:
        parses.append(parser.parse(strip))
    return num_instances, parse_eval(parses, golds)
def evaluate(args):
    pipeline = build_pipeline(schema=args.schema,
                              segmenter_name=args.segmenter_name,
                              use_gpu=args.use_gpu)
    cdtb = CDTB(args.data,
                "TRAIN",
                "VALIDATE",
                "TEST",
                ctb_dir=args.ctb_dir,
                preprocess=True,
                cache_dir=args.cache_dir)
    golds = list(filter(lambda d: d.root_relation(), chain(*cdtb.test)))
    parses = []

    if args.use_gold_edu:
        logger.info("evaluation with gold edu segmentation")
    else:
        logger.info("evaluation with auto edu segmentation")

    for para in tqdm(golds, desc="parsing", unit=" para"):
        if args.use_gold_edu:
            edus = []
            for edu in para.edus():
                edu_copy = EDU([TEXT(edu.text)])
                setattr(edu_copy, "words", edu.words)
                setattr(edu_copy, "tags", edu.tags)
                edus.append(edu_copy)
        else:
            sentences = []
            for sentence in para.sentences():
                if list(sentence.iterfind(node_type_filter(EDU))):
                    copy_sentence = Sentence([TEXT([sentence.text])])
                    if hasattr(sentence, "words"):
                        setattr(copy_sentence, "words", sentence.words)
                    if hasattr(sentence, "tags"):
                        setattr(copy_sentence, "tags", sentence.tags)
                    setattr(copy_sentence, "parse", cdtb.ctb[sentence.sid])
                    sentences.append(copy_sentence)
            para = pipeline.cut_edu(Paragraph(sentences))
            edus = []
            for edu in para.edus():
                edu_copy = EDU([TEXT(edu.text)])
                setattr(edu_copy, "words", edu.words)
                setattr(edu_copy, "tags", edu.tags)
                edus.append(edu_copy)
        parse = pipeline.parse(Paragraph(edus))
        parses.append(parse)

    # edu score
    scores = edu_eval(golds, parses)
    logger.info("EDU segmentation scores:")
    logger.info(gen_edu_report(scores))

    # parser score
    cdtb_macro_scores = eval.parse_eval(parses, golds, average="macro")
    logger.info("CDTB macro (strict) scores:")
    logger.info(eval.gen_parse_report(*cdtb_macro_scores))

    # nuclear scores
    nuclear_scores = eval.nuclear_eval(parses, golds)
    logger.info("nuclear scores:")
    logger.info(eval.gen_category_report(nuclear_scores))

    # relation scores
    ctype_scores, ftype_scores = eval.relation_eval(parses, golds)
    logger.info("coarse relation scores:")
    logger.info(eval.gen_category_report(ctype_scores))
    logger.info("fine relation scores:")
    logger.info(eval.gen_category_report(ftype_scores))

    # height eval
    height_scores = eval.height_eval(parses, golds)
    logger.info("structure precision by node height:")
    logger.info(eval.gen_height_report(height_scores))
Esempio n. 7
0
def evaluate(args):
    with open("pub/models/segmenter.svm.model", "rb") as segmenter_fd:
        segmenter_model = pickle.load(segmenter_fd)
    with open("pub/models/treebuilder.partptr.model", "rb") as parser_fd:
        parser_model = torch.load(parser_fd, map_location="cpu")
        parser_model.use_gpu = False
        parser_model.eval()
    segmenter = SVMSegmenter(segmenter_model)
    parser = PartPtrParser(parser_model)

    cdtb = CDTB(args.data,
                "TRAIN",
                "VALIDATE",
                "TEST",
                ctb_dir=args.ctb_dir,
                preprocess=True,
                cache_dir=args.cache_dir)
    golds = list(filter(lambda d: d.root_relation(), chain(*cdtb.test)))
    parses = []

    if args.use_gold_edu:
        logger.info("evaluation with gold edu segmentation")
    else:
        logger.info("evaluation with auto edu segmentation")

    for para in tqdm(golds, desc="parsing", unit=" para"):
        if args.use_gold_edu:
            edus = []
            for edu in para.edus():
                edu_copy = EDU([TEXT(edu.text)])
                setattr(edu_copy, "words", edu.words)
                setattr(edu_copy, "tags", edu.tags)
                edus.append(edu_copy)
            parse = parser.parse(Paragraph(edus))
            parses.append(parse)
        else:
            edus = []
            for sentence in para.sentences():
                if list(sentence.iterfind(node_type_filter(EDU))):
                    setattr(sentence, "parse", cdtb.ctb[sentence.sid])
                    edus.extend(segmenter.cut_edu(sentence))
            parse = parser.parse(Paragraph(edus))
            parses.append(parse)

    # edu score
    scores = edu_eval(golds, parses)
    logger.info("EDU segmentation scores:")
    logger.info(gen_edu_report(scores))

    # parser score
    cdtb_macro_scores = eval.parse_eval(parses, golds, average="macro")
    logger.info("CDTB macro (strict) scores:")
    logger.info(eval.gen_parse_report(*cdtb_macro_scores))

    # nuclear scores
    nuclear_scores = eval.nuclear_eval(parses, golds)
    logger.info("nuclear scores:")
    logger.info(eval.gen_category_report(nuclear_scores))

    # relation scores
    ctype_scores, ftype_scores = eval.relation_eval(parses, golds)
    logger.info("coarse relation scores:")
    logger.info(eval.gen_category_report(ctype_scores))
    logger.info("fine relation scores:")
    logger.info(eval.gen_category_report(ftype_scores))