示例#1
0
    def instance_to_datapoint(graph: AugmentedAST,
                              instance,
                              data_encoder: FITBCharCNNDataEncoder,
                              max_nodes_per_graph: int = None):
        var_use, other_uses = instance

        fill_in_flag = data_encoder.fill_in_flag
        internal_node_flag = data_encoder.internal_node_flag

        subgraph = graph.get_containing_subgraph((var_use, ) + other_uses,
                                                 max_nodes_per_graph)

        # Flag the variable to be filled in, and prune its subgraph
        subgraph.nodes[var_use]['identifier'] = fill_in_flag
        edges_to_prune = subgraph.all_adjacent_edges(var_use,
                                                     too_useful_edge_types)
        simplified_edges_to_prune = [(e[0], e[1], e[3]['type'])
                                     for e in edges_to_prune]
        for edge_type in edge_types_to_rewire:
            rewirees_in = []
            rewirees_out = []
            for edge in simplified_edges_to_prune:
                if edge[2] == edge_type and edge[0] != edge[1]:
                    if edge[0] == var_use:
                        rewirees_out.append(edge)
                    elif edge[1] == var_use:
                        rewirees_in.append(edge)
            for e_in, e_out in itertools.product(rewirees_in, rewirees_out):
                subgraph.add_edge(e_in[0], e_out[1], type=edge_type)
        subgraph._graph.remove_edges_from(edges_to_prune)
        for node in other_uses:
            subgraph.nodes[node]['other_use'] = True

        # Assemble node types, node names, and label
        subgraph.node_ids_to_ints_from_0()
        node_types = []
        node_names = []
        label = []
        for node, data in sorted(subgraph.nodes):
            if 'other_use' in data.keys() and data['other_use'] is True:
                label.append(node)
            if subgraph.is_variable_node(node):
                if data['identifier'] == fill_in_flag:
                    node_types.append([fill_in_flag])
                else:
                    node_types.append(
                        sorted(list(set(re.split(r'[,.]',
                                                 data['reference'])))))
                node_names.append(data['identifier'])
            else:
                node_types.append([data['type']])
                node_names.append(internal_node_flag)

        return data_encoder.DataPoint(subgraph, node_types, node_names, label,
                                      graph.origin_file,
                                      data_encoder.encoder_hash)
示例#2
0
    def instance_to_datapoint(graph: AugmentedAST,
                              instance,
                              data_encoder: FITBGSCVocabDataEncoder,
                              max_nodes_per_graph: int = None):
        var_use, other_uses = instance

        fill_in_flag = data_encoder.fill_in_flag
        internal_node_flag = data_encoder.internal_node_flag

        subgraph = graph.get_containing_subgraph((var_use, ) + other_uses,
                                                 max_nodes_per_graph)

        # Flag the variable to be filled in, and prune its subgraph
        subgraph.nodes[var_use]['identifier'] = fill_in_flag
        edges_to_prune = subgraph.all_adjacent_edges(var_use,
                                                     too_useful_edge_types)
        subgraph._graph.remove_edges_from(edges_to_prune)
        for node in other_uses:
            subgraph.nodes[node]['other_use'] = True

        # Remove any disconnected subtoken nodes (they could be unfair hints)
        for node, data in list(subgraph.nodes):
            if data['type'] == data_encoder.subtoken_flag and subgraph._graph.degree(
                    node) == 0:
                subgraph._graph.remove_node(node)

        # Assemble node types, node names, and label
        subgraph.node_ids_to_ints_from_0()
        node_types = []
        node_names = []
        label = []
        for node, data in sorted(subgraph.nodes):
            if 'other_use' in data.keys() and data['other_use'] is True:
                label.append(node)
            if subgraph.is_variable_node(node):
                if data['identifier'] == fill_in_flag:
                    node_types.append([fill_in_flag])
                else:
                    node_types.append(
                        sorted(list(set(re.split(r'[,.]',
                                                 data['reference'])))))
                node_names.append(data['identifier'])
            else:
                node_types.append([data['type']])
                if data['type'] == data_encoder.subtoken_flag:
                    node_names.append(data['identifier'])
                else:
                    node_names.append(internal_node_flag)

        return data_encoder.DataPoint(subgraph, node_types, node_names, label,
                                      graph.origin_file,
                                      data_encoder.encoder_hash)
    def instance_to_datapoint(graph: AugmentedAST,
                              instance,
                              data_encoder: VarNamingClosedVocabDataEncoder,
                              max_nodes_per_graph: int = None):
        var_name, locs = instance

        name_me_flag = data_encoder.name_me_flag
        internal_node_flag = data_encoder.internal_node_flag

        subgraph = graph.get_containing_subgraph(locs, max_nodes_per_graph)

        # Flag the variables to be named
        for loc in locs:
            subgraph.nodes[loc]['identifier'] = name_me_flag
            edges_to_prune = subgraph.all_adjacent_edges(
                loc, too_useful_edge_types)
            subgraph._graph.remove_edges_from(edges_to_prune)

        # Remove any disconnected subtoken nodes (they could come from subtokens that are only in the name, and thus be unfair hints)
        for node, data in list(subgraph.nodes):
            if data['type'] == data_encoder.subtoken_flag and subgraph._graph.degree(
                    node) == 0:
                subgraph._graph.remove_node(node)

        # Assemble node types, node names, and label
        subgraph.node_ids_to_ints_from_0()
        node_types = []
        node_names = []
        for node, data in sorted(subgraph.nodes):
            if subgraph.is_variable_node(node):
                node_types.append(
                    sorted(list(set(re.split(r'[,.]', data['reference'])))))
                node_names.append(data['identifier'])
            else:
                node_types.append([data['type']])
                if data['type'] == data_encoder.subtoken_flag:
                    node_names.append(data['identifier'])
                else:
                    node_names.append(internal_node_flag)

        label = data_encoder.name_to_subtokens(var_name)

        return data_encoder.DataPoint(subgraph, node_types, node_names,
                                      var_name, label, graph.origin_file,
                                      data_encoder.encoder_hash)