def destandardize_graphs_tuple(self, graphs: GraphsTuple) -> GraphsTuple: standard_graphs = graphs.replace(globals=self._destandardize( graphs.globals, mean=self.global_mean, std=self.global_std)) standard_graphs = standard_graphs.replace(nodes=self._destandardize( graphs.nodes, mean=self.nodes_mean, std=self.nodes_std)) standard_graphs = standard_graphs.replace(edges=self._destandardize( graphs.edges, mean=self.edges_mean, std=self.edges_std)) return standard_graphs
def connect_graph_dynamic(graph: GraphsTuple, is_edge_func, name="connect_graph_dynamic"): """ Connects a graph using a boolean edge mask to create edges. Args: graph: GraphsTuple is_edge_func: callable(sender: int, receiver: int) -> bool, should broadcast name: Returns: connected GraphsTuple """ utils_tf._validate_edge_fields_are_all_none(graph) with tf.name_scope(name): def body(i, senders, receivers, n_edge): edges = _create_functional_connect_edges_dynamic(graph.n_node[i], is_edge_func) # edges = create_edges_func(graph.n_node[i]) return (i + 1, senders.write(i, edges['senders']), receivers.write(i, edges['receivers']), n_edge.write(i, edges['n_edge'])) num_graphs = utils_tf.get_num_graphs(graph) loop_condition = lambda i, *_: tf.less(i, num_graphs) initial_loop_vars = [0] + [ tf.TensorArray(dtype=tf.int32, size=num_graphs, infer_shape=False) for _ in range(3) # senders, receivers, n_edge ] _, senders_array, receivers_array, n_edge_array = tf.while_loop(loop_condition, body, initial_loop_vars) n_edge = n_edge_array.concat() offsets = utils_tf._compute_stacked_offsets(graph.n_node, n_edge) senders = senders_array.concat() + offsets receivers = receivers_array.concat() + offsets senders.set_shape(offsets.shape) receivers.set_shape(offsets.shape) receivers.set_shape([None]) senders.set_shape([None]) num_graphs = graph.n_node.get_shape().as_list()[0] n_edge.set_shape([num_graphs]) return graph.replace(senders=tf.stop_gradient(senders), receivers=tf.stop_gradient(receivers), n_edge=tf.stop_gradient(n_edge))
def apply_random_rotation(graph: graphs.GraphsTuple) -> graphs.GraphsTuple: """Returns randomly rotated graph representation. The rotation is an element of O(3) with rotation angles multiple of pi/2. This function assumes that the relative particle distances are stored in the edge features. Args: graph: The graphs tuple as defined in `graph_nets.graphs`. """ # Transposes edge features, so that the axes are in the first dimension. # Outputs a tensor of shape [3, n_particles]. xyz = tf.transpose(graph.edges) # Random pi/2 rotation(s) permutation = tf.random.shuffle(tf.constant([0, 1, 2], dtype=tf.int32)) xyz = tf.gather(xyz, permutation) # Random reflections. symmetry = tf.random_uniform([3], minval=0, maxval=2, dtype=tf.int32) symmetry = 1 - 2 * tf.cast(tf.reshape(symmetry, [3, 1]), tf.float32) xyz = xyz * symmetry edges = tf.transpose(xyz) return graph.replace(edges=edges)
def _build(self, graph): # give graph edges and new node dimension (linear transformation) graph = graph.replace(edges=tf.tile(self.intra_graph_edge_variable[None, :], [graph.n_edge[0], 1])) graph = self.projection_node_block(graph) # [n_nodes, node_size] # print('graph 1', graph) n_node = tf.shape(graph.nodes)[0] graph.replace(n_node=n_node) # create fully connected output token nodes token_start_nodes = tf.tile(self.empty_node_variable[None, :], [self.num_output, 1]) token_graph = GraphsTuple(nodes=token_start_nodes, edges=None, globals=tf.constant([0.], dtype=tf.float32), senders=None, receivers=None, n_node=tf.constant([self.num_output], dtype=tf.int32), n_edge=tf.constant([0], dtype=tf.int32)) token_graph = fully_connect_graph_dynamic(token_graph) # print('\n token graph', token_graph, '\n') token_graph = token_graph.replace( edges=tf.tile(self.intra_token_graph_edge_variable[None, :], [token_graph.n_edge[0], 1])) concat_graph = concat([graph, token_graph], axis=0) # n_node = [n_nodes, n_tokes] concat_graph = concat_graph.replace(n_node=tf.reduce_sum(concat_graph.n_node, keepdims=True), n_edge=tf.reduce_sum(concat_graph.n_edge, keepdims=True)) # n_node=[n_nodes+n_tokens] # add random edges between # choose random unique set of nodes in graph, choose random set of nodes in token_graph gumbel = -tf.math.log(-tf.math.log(tf.random.uniform((n_node,)))) n_connect_edges = tf.cast( tf.multiply(tf.constant([self.inter_graph_connect_prob]), tf.cast(n_node, tf.float32)), tf.int32) _, graph_senders = tf.nn.top_k(gumbel, n_connect_edges[0]) # print('graph_senders', graph_senders) token_graph_receivers = n_node + tf.random.uniform(shape=n_connect_edges, minval=0, maxval=self.num_output, dtype=tf.int32) # print('token_graph_receivers', token_graph_receivers) senders = tf.concat([concat_graph.senders, graph_senders, token_graph_receivers], axis=0) # add bi-directional senders + receivers receivers = tf.concat([concat_graph.receivers, token_graph_receivers, graph_senders], axis=0) inter_edges = tf.tile(self.inter_graph_edge_variable[None, :], tf.concat([2 * n_connect_edges, tf.constant([1], dtype=tf.int32)], axis=0)) # 200 = 10000(n_nodes) * 0.01 * 2 edges = tf.concat([concat_graph.edges, inter_edges], axis=0) concat_graph = concat_graph.replace(senders=senders, receivers=receivers, edges=edges, n_edge=concat_graph.n_edge[0] + 2 * n_connect_edges[0], # concat_graph.n_edge[0] + 2 * n_connect_edges globals=self.starting_global_variable[None, :]) # print('starting global', self.starting_global_variable[None, :]) latent_graph = concat_graph print('concat_graph_nodes', self.name, concat_graph.nodes) for step in range(self.crossing_steps): # this would be that theoretical crossing time for information through the graph input_nodes = latent_graph.nodes latent_graph = self.edge_block(latent_graph) latent_graph = self.node_block(latent_graph) latent_graph = self.global_block(latent_graph) latent_graph = latent_graph.replace(nodes=latent_graph.nodes + input_nodes) # residual connections print('latent_graph_nodes', self.name, latent_graph.nodes) print('latent_graph_edges', self.name, latent_graph.edges) print('latent_graph_globals', self.name, latent_graph.globals) latent_graph = latent_graph.replace(nodes=latent_graph.nodes[n_node:], edges=None, receivers=None, senders=None, globals=None, n_node=tf.constant([self.num_output], dtype=tf.int32), n_edge=tf.constant(0, dtype=tf.int32)) output_graph = self.output_projection_node_block(latent_graph) print('output_graph_nodes', self.name, output_graph.nodes) return output_graph
def _build(self, graph, crossing_steps): n_edge = graph.n_edge[0] graph = graph.replace(edges=tf.tile( self.intra_graph_edge_variable[None, :], [n_edge, 1])) graph = self.projection_node_block(graph) # [n_nodes, node_size] n_node = tf.shape(graph.nodes)[0] # create fully connected output token nodes token_start_nodes = tf.tile(self.empty_node_variable[None, :], [self.num_output, 1]) token_graph = GraphsTuple(nodes=token_start_nodes, edges=None, globals=None, senders=None, receivers=None, n_node=[self.num_output], n_edge=None) token_graph = fully_connect_graph_static(token_graph) n_edge = token_graph.n_edge[0] token_graph = token_graph.replace(edges=tf.tile( self.intra_token_graph_edge_variable[None, :], [n_edge, 1])) concat_graph = concat([graph, token_graph], axis=0) # n_node = [n_nodes, n_tokes] concat_graph = concat_graph.replace(n_node=tf.reduce_sum( concat_graph, keepdims=True)) # n_node=[n_nodes+n_tokens] # add random edges between # choose random unique set of nodes in graph, choose random set of nodes in token_graph gumbel = -tf.log(-tf.log(tf.random_uniform((n_node, )))) n_connect_edges = tf.cast(self.inter_graph_connect_prob * n_node, tf.int32) _, graph_senders = tf.nn.top_k(gumbel, n_connect_edges) token_graph_receivers = n_node + tf.random.uniform( (n_connect_edges, ), minval=0, maxval=self.num_output, dtype=tf.int32) senders = tf.concat( [concat_graph.senders, graph_senders, token_graph_receivers], axis=0) # add bi-directional senders + receivers receivers = tf.concat( [concat_graph.receivers, token_graph_receivers, graph_senders], axis=0) inter_edges = tf.tile(self.inter_graph_edge_variable[None, :], [2 * n_connect_edges, 1]) edges = tf.concat([concat_graph.edges, inter_edges], axis=0) concat_graph = concat_graph.replace( senders=senders, receivers=receivers, edges=edges, n_edge=[concat_graph.n_edge[0] + 2 * n_connect_edges], globals=self.starting_global_variable[None, :]) latent_graph = concat_graph for _ in range( self.crossing_steps ): # this would be that theoretical crossing time for information through the graph input_nodes = latent_graph.nodes latent_graph = self.edge_block(latent_graph) latent_graph = self.node_block(latent_graph) latent_graph = self.global_block(latent_graph) latent_graph = latent_graph.replace( nodes=latent_graph.nodes + input_nodes) # residual connections output_graph = self.output_projection_node_block(latent_graph) return output_graph