def test_all_adjacent(self):
     ast = AugmentedAST(nx.MultiDiGraph(nx.complete_graph(10)),
                        parent_types_of_variable_nodes=frozenset(['test_type']))
     nx.set_node_attributes(ast._graph, 'SimpleName', 'type')
     nx.set_node_attributes(ast._graph, 'test_type', 'parentType')
     nx.set_edge_attributes(ast._graph, 'LAST_READ', 'type')
     for node in ast.nodes:
         self.assertCountEqual(ast.get_all_variable_usages(node[0]), ast._graph.nodes)
 def test_get_adjacency_matrix(self, g):
     pg = AugmentedAST(nx.MultiDiGraph(g),
                       parent_types_of_variable_nodes=frozenset(
                           ['test_type']))
     nx.set_edge_attributes(pg._graph, 'test_type', 'type')
     adj_mat = pg.get_adjacency_matrix('test_type')
     np.testing.assert_equal(
         adj_mat.todense(),
         nx.to_scipy_sparse_matrix(g, format='coo', dtype='int8').todense())
 def setUp(self):
     self.test_gml_dir = os.path.join(test_s3shared_path, 'test_dataset', 'repositories')
     self.augmented_asts = []
     for file in os.listdir(self.test_gml_dir):
         if file[-4:] == '.gml':
             gml = os.path.abspath(os.path.join(self.test_gml_dir, file))
             self.augmented_asts.append(AugmentedAST.from_gml(gml, parent_types_of_variable_nodes))
예제 #4
0
 def from_gml_files(cls, gml_files: List[str]):
     task = cls()
     task.origin_files = gml_files
     logger.info('Creating {} from gml files'.format(cls.__name__))
     for gml_file in tqdm(gml_files):
         ast = AugmentedAST.from_gml(gml_file,
                                     task.parent_types_of_variable_nodes)
         task.add_AugmentedAST(ast)
     return task
예제 #5
0
    def instance_to_datapoint(graph: AugmentedAST,
                              instance,
                              data_encoder: FITBCharCNNDataEncoder,
                              max_nodes_per_graph: int = None):
        var_use, other_uses = instance

        fill_in_flag = data_encoder.fill_in_flag
        internal_node_flag = data_encoder.internal_node_flag

        subgraph = graph.get_containing_subgraph((var_use, ) + other_uses,
                                                 max_nodes_per_graph)

        # Flag the variable to be filled in, and prune its subgraph
        subgraph.nodes[var_use]['identifier'] = fill_in_flag
        edges_to_prune = subgraph.all_adjacent_edges(var_use,
                                                     too_useful_edge_types)
        simplified_edges_to_prune = [(e[0], e[1], e[3]['type'])
                                     for e in edges_to_prune]
        for edge_type in edge_types_to_rewire:
            rewirees_in = []
            rewirees_out = []
            for edge in simplified_edges_to_prune:
                if edge[2] == edge_type and edge[0] != edge[1]:
                    if edge[0] == var_use:
                        rewirees_out.append(edge)
                    elif edge[1] == var_use:
                        rewirees_in.append(edge)
            for e_in, e_out in itertools.product(rewirees_in, rewirees_out):
                subgraph.add_edge(e_in[0], e_out[1], type=edge_type)
        subgraph._graph.remove_edges_from(edges_to_prune)
        for node in other_uses:
            subgraph.nodes[node]['other_use'] = True

        # Assemble node types, node names, and label
        subgraph.node_ids_to_ints_from_0()
        node_types = []
        node_names = []
        label = []
        for node, data in sorted(subgraph.nodes):
            if 'other_use' in data.keys() and data['other_use'] is True:
                label.append(node)
            if subgraph.is_variable_node(node):
                if data['identifier'] == fill_in_flag:
                    node_types.append([fill_in_flag])
                else:
                    node_types.append(
                        sorted(list(set(re.split(r'[,.]',
                                                 data['reference'])))))
                node_names.append(data['identifier'])
            else:
                node_types.append([data['type']])
                node_names.append(internal_node_flag)

        return data_encoder.DataPoint(subgraph, node_types, node_names, label,
                                      graph.origin_file,
                                      data_encoder.encoder_hash)
예제 #6
0
 def add_AugmentedAST(self, ast: AugmentedAST) -> None:
     instances = []
     included_nodes = set()
     for node, data in ast.nodes_that_represent_variables:
         if node not in included_nodes:
             real_var_name = data['identifier']
             locations = ast.get_all_variable_usages(node)
             for loc in locations:
                 assert ast[loc]['identifier'] == real_var_name
             included_nodes.update(locations)
             instances.append((real_var_name, locations))
     self.graphs_and_instances.append((ast, tuple(instances)))
예제 #7
0
    def instance_to_datapoint(graph: AugmentedAST,
                              instance,
                              data_encoder: FITBGSCVocabDataEncoder,
                              max_nodes_per_graph: int = None):
        var_use, other_uses = instance

        fill_in_flag = data_encoder.fill_in_flag
        internal_node_flag = data_encoder.internal_node_flag

        subgraph = graph.get_containing_subgraph((var_use, ) + other_uses,
                                                 max_nodes_per_graph)

        # Flag the variable to be filled in, and prune its subgraph
        subgraph.nodes[var_use]['identifier'] = fill_in_flag
        edges_to_prune = subgraph.all_adjacent_edges(var_use,
                                                     too_useful_edge_types)
        subgraph._graph.remove_edges_from(edges_to_prune)
        for node in other_uses:
            subgraph.nodes[node]['other_use'] = True

        # Remove any disconnected subtoken nodes (they could be unfair hints)
        for node, data in list(subgraph.nodes):
            if data['type'] == data_encoder.subtoken_flag and subgraph._graph.degree(
                    node) == 0:
                subgraph._graph.remove_node(node)

        # Assemble node types, node names, and label
        subgraph.node_ids_to_ints_from_0()
        node_types = []
        node_names = []
        label = []
        for node, data in sorted(subgraph.nodes):
            if 'other_use' in data.keys() and data['other_use'] is True:
                label.append(node)
            if subgraph.is_variable_node(node):
                if data['identifier'] == fill_in_flag:
                    node_types.append([fill_in_flag])
                else:
                    node_types.append(
                        sorted(list(set(re.split(r'[,.]',
                                                 data['reference'])))))
                node_names.append(data['identifier'])
            else:
                node_types.append([data['type']])
                if data['type'] == data_encoder.subtoken_flag:
                    node_names.append(data['identifier'])
                else:
                    node_names.append(internal_node_flag)

        return data_encoder.DataPoint(subgraph, node_types, node_names, label,
                                      graph.origin_file,
                                      data_encoder.encoder_hash)
예제 #8
0
 def add_AugmentedAST(self, ast: AugmentedAST) -> None:
     instances = []
     visited_nodes = set()
     for node, data in ast.nodes_that_represent_variables:
         if node not in visited_nodes:
             location_list = ast.get_all_variable_usages(node)
             if len(location_list) > 1:
                 for var_use in location_list:
                     idx = location_list.index(var_use)
                     other_uses = tuple(location_list[:idx] +
                                        location_list[idx + 1:])
                     instances.append((var_use, other_uses))
             visited_nodes.update(location_list)
     self.graphs_and_instances.append((ast, tuple(instances)))
    def instance_to_datapoint(graph: AugmentedAST,
                              instance,
                              data_encoder: VarNamingClosedVocabDataEncoder,
                              max_nodes_per_graph: int = None):
        var_name, locs = instance

        name_me_flag = data_encoder.name_me_flag
        internal_node_flag = data_encoder.internal_node_flag

        subgraph = graph.get_containing_subgraph(locs, max_nodes_per_graph)

        # Flag the variables to be named
        for loc in locs:
            subgraph.nodes[loc]['identifier'] = name_me_flag
            edges_to_prune = subgraph.all_adjacent_edges(
                loc, too_useful_edge_types)
            subgraph._graph.remove_edges_from(edges_to_prune)

        # Remove any disconnected subtoken nodes (they could come from subtokens that are only in the name, and thus be unfair hints)
        for node, data in list(subgraph.nodes):
            if data['type'] == data_encoder.subtoken_flag and subgraph._graph.degree(
                    node) == 0:
                subgraph._graph.remove_node(node)

        # Assemble node types, node names, and label
        subgraph.node_ids_to_ints_from_0()
        node_types = []
        node_names = []
        for node, data in sorted(subgraph.nodes):
            if subgraph.is_variable_node(node):
                node_types.append(
                    sorted(list(set(re.split(r'[,.]', data['reference'])))))
                node_names.append(data['identifier'])
            else:
                node_types.append([data['type']])
                if data['type'] == data_encoder.subtoken_flag:
                    node_names.append(data['identifier'])
                else:
                    node_names.append(internal_node_flag)

        label = data_encoder.name_to_subtokens(var_name)

        return data_encoder.DataPoint(subgraph, node_types, node_names,
                                      var_name, label, graph.origin_file,
                                      data_encoder.encoder_hash)