def act(self,defState, defNode, eps): """ parse defstate to the input format for feed_dict evaulate network to give action """ if (defState.nodes[defNode]["isDef"] != 1): raise ValueError("def location doesn't match") """ if episod is low or random value is les than epsilon take random action """ if (eps < OBSERVEEPS) or (random.random() < EPSILON): validActions = list(defState.out_edges([defNode])) a = self.random.randint(0, len(validActions),size=1)[0] return validActions[a] gin = self._gtmp2intmp(defState) test_values = self.sess.run({ "outputs":self.output_ops_tr }, feed_dict={self.inputPh: utils_np.networkxs_to_graphs_tuple([gin])}) outg = utils_np.graphs_tuple_to_networkxs(test_values["outputs"][-1])[0] outg = nx.DiGraph(outg) self.outg = outg validActions = list(defState.out_edges([defNode])) qdict = dict() for e in validActions: qdict[e] = outg.get_edge_data(*e)["features"][0] return max(qdict, key=qdict.get)
def test_graphs_tuple_to_networkxs(self, none_fields): if "nodes" in none_fields: for graph in self.graphs_dicts_in: graph["n_node"] = graph["nodes"].shape[0] graphs = utils_np.data_dicts_to_graphs_tuple(self.graphs_dicts_in) graphs = graphs.map(lambda _: None, none_fields) graph_nxs = utils_np.graphs_tuple_to_networkxs(graphs) for data_dict, graph_nx in zip(self.graphs_dicts_out, graph_nxs): if "globals" in none_fields: self.assertEqual(None, graph_nx.graph["features"]) else: self.assertAllClose(data_dict["globals"], graph_nx.graph["features"]) nodes_data = graph_nx.nodes(data=True) for i, (v, (j, n)) in enumerate(zip(data_dict["nodes"], nodes_data)): self.assertEqual(i, j) if "nodes" in none_fields: self.assertEqual(None, n["features"]) else: self.assertAllClose(v, n["features"]) edges_data = sorted(graph_nx.edges(data=True), key=lambda x: x[2]["index"]) for v, (_, _, e) in zip(data_dict["edges"], edges_data): if "edges" in none_fields: self.assertEqual(None, e["features"]) else: self.assertAllClose(v, e["features"]) for r, s, (i, j, _) in zip(data_dict["receivers"], data_dict["senders"], edges_data): self.assertEqual(s, i) self.assertEqual(r, j)
def plot_graphs_tuple_np(graphs_tuple): networkx_graphs = utils_np.graphs_tuple_to_networkxs(graphs_tuple) num_graphs = len(networkx_graphs) _, axes = plt.subplots(1, num_graphs, figsize=(5 * num_graphs, 5)) if num_graphs == 1: axes = axes, for graph, ax in zip(networkx_graphs, axes): plot_graph_networkx(graph, ax)
def plot_graph_structure(graphs_tuple): """ Plot the graph structure by converting the input graph into networkx graph """ networkx_graphs = graphs_tuple_to_networkxs(graphs_tuple) num_graphs = len(networkx_graphs) _, axes = plt.subplots(1, num_graphs, figsize=(5 * num_graphs, 5)) if num_graphs == 1: axes = axes, for graph, ax in zip(networkx_graphs, axes): plot_graph_networkx(graph, ax) plt.show()
def plot_compare_graphs(graphs_tuples, labels): pos = None num_graphs = len(graphs_tuples) _, axes = plt.subplots(1, num_graphs, figsize=(5*num_graphs, 5)) if num_graphs == 1: axes = axes, pos = None for name, graphs_tuple, ax in zip(labels, graphs_tuples, axes): graph = utils_np.graphs_tuple_to_networkxs(graphs_tuple)[0] pos = plot_graph_networkx(graph, ax, pos=pos) ax.set_title(name)
def visualize_original_graph(graph_dict, file_name, use_edges=True): """ Creates a visualization of the given graph :param graph_dict: An instance of a graph dictionary :param file_name: The path to save the image :param use_edges: This parameter is not used, only exist so it has the same format as the visualize_graph function """ graphs_nx = utils_np.graphs_tuple_to_networkxs(utils_np.data_dicts_to_graphs_tuple([graph_dict])) plt.figure(1, figsize=(25, 25)) mapping0 = {i: "; ".join([str(n[0]), str(n[1])]) for i, n in enumerate(graph_dict["nodes"])} graphs_nx[0] = nx.relabel_nodes(graphs_nx[0], mapping0) nx.draw_networkx(graphs_nx[0], node_color='red') plt.savefig(file_name) nx.drawing.nx_pydot.write_dot(graphs_nx[0], "{}.dot".format(os.path.splitext(file_name)[0])) plt.show()
def plot_compare_graphs(graphs_tuples, labels, color, use_pos_node=False, pos=None): num_graphs = len(graphs_tuples) _, axes = plt.subplots(1, num_graphs, figsize=(5 * num_graphs, 5)) if num_graphs == 1: axes = axes, for name, graphs_tuple, ax in zip(labels, graphs_tuples, axes): graph = utils_np.graphs_tuple_to_networkxs(graphs_tuple)[0] plot_graph_networkx(graph, ax, colors=color, use_pos_node=use_pos_node, pos=pos) ax.set_title(name)
def draw_graph(graph, node_pos_dict, col_lims=None, is_normed=False, normfile=None): if col_lims: vmin, vmax = col_lims[0], col_lims[1] e_vmin, e_vmax = col_lims[2], col_lims[3] else: vmin, vmax = -0.5, 10 e_vmin, e_vmax = -0.5, 5 if is_normed: # Need to unnorm for plotting hf = h5py.File(normfile, 'r') edgestats = hf['edge_stats'] nodestats = hf['node_stats'] graph = unnorm_graph(graph, nodestats, edgestats) hf.close() graphs_nx = utils_np.graphs_tuple_to_networkxs(graph) nodecols = graph.nodes[:, 0] edges = graph.edges edgecols = np.zeros((len(edges), )) for i, e in enumerate(graphs_nx[0].edges): j = np.argwhere((graph.senders == e[0]) & (graph.receivers == e[1])) edgecols[i] = edges[j, 0] fig, ax = plt.subplots(figsize=(15, 15)) nx.draw(graphs_nx[0], ax=ax, pos=node_pos_dict, node_color=nodecols, edge_color=edgecols, node_size=100, cmap=plt.cm.winter, edge_cmap=plt.cm.winter, vmin=vmin, vmax=vmax, edge_vmin=e_vmin, edge_vmax=e_vmax, arrowsize=10) return fig, ax
def predicted_graphs_to_nxs(gnn_output, input_graphs, target_graphs, **kwargs): output_nxs = utils_np.graphs_tuple_to_networkxs(gnn_output) input_dds = utils_np.graphs_tuple_to_data_dicts(input_graphs) target_dds = utils_np.graphs_tuple_to_data_dicts(target_graphs) total_graphs = len(output_nxs) print("total_graphs", total_graphs) graphs = [] for ig in range(total_graphs): input_dd = input_dds[ig] target_dd = target_dds[ig] graph = data_dict_to_nx(input_dd, target_dd, **kwargs) ## update edge features with TF output for edge in graph.edges(): graph.edges[edge]['predict'] = output_nxs[ig].edges[ edge + (0, )]['features'] graphs.append(graph) return graphs
def visualize_graph(graph_dict, file_name, use_edges=True): """ Creates a visualization of the given graph using only its 1 valued nodes and edges :param graph_dict: An instance of a graph dictionary :param file_name: The path to save the image :param use_edges: Whether to take into account the value of the edges, or just the value of the nodes """ graph = {"edges": [], "senders": [], "receivers": [], "nodes": [node[:-1] for node in graph_dict["nodes"] if node[-1] >= 0.5], "globals": [1.0]} for edge, sender, receiver in zip(graph_dict["edges"], graph_dict["senders"], graph_dict["receivers"]): if (edge[-1] == 1. or not use_edges) \ and graph_dict["nodes"][sender][-1] >= 0.5 and graph_dict["nodes"][receiver][-1] >= 0.5: graph["edges"].append(edge[:-1]) graph["senders"].append(graph["nodes"].index(graph_dict["nodes"][sender][:-1])) graph["receivers"].append(graph["nodes"].index(graph_dict["nodes"][receiver][:-1])) graphs_nx = utils_np.graphs_tuple_to_networkxs(utils_np.data_dicts_to_graphs_tuple([graph])) plt.figure(1, figsize=(25, 25)) mapping0 = {i: "; ".join([str(n[0]), str(n[1])]) for i, n in enumerate(graph_dict["nodes"])} graphs_nx[0] = nx.relabel_nodes(graphs_nx[0], mapping0) nx.draw_networkx(graphs_nx[0], node_color='blue') plt.savefig(file_name) nx.drawing.nx_pydot.write_dot(graphs_nx[0], "{}.dot".format(os.path.splitext(file_name)[0])) plt.show()
def pipeline(graphs, tr_ge_split, node_types, edge_types, num_processing_steps_tr=10, num_processing_steps_ge=10, num_training_iterations=10000, continuous_attributes=None, categorical_attributes=None, type_embedding_dim=5, attr_embedding_dim=6, edge_output_size=3, node_output_size=3, output_dir=None): ############################################################ # Manipulate the graph data ############################################################ # Encode attribute values graphs = [ encode_values(graph, categorical_attributes, continuous_attributes) for graph in graphs ] indexed_graphs = [ nx.convert_node_labels_to_integers(graph, label_attribute='concept') for graph in graphs ] graphs = [duplicate_edges_in_reverse(graph) for graph in indexed_graphs] graphs = [ encode_types(graph, multidigraph_node_data_iterator, node_types) for graph in graphs ] graphs = [ encode_types(graph, multidigraph_edge_data_iterator, edge_types) for graph in graphs ] input_graphs = [create_input_graph(graph) for graph in graphs] target_graphs = [create_target_graph(graph) for graph in graphs] tr_input_graphs = input_graphs[:tr_ge_split] tr_target_graphs = target_graphs[:tr_ge_split] ge_input_graphs = input_graphs[tr_ge_split:] ge_target_graphs = target_graphs[tr_ge_split:] ############################################################ # Build and run the KGCN ############################################################ thing_embedder = ThingEmbedder(node_types, type_embedding_dim, attr_embedding_dim, categorical_attributes, continuous_attributes) role_embedder = RoleEmbedder(len(edge_types), type_embedding_dim) kgcn = KGCN(thing_embedder, role_embedder, edge_output_size=edge_output_size, node_output_size=node_output_size) learner = KGCNLearner(kgcn, num_processing_steps_tr=num_processing_steps_tr, num_processing_steps_ge=num_processing_steps_ge) train_values, test_values, tr_info = learner( tr_input_graphs, tr_target_graphs, ge_input_graphs, ge_target_graphs, num_training_iterations=num_training_iterations, log_dir=output_dir) plot_across_training(*tr_info, output_file=f'{output_dir}learning.png') plot_predictions(graphs[tr_ge_split:], test_values, num_processing_steps_ge, output_file=f'{output_dir}graph.png') logit_graphs = graphs_tuple_to_networkxs(test_values["outputs"][-1]) indexed_ge_graphs = indexed_graphs[tr_ge_split:] ge_graphs = [ apply_logits_to_graphs(graph, logit_graph) for graph, logit_graph in zip(indexed_ge_graphs, logit_graphs) ] for ge_graph in ge_graphs: for data in multidigraph_data_iterator(ge_graph): data['probabilities'] = softmax(data['logits']) data['prediction'] = int(np.argmax(data['probabilities'])) _, _, _, _, _, solveds_tr, solveds_ge = tr_info return ge_graphs, solveds_tr, solveds_ge
def pipeline(graphs, tr_ge_split, node_types, edge_types, num_processing_steps_tr=10, num_processing_steps_ge=10, num_training_iterations=10000, continuous_attributes=None, categorical_attributes=None, type_embedding_dim=5, attr_embedding_dim=6, edge_output_size=3, node_output_size=3, output_dir=None): ############################################################ # Manipulate the graph data ############################################################ # Encode attribute values for graph in graphs: for node_data in multidigraph_node_data_iterator(graph): typ = node_data['type'] if categorical_attributes is not None and typ in categorical_attributes.keys( ): # Add the integer value of the category for each categorical attribute instance category_values = categorical_attributes[typ] node_data['encoded_value'] = category_values.index( node_data['value']) elif continuous_attributes is not None and typ in continuous_attributes.keys( ): min_val, max_val = continuous_attributes[typ] node_data['encoded_value'] = (node_data['value'] - min_val) / (max_val - min_val) else: node_data['encoded_value'] = 0 for edge_data in multidigraph_edge_data_iterator(graph): edge_data['encoded_value'] = 0 indexed_graphs = [ nx.convert_node_labels_to_integers(graph, label_attribute='concept') for graph in graphs ] graphs = [duplicate_edges_in_reverse(graph) for graph in indexed_graphs] graphs = [encode_types(graph, node_types, edge_types) for graph in graphs] input_graphs = [create_input_graph(graph) for graph in graphs] target_graphs = [create_target_graph(graph) for graph in graphs] tr_input_graphs = input_graphs[:tr_ge_split] tr_target_graphs = target_graphs[:tr_ge_split] ge_input_graphs = input_graphs[tr_ge_split:] ge_target_graphs = target_graphs[tr_ge_split:] ############################################################ # Build and run the KGCN ############################################################ attr_embedders = configure_embedders(node_types, attr_embedding_dim, categorical_attributes, continuous_attributes) kgcn = KGCN(len(node_types), len(edge_types), type_embedding_dim, attr_embedding_dim, attr_embedders, edge_output_size=edge_output_size, node_output_size=node_output_size) learner = KGCNLearner(kgcn, num_processing_steps_tr=num_processing_steps_tr, num_processing_steps_ge=num_processing_steps_ge) train_values, test_values, tr_info = learner( tr_input_graphs, tr_target_graphs, ge_input_graphs, ge_target_graphs, num_training_iterations=num_training_iterations, log_dir=output_dir) plot_across_training(*tr_info, output_file=f'{output_dir}learning.png') plot_predictions(ge_input_graphs, test_values, num_processing_steps_ge, output_file=f'{output_dir}graph.png') logit_graphs = graphs_tuple_to_networkxs(test_values["outputs"][-1]) indexed_ge_graphs = indexed_graphs[tr_ge_split:] ge_graphs = [ apply_logits_to_graphs(graph, logit_graph) for graph, logit_graph in zip(indexed_ge_graphs, logit_graphs) ] for ge_graph in ge_graphs: for data in multidigraph_data_iterator(ge_graph): data['probabilities'] = softmax(data['logits']) data['prediction'] = int(np.argmax(data['probabilities'])) _, _, _, _, _, solveds_tr, solveds_ge = tr_info return ge_graphs, solveds_tr, solveds_ge
def generate_example(positions, properties, k_mean=26, plot=False): """ Generate a geometric graph from positions. Args: positions: [num_points, 3] positions used for graph constrution. properties: [num_points, F0,...,Fd] each node will have these properties of shape [F0,...,Fd] k_mean: float plot: whether to plot graph. Returns: GraphTuple """ graph = nx.DiGraph() sibling_edgelist = [] parent_edgelist = [] pos = dict() # for plotting node positions. real_nodes = list(np.arange(positions.shape[0])) while positions.shape[0] > 1: # n_nodes, n_nodes dist = np.linalg.norm(positions[:, None, :] - positions[None, :, :], axis=-1) opt_screen_length = find_screen_length(dist, k_mean) print("Found optimal screening length {}".format(opt_screen_length)) distance_matrix_no_loops = np.where(dist == 0., np.inf, dist) A = distance_matrix_no_loops < opt_screen_length senders, receivers = np.where(A) n_edge = senders.size # [1,0] for siblings, [0,1] for parent-child sibling_edges = np.tile([[1., 0.]], [n_edge, 1]) # num_points, F0,...Fd # if positions is to be part of features then this should already be set in properties. # We don't concatentate here. Mainly because properties could be an image, etc. sibling_nodes = properties n_nodes = sibling_nodes.shape[0] sibling_node_offset = len(graph.nodes) for node, feature, position in zip(np.arange(sibling_node_offset, sibling_node_offset + n_nodes), sibling_nodes, positions): graph.add_node(node, features=feature) pos[node] = position[:2] # edges = np.stack([senders, receivers], axis=-1) + sibling_node_offset for u, v in zip(senders + sibling_node_offset, receivers + sibling_node_offset): graph.add_edge(u, v, features=np.array([1., 0.])) graph.add_edge(v, u, features=np.array([1., 0.])) sibling_edgelist.append((u, v)) sibling_edgelist.append((v, u)) # for virtual nodes sibling_graph = GraphsTuple(nodes=None, # sibling_nodes, edges=None, senders=senders, receivers=receivers, globals=None, n_node=np.array([n_nodes]), n_edge=np.array([n_edge])) sibling_graph = graphs_tuple_to_networkxs(sibling_graph)[0] # completely connect connected_components = sorted(nx.connected_components(nx.Graph(sibling_graph)), key=len) _positions = [] _properties = [] for connected_component in connected_components: print("Found connected component {}".format(connected_component)) indices = list(sorted(list(connected_component))) virtual_position, virtual_property = make_virtual_node(positions[indices, :], properties[indices, ...]) _positions.append(virtual_position) _properties.append(virtual_property) virtual_positions = np.stack(_positions, axis=0) virtual_properties = np.stack(_properties, axis=0) ### # add virutal nodes # num_parents, 3+F parent_nodes = virtual_properties n_nodes = parent_nodes.shape[0] parent_node_offset = len(graph.nodes) parent_indices = np.arange(parent_node_offset, parent_node_offset + n_nodes) # adding the nodes to global graph for node, feature, virtual_position in zip(parent_indices, parent_nodes, virtual_positions): graph.add_node(node, features=feature) print("new virtual {}".format(node)) pos[node] = virtual_position[:2] for parent_idx, connected_component in zip(parent_indices, connected_components): child_node_indices = [idx + sibling_node_offset for idx in list(sorted(list(connected_component)))] for child_node_idx in child_node_indices: graph.add_edge(parent_idx, child_node_idx, features=np.array([0., 1.])) graph.add_edge(child_node_idx, parent_idx, features=np.array([0., 1.])) parent_edgelist.append((parent_idx, child_node_idx)) parent_edgelist.append((child_node_idx, parent_idx)) print("connecting {}<->{}".format(parent_idx, child_node_idx)) positions = virtual_positions properties = virtual_properties # plotting virutal_nodes = list(set(graph.nodes) - set(real_nodes)) if plot: fig, ax = plt.subplots(1, 1, figsize=(12, 12)) draw(graph, ax=ax, pos=pos, node_color='green', edgelist=[], nodelist=real_nodes) draw(graph, ax=ax, pos=pos, node_color='purple', edgelist=[], nodelist=virutal_nodes) draw(graph, ax=ax, pos=pos, edge_color='blue', edgelist=sibling_edgelist, nodelist=[]) draw(graph, ax=ax, pos=pos, edge_color='red', edgelist=parent_edgelist, nodelist=[]) plt.show() return networkxs_to_graphs_tuple([graph], node_shape_hint=[positions.shape[1] + properties.shape[1]], edge_shape_hint=[2])
def train(self,oldState, action, reward, newState, terminal, eps): """ store newState in the replay memory sample transitions into feed_dict train the network """ self.D.append((oldState,action,reward,newState,terminal)) """ if epsilon is too low, don't train, just observe """ if (eps < OBSERVEEPS): return 42 minibatch = random.sample(self.D, self.BATCH) """ evaluate network from newState change output based on r & a retrain network using oldState """ inputs = [] targets = [] for b in minibatch: s0 = b[0] a = b[1] r = b[2] s1 = b[3] term = b[4] # parse s1 to feed gs1 = self._gtmp2intmp(s1) # get output from feeding s1 testValues = self.sess.run({ "outputs":self.output_ops_tr }, feed_dict={self.inputPh: utils_np.networkxs_to_graphs_tuple([gs1])}) outs1 = utils_np.graphs_tuple_to_networkxs(testValues["outputs"][-1])[0] outs1 = nx.DiGraph(outs1) # change output based on r and a if terminal == False: validActions = list(outs1.out_edges([a[1]])) qdict = dict() for e in validActions: qdict[e] = outs1.get_edge_data(*e)["features"][0] v = max(qdict.values()) outs1.add_edge(*a, features=np.array([r + v])) else: outs1.add_edge(*a, features=np.array([r])) # now outs1 can be trained as target for s0 outs1.graph['features'] = [0.0] # TODO: write new functions to parse state to standard graphs inputs.append(self._gtmp2intmp(s0)) targets.append(outs1) feed_dict = {self.inputPh: utils_np.networkxs_to_graphs_tuple(inputs), self.targetPh: utils_np.networkxs_to_graphs_tuple(targets)} # apply gradient descent train_value = self.sess.run({ "step":self.step_op, "target":self.targetPh, "loss":self.loss_op_tr, "outputs":self.output_ops_tr }, feed_dict = feed_dict) # print(train_value["loss"]) return 42
def plot_graphs_tuple_np(graphs_tuple, phase): graphs_nx = utils_np.graphs_tuple_to_networkxs(graphs_tuple) fig, axs = plt.subplots(ncols=10, figsize=(20, 2)) for iax, (graph_nx, ax) in enumerate(zip(graphs_nx, axs)): graph_t = utils_np.networkxs_to_graphs_tuple([graph_nx]) graph_d = utils_np.graphs_tuple_to_data_dicts(graph_t) # print(type(graph_d)) nodes = graph_d[0]["nodes"] black_nodes = getBlackNodesIndecesPrediction(nodes) white_nodes = getWhiteNodesIndeces(nodes) print("black_nodes") print(black_nodes) print("white_nodes") print(white_nodes) color_map = [] for i in range(len(nodes)): if i in black_nodes: color_map.append('r') else: color_map.append('g') print("color_map") print(color_map) pos = nx.spring_layout(graph_nx) pos = { 0: (10, 20), 1: (20, 30), 2: (30, 40), 3: (50, 60), 4: (50, 70), 5: (50, 80), 6: (50, 90), 7: (40, 70), 8: (40, 80) } nx.draw(graph_nx, pos={ 0: (10, 20), 1: (20, 30), 2: (30, 40), 3: (30, 50), 4: (50, 50), 5: (60, 60), 6: (70, 70), 7: (80, 80), 8: (40, 80) }, ax=ax, node_color=color_map, node_size=100, alpha=0.8) # nx.draw_networkx_nodes(graph_nx,pos, ax=ax, nodelist=white_nodes, # node_color='g', # node_size=100, # alpha=0.8) # if phase == 1: # ax.set_title("Step {}".format(iax)) # else: # x = iax # ax.set_title("Step {}".format(x+5)) ax.set_title("Step {}".format(iax)) if phase == 1: # fig.suptitle('True trajectory', fontsize=10) fig = plt.gcf() fig.canvas.set_window_title('True trajectory') else: # fig.suptitle('Predicted trajectory', fontsize=10) fig = plt.gcf() fig.canvas.set_window_title('Predicted trajectory') nodes = [] black_nodes = [] white_nodes = []
"senders": senders_0, "receivers": receivers_0 } data_dict_1 = { "globals": globals_1, "nodes": nodes_1, "edges": edges_1, "senders": senders_1, "receivers": receivers_1 } data_dict_list = [data_dict_0, data_dict_1] graphs_tuple = utils_np.data_dicts_to_graphs_tuple(data_dict_list) graphs_nx = utils_np.graphs_tuple_to_networkxs(graphs_tuple) _, axs = plt.subplots(ncols=2, figsize=(6, 3)) for iax, (graph_nx, ax) in enumerate(zip(graphs_nx, axs)): nx.draw(graph_nx, ax=ax) ax.set_title("Graph {}".format(iax)) #plt.show() def plot_graphs_tuple_np(graphs_tuple): networkx_graphs = utils_np.graphs_tuple_to_networkxs(graphs_tuple) num_graphs = len(networkx_graphs) _, axes = plt.subplots(1, num_graphs, figsize=(5 * num_graphs, 5)) if num_graphs == 1: axes = axes, for graph, ax in zip(networkx_graphs, axes): plot_graph_networkx(graph, ax)
def pipeline(graphs, tr_ge_split, node_types, edge_types, num_processing_steps_tr=10, num_processing_steps_ge=10, num_training_iterations=10000, continuous_attributes=None, categorical_attributes=None, type_embedding_dim=5, attr_embedding_dim=6, edge_output_size=3, node_output_size=3, output_dir=None, do_test=False, save_fle="test_model.ckpt", reload_fle=""): ############################################################ # Manipulate the graph data ############################################################ # Encode attribute values graphs = [ encode_values(graph, categorical_attributes, continuous_attributes) for graph in graphs ] indexed_graphs = [ nx.convert_node_labels_to_integers(graph, label_attribute='concept') for graph in graphs ] graphs = [duplicate_edges_in_reverse(graph) for graph in indexed_graphs] graphs = [ encode_types(graph, multidigraph_node_data_iterator, node_types) for graph in graphs ] graphs = [ encode_types(graph, multidigraph_edge_data_iterator, edge_types) for graph in graphs ] input_graphs = [create_input_graph(graph) for graph in graphs] target_graphs = [create_target_graph(graph) for graph in graphs] tr_input_graphs = input_graphs[:tr_ge_split] tr_target_graphs = target_graphs[:tr_ge_split] ge_input_graphs = input_graphs[tr_ge_split:] ge_target_graphs = target_graphs[tr_ge_split:] ############################################################ # Build and run the KGCN ############################################################ thing_embedder = ThingEmbedder(node_types, type_embedding_dim, attr_embedding_dim, categorical_attributes, continuous_attributes) role_embedder = RoleEmbedder(len(edge_types), type_embedding_dim) kgcn = KGCN(thing_embedder, role_embedder, edge_output_size=edge_output_size, node_output_size=node_output_size) learner = KGCNLearner( kgcn, num_processing_steps_tr= num_processing_steps_tr, # These processing steps indicate how many message-passing iterations to do for every training / testing step num_processing_steps_ge=num_processing_steps_ge, log_dir=output_dir, save_fle=f'{output_dir}/{save_fle}', reload_fle=f'{output_dir}/{reload_fle}') # only test if not (Path(output_dir) / reload_fle).is_dir() and do_test is True: print("\n\nVALIDATION ONLY\n\n") test_values, tr_info = learner.infer(ge_input_graphs, ge_target_graphs) #,log_dir=output_dir) # train else: print("\n\nTRAINING\n\n") train_values, test_values, tr_info = learner.train( tr_input_graphs, #input_graphs tr_target_graphs, ge_input_graphs, ge_target_graphs, num_training_iterations=num_training_iterations) #,log_dir=output_dir) plot_across_training(*tr_info, output_file=f'{output_dir}/learning.png') plot_predictions(graphs[tr_ge_split:], test_values, num_processing_steps_ge, output_file=f'{output_dir}/graph.png') logit_graphs = graphs_tuple_to_networkxs(test_values["outputs"][-1]) indexed_ge_graphs = indexed_graphs[tr_ge_split:] ge_graphs = [ apply_logits_to_graphs(graph, logit_graph) for graph, logit_graph in zip(indexed_ge_graphs, logit_graphs) ] for ge_graph in ge_graphs: for data in multidigraph_data_iterator(ge_graph): data['probabilities'] = softmax(data['logits']) # assing 0,1,2 based argmax of logits -> TODO: threshold data['prediction'] = int(np.argmax(data['probabilities'])) _, _, _, _, _, solveds_tr, solveds_ge = tr_info return ge_graphs, solveds_tr, solveds_ge
def pipeline(graphs, tr_ge_split, node_types, edge_types, num_processing_steps_tr=10, num_processing_steps_ge=10, num_training_iterations=10000, categorical_attributes=None, type_embedding_dim=5, attr_embedding_dim=6, edge_output_size=3, node_output_size=3): ############################################################ # Manipulate the graph data ############################################################ # Encode attribute values for graph in graphs: for data in multidigraph_data_iterator(graph): data['encoded_value'] = 0 for node_data in multidigraph_node_data_iterator(graph): typ = node_data['type'] # Add the integer value of the category for each categorical attribute instance for attr_typ, category_values in categorical_attributes.items(): if typ == attr_typ: node_data['encoded_value'] = category_values.index( node_data['value']) indexed_graphs = [ nx.convert_node_labels_to_integers(graph, label_attribute='concept') for graph in graphs ] graphs = [duplicate_edges_in_reverse(graph) for graph in indexed_graphs] graphs = [encode_types(graph, node_types, edge_types) for graph in graphs] input_graphs = [create_input_graph(graph) for graph in graphs] target_graphs = [create_target_graph(graph) for graph in graphs] tr_input_graphs = input_graphs[:tr_ge_split] tr_target_graphs = target_graphs[:tr_ge_split] ge_input_graphs = input_graphs[tr_ge_split:] ge_target_graphs = target_graphs[tr_ge_split:] ############################################################ # Build and run the KGCN ############################################################ type_categories_list = [i for i, _ in enumerate(node_types)] non_attribute_nodes = type_categories_list.copy() attr_embedders = dict() # Construct categorical attribute embedders for attr_typ, category_values in categorical_attributes.items(): num_categories = len(category_values) def make_embedder(): return CategoricalAttribute(num_categories, attr_embedding_dim, name=attr_typ + '_cat_embedder') attr_typ_index = node_types.index(attr_typ) # Record the embedder, and the index of the type that it should encode attr_embedders[make_embedder] = [attr_typ_index] non_attribute_nodes.pop(attr_typ_index) # All entities and relations (non-attributes) also need an embedder with matching output dimension, which does # nothing. This is provided as a list of their indices def make_blank_embedder(): return BlankAttribute(attr_embedding_dim) attr_embedders[make_blank_embedder] = non_attribute_nodes kgcn = KGCN(len(node_types), len(edge_types), type_embedding_dim, attr_embedding_dim, attr_embedders, edge_output_size=edge_output_size, node_output_size=node_output_size) learner = KGCNLearner(kgcn, num_processing_steps_tr=num_processing_steps_tr, num_processing_steps_ge=num_processing_steps_ge) train_values, test_values, tr_info = learner( tr_input_graphs, tr_target_graphs, ge_input_graphs, ge_target_graphs, num_training_iterations=num_training_iterations) plot_across_training(*tr_info) plot_predictions(ge_input_graphs, test_values, num_processing_steps_ge) logit_graphs = graphs_tuple_to_networkxs(test_values["outputs"][-1]) indexed_ge_graphs = indexed_graphs[tr_ge_split:] ge_graphs = [ apply_logits_to_graphs(graph, logit_graph) for graph, logit_graph in zip(indexed_ge_graphs, logit_graphs) ] for ge_graph in ge_graphs: for data in multidigraph_data_iterator(ge_graph): data['probabilities'] = softmax(data['logits']) data['prediction'] = int(np.argmax(data['probabilities'])) _, _, _, _, _, solveds_tr, solveds_ge = tr_info return ge_graphs, solveds_tr, solveds_ge