def instance_to_datapoint(graph: AugmentedAST, instance, data_encoder: FITBCharCNNDataEncoder, max_nodes_per_graph: int = None): var_use, other_uses = instance fill_in_flag = data_encoder.fill_in_flag internal_node_flag = data_encoder.internal_node_flag subgraph = graph.get_containing_subgraph((var_use, ) + other_uses, max_nodes_per_graph) # Flag the variable to be filled in, and prune its subgraph subgraph.nodes[var_use]['identifier'] = fill_in_flag edges_to_prune = subgraph.all_adjacent_edges(var_use, too_useful_edge_types) simplified_edges_to_prune = [(e[0], e[1], e[3]['type']) for e in edges_to_prune] for edge_type in edge_types_to_rewire: rewirees_in = [] rewirees_out = [] for edge in simplified_edges_to_prune: if edge[2] == edge_type and edge[0] != edge[1]: if edge[0] == var_use: rewirees_out.append(edge) elif edge[1] == var_use: rewirees_in.append(edge) for e_in, e_out in itertools.product(rewirees_in, rewirees_out): subgraph.add_edge(e_in[0], e_out[1], type=edge_type) subgraph._graph.remove_edges_from(edges_to_prune) for node in other_uses: subgraph.nodes[node]['other_use'] = True # Assemble node types, node names, and label subgraph.node_ids_to_ints_from_0() node_types = [] node_names = [] label = [] for node, data in sorted(subgraph.nodes): if 'other_use' in data.keys() and data['other_use'] is True: label.append(node) if subgraph.is_variable_node(node): if data['identifier'] == fill_in_flag: node_types.append([fill_in_flag]) else: node_types.append( sorted(list(set(re.split(r'[,.]', data['reference']))))) node_names.append(data['identifier']) else: node_types.append([data['type']]) node_names.append(internal_node_flag) return data_encoder.DataPoint(subgraph, node_types, node_names, label, graph.origin_file, data_encoder.encoder_hash)
def instance_to_datapoint(graph: AugmentedAST, instance, data_encoder: FITBGSCVocabDataEncoder, max_nodes_per_graph: int = None): var_use, other_uses = instance fill_in_flag = data_encoder.fill_in_flag internal_node_flag = data_encoder.internal_node_flag subgraph = graph.get_containing_subgraph((var_use, ) + other_uses, max_nodes_per_graph) # Flag the variable to be filled in, and prune its subgraph subgraph.nodes[var_use]['identifier'] = fill_in_flag edges_to_prune = subgraph.all_adjacent_edges(var_use, too_useful_edge_types) subgraph._graph.remove_edges_from(edges_to_prune) for node in other_uses: subgraph.nodes[node]['other_use'] = True # Remove any disconnected subtoken nodes (they could be unfair hints) for node, data in list(subgraph.nodes): if data['type'] == data_encoder.subtoken_flag and subgraph._graph.degree( node) == 0: subgraph._graph.remove_node(node) # Assemble node types, node names, and label subgraph.node_ids_to_ints_from_0() node_types = [] node_names = [] label = [] for node, data in sorted(subgraph.nodes): if 'other_use' in data.keys() and data['other_use'] is True: label.append(node) if subgraph.is_variable_node(node): if data['identifier'] == fill_in_flag: node_types.append([fill_in_flag]) else: node_types.append( sorted(list(set(re.split(r'[,.]', data['reference']))))) node_names.append(data['identifier']) else: node_types.append([data['type']]) if data['type'] == data_encoder.subtoken_flag: node_names.append(data['identifier']) else: node_names.append(internal_node_flag) return data_encoder.DataPoint(subgraph, node_types, node_names, label, graph.origin_file, data_encoder.encoder_hash)
def instance_to_datapoint(graph: AugmentedAST, instance, data_encoder: VarNamingClosedVocabDataEncoder, max_nodes_per_graph: int = None): var_name, locs = instance name_me_flag = data_encoder.name_me_flag internal_node_flag = data_encoder.internal_node_flag subgraph = graph.get_containing_subgraph(locs, max_nodes_per_graph) # Flag the variables to be named for loc in locs: subgraph.nodes[loc]['identifier'] = name_me_flag edges_to_prune = subgraph.all_adjacent_edges( loc, too_useful_edge_types) subgraph._graph.remove_edges_from(edges_to_prune) # Remove any disconnected subtoken nodes (they could come from subtokens that are only in the name, and thus be unfair hints) for node, data in list(subgraph.nodes): if data['type'] == data_encoder.subtoken_flag and subgraph._graph.degree( node) == 0: subgraph._graph.remove_node(node) # Assemble node types, node names, and label subgraph.node_ids_to_ints_from_0() node_types = [] node_names = [] for node, data in sorted(subgraph.nodes): if subgraph.is_variable_node(node): node_types.append( sorted(list(set(re.split(r'[,.]', data['reference']))))) node_names.append(data['identifier']) else: node_types.append([data['type']]) if data['type'] == data_encoder.subtoken_flag: node_names.append(data['identifier']) else: node_names.append(internal_node_flag) label = data_encoder.name_to_subtokens(var_name) return data_encoder.DataPoint(subgraph, node_types, node_names, var_name, label, graph.origin_file, data_encoder.encoder_hash)