コード例 #1
0
ファイル: eval_proof.py プロジェクト: swarnaHub/PRover
def get_node_edge_indices(proofs, sentence_scramble, nfact, nrule):
    all_node_indices, all_edge_indices = [], []
    for proof in proofs.split("OR"):
        node_indices = []
        edge_indices = []

        if "FAIL" in proof:
            nodes, edges = get_proof_graph_with_fail(proof)
        else:
            nodes, edges = get_proof_graph(proof)

        component_index_map = {}
        for (i, index) in enumerate(sentence_scramble):
            if index <= nfact:
                component = "triple" + str(index)
            else:
                component = "rule" + str(index - nfact)
            component_index_map[component] = i
        component_index_map["NAF"] = nfact + nrule

        for node in nodes:
            index = component_index_map[node]
            node_indices.append(index)

        edges = list(set(edges))
        for edge in edges:
            start_index = component_index_map[edge[0]]
            end_index = component_index_map[edge[1]]
            edge_indices.append((start_index, end_index))

        all_node_indices.append(node_indices)
        all_edge_indices.append(edge_indices)

    return all_node_indices, all_edge_indices
コード例 #2
0
ファイル: eval_natlang.py プロジェクト: swarnaHub/PRover
def get_node_edge_indices(proofs, natlang_mapping):
    all_node_indices, all_edge_indices = [], []
    for proof in proofs.split("OR"):
        node_indices = []
        edge_indices = []

        if "FAIL" in proof:
            nodes, edges = get_proof_graph_with_fail(proof)
        else:
            nodes, edges = get_proof_graph(proof)


        for node in nodes:
            sent_id = int(natlang_mapping[node][0].replace("sent", "")) - 1
            node_indices.append(sent_id)

        edges = list(set(edges))
        for edge in edges:
            start_index = int(natlang_mapping[edge[0]][0].replace("sent", "")) - 1
            end_index = int(natlang_mapping[edge[1]][0].replace("sent", "")) - 1
            edge_indices.append((start_index, end_index))

        all_node_indices.append(node_indices)
        all_edge_indices.append(edge_indices)

    return all_node_indices, all_edge_indices
コード例 #3
0
ファイル: utils.py プロジェクト: swarnaHub/PRover
    def _get_node_edge_label_unconstrained(self, proofs, sentence_scramble, nfact, nrule):
        proof = proofs.split("OR")[0]
        #print(proof)
        node_label = [0] * (nfact + nrule + 1)
        edge_label = np.zeros((nfact+nrule+1, nfact+nrule+1), dtype=int)

        if "FAIL" in proof:
            nodes, edges = get_proof_graph_with_fail(proof)
        else:
            nodes, edges = get_proof_graph(proof)
        #print(nodes)
        #print(edges)

        component_index_map = {}
        for (i, index) in enumerate(sentence_scramble):
            if index <= nfact:
                component = "triple" + str(index)
            else:
                component = "rule" + str(index-nfact)
            component_index_map[component] = i

        for node in nodes:
            if node != "NAF":
                index = component_index_map[node]
            else:
                index = nfact+nrule
            node_label[index] = 1

        edges = list(set(edges))
        for edge in edges:
            if edge[0] != "NAF":
                start_index = component_index_map[edge[0]]
            else:
                start_index = nfact+nrule
            if edge[1] != "NAF":
                end_index = component_index_map[edge[1]]
            else:
                end_index = nfact+nrule

            edge_label[start_index][end_index] = 1

        return node_label, list(edge_label.flatten())
コード例 #4
0
ファイル: utils.py プロジェクト: swarnaHub/PRover
    def _get_node_edge_label_constrained(self, proofs, sentence_scramble, nfact, nrule):
        proof = proofs.split("OR")[0]
        #print(proof)
        node_label = [0] * (nfact + nrule + 1)
        edge_label = np.zeros((nfact + nrule + 1, nfact + nrule + 1), dtype=int)

        if "FAIL" in proof:
            nodes, edges = get_proof_graph_with_fail(proof)
        else:
            nodes, edges = get_proof_graph(proof)
        # print(nodes)
        # print(edges)

        component_index_map = {}
        for (i, index) in enumerate(sentence_scramble):
            if index <= nfact:
                component = "triple" + str(index)
            else:
                component = "rule" + str(index - nfact)
            component_index_map[component] = i
        component_index_map["NAF"] = nfact+nrule

        for node in nodes:
            index = component_index_map[node]
            node_label[index] = 1

        edges = list(set(edges))
        for edge in edges:
            start_index = component_index_map[edge[0]]
            end_index = component_index_map[edge[1]]
            edge_label[start_index][end_index] = 1

        # Mask impossible edges
        for i in range(len(edge_label)):
            for j in range(len(edge_label)):
                # Ignore diagonal
                if i == j:
                    edge_label[i][j] = -100
                    continue

                # Ignore edges between non-nodes
                if node_label[i] == 0 or node_label[j] == 0:
                    edge_label[i][j] = -100
                    continue

                is_fact_start = False
                is_fact_end = False
                if i == len(edge_label)-1 or sentence_scramble[i] <= nfact:
                    is_fact_start = True
                if j == len(edge_label)-1 or sentence_scramble[j] <= nfact:
                    is_fact_end = True

                # No edge between fact/NAF -> fact/NAF
                if is_fact_start and is_fact_end:
                    edge_label[i][j] = -100
                    continue

                # No edge between Rule -> fact/NAF
                if not is_fact_start and is_fact_end:
                    edge_label[i][j] = -100
                    continue

        return node_label, list(edge_label.flatten())
コード例 #5
0
ファイル: utils_natlang.py プロジェクト: swarnaHub/PRover
    def _get_node_edge_label_natlang(self, id, proofs, natlang_mappings):
        natlang_mapping = natlang_mappings[id.split("-")[1]]
        new_sents = {}
        for rf_id, (sid, orig_sents) in natlang_mapping.items():
            if rf_id.startswith("triple"):
                new_sents[sid] = "fact"
            else:
                new_sents[sid] = "rule"

        proof = proofs.split("OR")[0]
        #print(proof)
        node_label = [0] * (len(new_sents) + 1)
        edge_label = np.zeros((len(new_sents) + 1, len(new_sents) + 1),
                              dtype=int)

        if "FAIL" in proof:
            nodes, edges = get_proof_graph_with_fail(proof)
        else:
            nodes, edges = get_proof_graph(proof)

        #print(edges)

        for node in nodes:
            sent_id = int(natlang_mapping[node][0].replace("sent", "")) - 1
            node_label[sent_id] = 1

        for edge in edges:
            start_sent_id = int(natlang_mapping[edge[0]][0].replace(
                "sent", "")) - 1
            end_sent_id = int(natlang_mapping[edge[1]][0].replace("sent",
                                                                  "")) - 1
            edge_label[start_sent_id][end_sent_id] = 1

        # Edge masking
        for i in range(len(edge_label)):
            for j in range(len(edge_label)):
                # Ignore diagonal
                if i == j:
                    edge_label[i][j] = -100
                    continue

                # Ignore edges between non-nodes
                if node_label[i] == 0 or node_label[j] == 0:
                    edge_label[i][j] = -100
                    continue

                # Ignore edges between Fact -> Fact and Rule -> Fact
                is_fact_start = False
                is_fact_end = False

                if i == len(edge_label) - 1 or new_sents["sent" +
                                                         str(i + 1)] == "fact":
                    is_fact_start = True
                if j == len(edge_label) - 1 or new_sents["sent" +
                                                         str(j + 1)] == "fact":
                    is_fact_end = True

                # No edge between fact/NAF -> fact/NAF
                if is_fact_start and is_fact_end:
                    edge_label[i][j] = -100
                    continue

                # No edge between Rule -> fact/NAF
                if not is_fact_start and is_fact_end:
                    edge_label[i][j] = -100
                    continue

        return node_label, list(edge_label.flatten())
コード例 #6
0
                meta_data = meta_record["questions"]["Q" + str(j + 1)]
                proofs = meta_data["proofs"]
                if "CWA" in proofs:
                    continue
                question = question["text"]
                all_samples += len(sentence_scramble)

                assert len(sentence_scramble) == len(context_sents)

                #print(proofs)
                proofs = proofs.split("OR")
                #print(len(proofs))

                all_proof_nodes = []
                for proof in proofs:
                    nodes, _ = get_proof_graph(proof)
                    all_proof_nodes.append(nodes)

                critical_count = 0
                irrelevant_count = 0
                for k in range(len(sentence_scramble)):
                    curr_sentence = context_sents[k]
                    new_context = context.replace(curr_sentence, "")
                    if is_node_in_all_proofs(all_proof_nodes,
                                             index_component_map[k]):
                        critical_count += 1
                        new_sample = {
                            "id":
                            record_id + "_" + qid + "_" + "c" +
                            str(critical_count),
                            "context":