Example #1
0
def max_f1_no_ambiguous(gold_sm: Graph, pred_sm: Graph,
                        is_blurring_label: bool,
                        gold_triples: Set[Tuple[int, bytes, Union[bytes,
                                                                  int]]]):
    alignment = align_graph(
        gold_sm, pred_sm, DataNodeMode.IGNORE_LABEL_DATA_NODE
        if is_blurring_label else DataNodeMode.NO_TOUCH)
    if len(alignment['_bijections']) != 1:
        return None, None, None

    bijection = alignment['_bijections'][0]
    link2label = {}

    # build example from this candidate model
    for node in pred_sm.iter_class_nodes():
        outgoing_links = list(node.iter_outgoing_links())
        numbered_links = numbering_link_labels(outgoing_links)

        for link in outgoing_links:
            dest_node = link.get_target_node()
            if dest_node.is_class_node():
                dest_label = bijection.prime2x[link.target_id]
            else:
                dest_label = get_numbered_link_label(
                    "DATA_NODE", numbered_links[
                        link.id]) if is_blurring_label else dest_node.label

            triple = (bijection.prime2x[link.source_id], link.label,
                      dest_label)
            link2label[link.id] = triple in gold_triples

    return link2label, bijection.prime2x, alignment['f1']
Example #2
0
def preserved_structure(
    gold_sm: Graph, pred_sm: Graph, gold_triples: Set[Tuple[int, bytes,
                                                            Union[bytes, int]]]
) -> Tuple[Dict[int, bool], Dict[int, Optional[int]]]:
    alignment = align_graph(gold_sm, pred_sm, DataNodeMode.IGNORE_DATA_NODE)
    bijections = alignment['_bijections']
    best_bijection = None
    best_link2label = None
    best_score = -1

    # build example from this candidate model
    for bijection in bijections:
        link2label = {}
        for node in pred_sm.iter_class_nodes():
            outgoing_links = list(node.iter_outgoing_links())
            for link in outgoing_links:
                dest_node = link.get_target_node()
                if dest_node.is_class_node():
                    dest_label = bijection.prime2x[link.target_id]
                else:
                    dest_label = dest_node.label

                triple = (bijection.prime2x[link.source_id], link.label,
                          dest_label)
                link2label[link.id] = triple in gold_triples
        score = sum(link2label.values())
        if score > best_score:
            best_score = score
            best_bijection = bijection
            best_link2label = link2label

    return best_link2label, best_bijection.prime2x
Example #3
0
    def feature_extraction(self, graph: Graph,
                           stype_score: Dict[int, Optional[float]]):
        node2features = {}
        for node in graph.iter_class_nodes():
            prob_data_nodes = _(node.iter_outgoing_links()) \
                .imap(lambda x: x.get_target_node()) \
                .ifilter(lambda x: x.is_data_node()) \
                .reduce(lambda a, b: a + (stype_score[b.id] or 0), 0)

            similar_nodes = graph.iter_nodes_by_label(node.label)
            minimum_merged_cost = min((get_merged_cost(node, similar_node,
                                                       self.multival_predicate)
                                       for similar_node in similar_nodes))

            node2features[node.id] = [('prob_data_nodes', prob_data_nodes),
                                      ('minimum_merged_cost',
                                       minimum_merged_cost)]
        return node2features
Example #4
0
def get_gold_triples(
        gold_sm: Graph,
        is_blurring_label: bool) -> Set[Tuple[int, bytes, Union[bytes, int]]]:
    gold_triples = set()
    for node in gold_sm.iter_class_nodes():
        outgoing_links: List[GraphLink] = list(node.iter_outgoing_links())
        numbered_links = numbering_link_labels(outgoing_links)

        for link in outgoing_links:
            dest_node = link.get_target_node()
            if dest_node.is_class_node():
                dest_label = link.target_id
            else:
                dest_label = get_numbered_link_label(
                    "DATA_NODE", numbered_links[
                        link.id]) if is_blurring_label else dest_node.label

            triple = (link.source_id, link.label, dest_label)
            gold_triples.add(triple)
    return gold_triples
Example #5
0
def preserved_structure_with_heuristic(gold_sm: Graph, pred_sm: Graph, gold_triples: Set[Tuple[int, bytes, Union[bytes, int]]],
                         gold_semantic_types: Dict[bytes, Tuple[bytes, bytes]]) -> Tuple[Dict[int, bool], Dict[int, Optional[int]]]:
    alignment = align_graph(gold_sm, pred_sm, DataNodeMode.IGNORE_DATA_NODE)
    bijections = alignment['_bijections']
    best_bijection = None
    best_link2label = None
    best_score = -1

    # build example from this candidate model
    for bijection in bijections:
        link2label = {}
        for node in pred_sm.iter_class_nodes():
            outgoing_links = list(node.iter_outgoing_links())
            for link in outgoing_links:
                dest_node = link.get_target_node()
                if dest_node.is_class_node():
                    dest_label = bijection.prime2x[link.target_id]
                else:
                    dest_label = dest_node.label

                triple = (bijection.prime2x[link.source_id], link.label, dest_label)
                link2label[link.id] = triple in gold_triples

        score = sum(link2label.values())
        if score > best_score:
            best_score = score
            best_bijection = bijection
            best_link2label = link2label

    # heuristic
    for dnode in pred_sm.iter_data_nodes():
        dlink = dnode.get_first_incoming_link()
        stype = (dlink.get_source_node().label, dlink.label)

        if stype == gold_semantic_types[dnode.label]:
            # the semantic types is correct
            # if it is a splitting node issues, label as incorrect
            best_link2label[dlink.id] = True

    return best_link2label, best_bijection.prime2x
Example #6
0
    def compute_prob(self, sm_id: str, g: Graph) -> Dict[int, float]:
        link2features = {}
        graph_observed_mounts = set()
        graph_observed_class_lbls = set()
        name2col_idx = self.name2cols[sm_id]

        parent_nodes: Dict[int, Tuple[GraphNode, Tuple[bytes, bytes]]] = {}
        for dnode in g.iter_data_nodes():
            dlink = dnode.get_first_incoming_link()
            col_idx = name2col_idx[dnode.label]

            if dlink.source_id not in parent_nodes:
                pnode = dlink.get_source_node()
                plink = pnode.get_first_incoming_link()
                if plink is None:
                    continue

                pstype = (plink.get_source_node().label, plink.label)

                # add pstype to observed mounts
                graph_observed_mounts.add(pstype)
                parent_nodes[dlink.source_id] = (pnode, plink, pstype, [dlink], [col_idx])
            else:
                parent_nodes[dlink.source_id][-2].append(dlink)
                parent_nodes[dlink.source_id][-1].append(col_idx)

        for pnode in g.iter_class_nodes():
            graph_observed_class_lbls.add(pnode.label)

        for pnode, plink, pstype, dlinks, col_idxs in parent_nodes.values():
            # map from possible mount => scores of each columns
            parent_stype_score: Dict[Tuple[bytes, bytes], List[float]] = {}

            # filter out all possible mounts that present in the graph (except the current one),
            # but the domain of the mounts are not in the graph
            possible_mounts = [
                possible_mount for possible_mount in self.possible_mounts.get(pnode.label, [])
                if not ((possible_mount in graph_observed_mounts and possible_mount != pstype)
                        or possible_mount[0] not in graph_observed_class_lbls)
            ]

            if len(possible_mounts) > 1:
                # the number only make sense if there are another place to mount this object to
                for possible_mount in possible_mounts:
                    spo = (possible_mount[0], possible_mount[1], pnode.label)
                    scores = []
                    for i, col_idx in enumerate(col_idxs):
                        # stype = (pnode.label, dlinks[i].label)
                        refcols = self.column_stype_index[pnode.label][spo]
                        best_score = max(
                            self.similarity_matrix[col_idx, self.stype_db.col2idx[refcol.id]] for refcol in refcols)
                        scores.append(best_score)
                    parent_stype_score[possible_mount] = scores

                aggregation_score = {mount: sum(scores) / len(scores) for mount, scores in parent_stype_score.items()}
            else:
                aggregation_score = {}

            if pstype not in aggregation_score:
                link2features[plink.id] = None
            else:
                link2features[plink.id] = aggregation_score.pop(pstype) - max(aggregation_score.values())

        return link2features
Example #7
0
def prepare_args(gold_sm: Graph, pred_sm: Graph,
                 data_node_mode: DataNodeMode) -> List[PairLabelGroup]:
    """Prepare data for evaluation

        + data_node_mode = 0, mean we don't touch anything (note that the label of data_node must be unique)
        + data_node_mode = 1, mean we ignore label of data node (convert it to DATA_NODE, DATA_NODE2 if there are duplication columns)
        + data_node_mode = 2, mean we ignore data node
    """
    def convert_graph(graph: Graph):
        node_index: Dict[int, Node] = {}

        for v in graph.iter_nodes():
            type = Node.DATA_NODE if v.is_data_node() else Node.CLASS_NODE
            node_index[v.id] = Node(v.id, type, v.label)

        for l in graph.iter_links():
            if data_node_mode == 2:
                if node_index[l.target_id].type == Node.DATA_NODE:
                    # ignore data node
                    continue

            link = Link(l.id, l.label, l.source_id, l.target_id)
            Node.add_outgoing_link(node_index[l.source_id], link)
            Node.add_incoming_link(node_index[l.target_id], link)

        if data_node_mode == DataNodeMode.IGNORE_DATA_NODE:
            for v2 in [
                    v for v in node_index.values() if v.type == Node.DATA_NODE
            ]:
                del node_index[v2.id]

        if data_node_mode == DataNodeMode.IGNORE_LABEL_DATA_NODE:
            # we convert label of node to DATA_NODE
            leaf_source_nodes: Set[Node] = set()
            for v in [
                    v for v in node_index.values() if v.type == Node.DATA_NODE
            ]:
                assert len(v.incoming_links) == 1
                link = v.incoming_links[0]
                source = node_index[link.source_id]
                leaf_source_nodes.add(source)

            for node in leaf_source_nodes:
                link_label_count = {}
                for link in node.outgoing_links:
                    target = node_index[link.target_id]
                    if target.type == Node.DATA_NODE:
                        if link.label not in link_label_count:
                            link_label_count[link.label] = 0

                        link_label_count[link.label] += 1
                        target.label = 'DATA_NODE' + str(
                            link_label_count[link.label])

        return node_index

    label2nodes = {}
    gold_nodes = convert_graph(gold_sm)
    pred_nodes = convert_graph(pred_sm)

    for v in gold_sm.iter_class_nodes():
        if v.label not in label2nodes:
            label2nodes[v.label] = ([], [])
        label2nodes[v.label][0].append(gold_nodes[v.id])
    for v in pred_sm.iter_class_nodes():
        if v.label not in label2nodes:
            label2nodes[v.label] = ([], [])
        label2nodes[v.label][1].append(pred_nodes[v.id])

    return [
        PairLabelGroup(label, LabelGroup(g[0]), LabelGroup(g[1]))
        for label, g in label2nodes.items()
    ]