Exemple #1
0
def from_dict(tree_dict: dict, tokens: List[str]):
    token_nodes = [
        QueryTree.Node(NodeType.TOKEN, id) for id in range(len(tokens))
    ]
    used_nodes = set()

    def node_from_dict(tree_dict):
        node_type = enum_for_str(tree_dict['type'])

        if node_type == NodeType.TOKEN:
            node = token_nodes[tree_dict['id']]
            used_nodes.add(node)
        else:
            node = QueryTree.Node(node_type)

        if 'children' in tree_dict:
            for child in tree_dict['children']:
                node.children.append(node_from_dict(child))

        return node

    root = node_from_dict(tree_dict['tree'])

    # Aggregate unused tokens
    if len(used_nodes) < len(token_nodes):
        unused_container_node = QueryTree.Node(NodeType.UNUSED)
        root.children.append(unused_container_node)

        for node in token_nodes:
            if node not in used_nodes:
                unused_container_node.children.append(node)

    tree = QueryTree(root, tokens)
    return tree
def normalize_treeqald(source_path):
    '''
    Convert a tree-annotated question dataset to a relation extraction dataset (from its relation nodes)
    '''

    file_name = os.path.basename(source_path)
    file_name = file_name.split('.')[0]
    output_path = os.path.join(r'datasets\relation_extraction\tree-qald\data',
                               file_name + '_normalized.json')
    if os.path.exists(output_path): return output_path

    with open(source_path, 'r', encoding='utf-8') as input_file:
        trees = json.load(input_file)
        trees = [
            question for question in trees
            if 'root' in question and question['root']
        ]
        trees = list(map(lambda question: QueryTree.from_dict(question),
                         trees))
        trees = list(filter(lambda x: SYNTAX_CHECKER.validate(x), trees))

    sequences = []
    for tree in trees:
        relation_nodes = tree.root.collect(RELATION_NODE_TYPES)
        for node in relation_nodes:
            if node.kb_resources:
                sequences.append(
                    tree.generate_relation_extraction_sequence(node))

    with open(output_path, 'w', encoding='utf-8') as output_file:
        json.dump(sequences, output_file)

    return output_path
Exemple #3
0
    def test_correct_tree_1(self):
        checker = SyntaxChecker(GRAMMAR_FILE_PATH)
        tree_dict = {
            'root': {
                'type':
                'ROOT',
                'children': [{
                    'type':
                    'COUNT',
                    'children': [{
                        'type':
                        'PROPERTY',
                        'children': [{
                            'type':
                            'ENTITY',
                            'children': [{
                                'type': 'TOKEN',
                                'id': 1
                            }, {
                                'type': 'TOKEN',
                                'id': 2
                            }]
                        }]
                    }]
                }]
            },
            'tokens': ['a', 'b', 'c', 'd', 'e', 'f']
        }

        query_tree = QueryTree.from_dict(tree_dict)
        self.assertTrue(checker.validate(query_tree))
def validate(tree):
    return SYNTAX_CHECKER.validate(
        QueryTree.from_dict({
            'root':
            tree,
            'id':
            state['examples'][state['current_example_index']][0],
            'tokens':
            state['tokens']
        }))
Exemple #5
0
    def node_from_dict(tree_dict):
        node_type = enum_for_str(tree_dict['type'])

        if node_type == NodeType.TOKEN:
            node = token_nodes[tree_dict['id']]
            used_nodes.add(node)
        else:
            node = QueryTree.Node(node_type)

        if 'children' in tree_dict:
            for child in tree_dict['children']:
                node.children.append(node_from_dict(child))

        return node
Exemple #6
0
    def __decode_labels(self, tokens):
        '''
        Process the output of the NCRFPP algorithm.
        '''
        parenthesized_trees = []
        with open(MODEL_FILE_PATH + ".output",
                  encoding="utf-8") as ncrfpp_output_file:
            lines = ncrfpp_output_file.readlines()
            lines = [line.strip().split(' ') for line in lines[1:]]
            for prediction_index in range(0, TREE_CANDIDATES_N_BEST):
                try:
                    sentence = []
                    pred = []
                    for token_index, line_tokens in enumerate(lines):
                        # The tree2labels algorithm is not aware of 'UNUSED' tokens. We first have to eliminate them, so the sequence_to_parenthesis algorithm
                        # reconstructs the tree corectly
                        if len(line_tokens
                               ) >= 2 and 'UNUSED' != line_tokens[1]:
                            sentence.append(
                                (str(token_index - 1),
                                 'TOKEN'))  # -1 because of the -BOS- line
                            pred.append(line_tokens[1 + prediction_index])

                    # The main decoding process of tree2labels repo
                    parenthesized_trees.extend(
                        sequence_to_parenthesis([sentence], [pred]))
                except:
                    print("Error decoding tree")

        candidates = []

        for tree in parenthesized_trees:

            def tree2dict(tree: Tree):
                result = {}

                result['type'] = tree.label()
                children = [
                    tree2dict(t) if isinstance(t, Tree) else t for t in tree
                ]
                if tree.label() == 'TOKEN':
                    result['id'] = int(children[0])
                elif children:
                    result['children'] = children

                return result

            tree = tree.strip()
            nltk_tree: Tree = Tree.fromstring(tree,
                                              remove_empty_top_bracketing=True)
            root = tree2dict(nltk_tree)
            dict_tree = {'root': root, 'tokens': tokens}

            # If first tokens are 'who', 'where', 'when' or 'how many" add it as a type to the answer node.
            try:
                if tokens[0].lower() in {'who', 'where', 'when'}:
                    type_node = {
                        'type': NodeType.TYPE.value,
                        'children': [{
                            'type': NodeType.TOKEN.value,
                            'id': 0
                        }]
                    }
                    root['children'][0]['children'].append(type_node)
                if (tokens[0].lower(), tokens[1].lower()) == (
                        'how', 'many'
                ) and root['children'][0]['type'] != NodeType.COUNT.value:
                    type_node = {
                        'type':
                        NodeType.TYPE.value,
                        'children': [{
                            'type': NodeType.TOKEN.value,
                            'id': 0
                        }, {
                            'type': NodeType.TOKEN.value,
                            'id': 1
                        }]
                    }
                    root['children'][0]['children'].append(type_node)
            except:
                pass  # Tree is probabily invalid
            query_tree = QueryTree.from_dict(dict_tree)
            candidates.append(query_tree)

        return candidates
Exemple #7
0
                filter(lambda x: x.type == NodeType.ENTITY, node.children))
            entity_set = list(
                filter(
                    lambda x: x.type != NodeType.ENTITY and x.type in
                    ENTITY_SET_TYPES, node.children))
            if len(entities) > 0 and len(entity_set) > 0:
                return True
        return False

    return validates_grammar(question)


with open(INPUT_FILE_PATH, 'r', encoding='utf-8') as input_file:
    questions = json.load(input_file)
    questions = [
        question for question in questions
        if 'root' in question and question['root']
    ]
    questions = list(
        map(lambda question: QueryTree.from_dict(question), questions))
    filtered_questions = list(filter(filter_predicate, questions))
    print('{}/{} passed the filter_predicate!'.format(len(filtered_questions),
                                                      len(questions)))

with open(OUTPUT_FILE_PATH, 'w', encoding='utf-8') as output_file:
    questions = [
        question.to_serializable(SerializationFormat.HIERARCHICAL_DICT)
        for question in filtered_questions
    ]
    json.dump(questions, output_file)
Exemple #8
0
def prepare_input():

    with open(QUESTION_SET_FILE_PATH, 'r', encoding='utf-8') as input_file:
        questions = json.load(input_file)
        questions = list(filter(lambda question: question['root'], questions))
        questions = list(
            map(lambda question: QueryTree.from_dict(question), questions))
        questions = list(
            filter(lambda question: not contains_bad_exists(question),
                   questions))
        for question in questions:
            question.remove_children_of_type(NodeType.VARIABLE)

        random.shuffle(questions)
        split_point = int(TRAIN_TEST_RATIO * len(questions))

        train_questions = questions[:split_point]
        test_questions = questions[split_point:]

        with open(INTERMEDIATE_TRAIN_FILE_PATH, 'w',
                  encoding='utf-8') as output_file:
            for question in train_questions:
                line = question.to_serializable(
                    SerializationFormat.PREFIX_PARANTHESES) + '|' + ' '.join(
                        question.tokens) + '\n'
                output_file.write(line)

        with open(INTERMEDIATE_TEST_FILE_PATH, 'w',
                  encoding='utf-8') as output_file:
            for question in test_questions:
                line = question.to_serializable(
                    SerializationFormat.PREFIX_PARANTHESES) + '|' + ' '.join(
                        question.tokens) + '\n'
                output_file.write(line)

    args_binarized = False
    args_os = True
    args_root_label = False
    args_encode_unaries = True
    args_abs_top = 3
    args_abs_neg_gap = 2
    args_join_char = '~'
    args_split_char = '@'

    train_sequences, train_leaf_unary_chains = transform_split(
        INTERMEDIATE_TRAIN_FILE_PATH, args_binarized, args_os, args_root_label,
        args_encode_unaries, args_abs_top, args_abs_neg_gap, args_join_char,
        args_split_char)

    dev_sequences, dev_leaf_unary_chains = transform_split(
        INTERMEDIATE_TEST_FILE_PATH, args_binarized, args_os, args_root_label,
        args_encode_unaries, args_abs_top, args_abs_neg_gap, args_join_char,
        args_split_char)
    ext = "seq_lu"
    feats_dict = {}
    write_linearized_trees(
        "/".join([
            INTERMEDIATE_OUTPUT_DIRECTORY_PATH, DATASET_NAME + "-train." + ext
        ]), train_sequences)

    write_linearized_trees(
        "/".join(
            [INTERMEDIATE_OUTPUT_DIRECTORY_PATH,
             DATASET_NAME + "-dev." + ext]), dev_sequences)

    test_sequences, test_leaf_unary_chains = transform_split(
        INTERMEDIATE_TEST_FILE_PATH, args_binarized, args_os, args_root_label,
        args_encode_unaries, args_abs_top, args_abs_neg_gap, args_join_char,
        args_split_char)

    write_linearized_trees(
        "/".join([
            INTERMEDIATE_OUTPUT_DIRECTORY_PATH, DATASET_NAME + "-test." + ext
        ]), test_sequences)
def generate_query(query_tree_dict: dict):
    generator = QueryGenerator()
    return generator(QueryTree.from_dict(query_tree_dict))