def test_unused_field_can_be_none( self, use_edges, use_nodes, use_globals, none_field): """Checks that computation can handle non-necessary fields left None.""" input_graph = self._get_input_graph([none_field]) edge_block = blocks.EdgeBlock( edge_model_fn=self._edge_model_fn, use_edges=use_edges, use_receiver_nodes=use_nodes, use_sender_nodes=use_nodes, use_globals=use_globals) output_graph = edge_block(input_graph) model_inputs = [] if use_edges: model_inputs.append(input_graph.edges) if use_nodes: model_inputs.append(blocks.broadcast_receiver_nodes_to_edges(input_graph)) model_inputs.append(blocks.broadcast_sender_nodes_to_edges(input_graph)) if use_globals: model_inputs.append(blocks.broadcast_globals_to_edges(input_graph)) model_inputs = tf.concat(model_inputs, axis=-1) self.assertEqual(input_graph.nodes, output_graph.nodes) self.assertEqual(input_graph.globals, output_graph.globals) with self.test_session() as sess: actual_edges, model_inputs_out = sess.run( (output_graph.edges, model_inputs)) expected_output_edges = model_inputs_out * self._scale self.assertNDArrayNear(expected_output_edges, actual_edges, err=1e-4)
def _build(self, graph): """Builds a SpringMassSimulator. Args: graph: A graphs.GraphsTuple having, for some integers N, E, G: - edges: Nx2 tf.Tensor of [spring_constant, rest_length] for each edge. - nodes: Ex5 tf.Tensor of [x, y, v_x, v_y, is_fixed] features for each node. - globals: Gx2 tf.Tensor containing the gravitational constant. Returns: A graphs.GraphsTuple of the same shape as `graph`, but where: - edges: Holds the force [f_x, f_y] acting on each edge. - nodes: Holds positions and velocities after applying one step of Euler integration. """ receiver_nodes = blocks.broadcast_receiver_nodes_to_edges(graph) sender_nodes = blocks.broadcast_sender_nodes_to_edges(graph) spring_force_per_edge = hookes_law(receiver_nodes, sender_nodes, graph.edges[..., 0:1], graph.edges[..., 1:2]) graph = graph.replace(edges=spring_force_per_edge) spring_force_per_node = self._aggregator(graph) gravity = blocks.broadcast_globals_to_nodes(graph) updated_velocities = euler_integration(graph.nodes, spring_force_per_node + gravity, self._step_size) graph = graph.replace(nodes=updated_velocities) return graph
def test_output_values( self, use_edges, use_receiver_nodes, use_sender_nodes, use_globals): """Compares the output of an EdgeBlock to an explicit computation.""" input_graph = self._get_input_graph() edge_block = blocks.EdgeBlock( edge_model_fn=self._edge_model_fn, use_edges=use_edges, use_receiver_nodes=use_receiver_nodes, use_sender_nodes=use_sender_nodes, use_globals=use_globals) output_graph = edge_block(input_graph) model_inputs = [] if use_edges: model_inputs.append(input_graph.edges) if use_receiver_nodes: model_inputs.append(blocks.broadcast_receiver_nodes_to_edges(input_graph)) if use_sender_nodes: model_inputs.append(blocks.broadcast_sender_nodes_to_edges(input_graph)) if use_globals: model_inputs.append(blocks.broadcast_globals_to_edges(input_graph)) model_inputs = tf.concat(model_inputs, axis=-1) self.assertEqual(input_graph.nodes, output_graph.nodes) self.assertEqual(input_graph.globals, output_graph.globals) with self.test_session() as sess: output_graph_out, model_inputs_out = sess.run( (output_graph, model_inputs)) expected_output_edges = model_inputs_out * self._scale self.assertNDArrayNear( expected_output_edges, output_graph_out.edges, err=1e-4)
def _build(self, node_values, node_keys, node_queries, attention_graph): """Connects the multi-head self-attention module. The self-attention is only computed according to the connectivity of the input graphs, with receiver nodes attending to sender nodes. Args: node_values: Tensor containing the values associated to each of the nodes. The expected shape is [total_num_nodes, num_heads, key_size]. node_keys: Tensor containing the key associated to each of the nodes. The expected shape is [total_num_nodes, num_heads, key_size]. node_queries: Tensor containing the query associated to each of the nodes. The expected shape is [total_num_nodes, num_heads, query_size]. The query size must be equal to the key size. attention_graph: Graph containing connectivity information between nodes via the senders and receivers fields. Node A will only attempt to attend to Node B if `attention_graph` contains an edge sent by Node A and received by Node B. Returns: An output `graphs.GraphsTuple` with updated nodes containing the aggregated attended value for each of the nodes with shape [total_num_nodes, num_heads, value_size]. Raises: ValueError: if the input graph does not have edges. """ # Sender nodes put their keys and values in the edges. sender_keys = blocks.broadcast_sender_nodes_to_edges( attention_graph.replace(nodes=node_keys)) sender_values = blocks.broadcast_sender_nodes_to_edges( attention_graph.replace(nodes=node_values)) # Receiver nodes put their queries in the edges. receiver_queries = blocks.broadcast_receiver_nodes_to_edges( attention_graph.replace(nodes=node_queries)) # Attention weight for each edge. attention_weights_logits = tf.reduce_sum(sender_keys * receiver_queries, axis=-1) normalized_attention_weights = _received_edges_normalizer( attention_graph.replace(edges=attention_weights_logits), normalizer=self._normalizer) # Attending to sender values according to the weights. attented_edges = sender_values * \ normalized_attention_weights[..., None] # Summing all of the attended values from each node. received_edges_aggregator = blocks.ReceivedEdgesToNodesAggregator( reducer=tf.unsorted_segment_sum) aggregated_attended_values = received_edges_aggregator( attention_graph.replace(edges=attented_edges)) return attention_graph.replace(nodes=aggregated_attended_values)
def _build(self, attended_graph): """ Feed the input through the layer :param attended_graph: the graph to attend to :return: result """ stacked_edges = tf.stack([ blocks.broadcast_sender_nodes_to_edges(attended_graph), blocks.broadcast_receiver_nodes_to_edges(attended_graph) ], axis=1) his = None for k in range(self.heads): e = tf.map_fn( lambda edge: tf.concat([ tf.tensordot(self.W[k], edge[0], axes=1), tf.tensordot(self.W[k], edge[1], axes=1) ], axis=0), stacked_edges) attended_e = tf.exp(tf.nn.leaky_relu(self.attentions[k](e))) e_sender_sum = tf.math.unsorted_segment_sum( attended_e, attended_graph.senders, num_segments=tf.shape(attended_graph.nodes)[0]) e_receiver_sum = tf.math.unsorted_segment_sum( attended_e, attended_graph.receivers, num_segments=tf.shape(attended_graph.nodes)[0]) stacked_to_avg = tf.stack([ attended_e, tf.add(tf.gather(e_sender_sum, attended_graph.senders), tf.gather(e_receiver_sum, attended_graph.receivers)) ], axis=1) e_avg = tf.map_fn(lambda avg: tf.divide(avg[0], avg[1]), stacked_to_avg) Whi = tf.map_fn( lambda edge: tf.tensordot(self.W[k], edge, axes=1), blocks.broadcast_sender_nodes_to_edges(attended_graph)) aWhi = tf.multiply(Whi, e_avg) hi = tf.math.unsorted_segment_sum(aWhi, attended_graph.senders, num_segments=tf.shape( attended_graph.nodes)[0]) if his is None: his = hi else: his = tf.add(his, hi) his = tf.divide(his, self.heads) return attended_graph.replace(nodes=his)
def set_rest_lengths(graph): """Computes and sets rest lengths for the springs in a physical system. The rest length is taken to be the distance between each edge's nodes. Args: graph: a graphs.GraphsTuple having, for some integers N, E: - nodes: Nx5 Tensor of [x, y, _, _, _] for each node. - edges: Ex2 Tensor of [spring_constant, _] for each edge. Returns: The input graph, but with [spring_constant, rest_length] for each edge. """ receiver_nodes = blocks.broadcast_receiver_nodes_to_edges(graph) sender_nodes = blocks.broadcast_sender_nodes_to_edges(graph) rest_length = tf.norm( receiver_nodes[..., :2] - sender_nodes[..., :2], axis=-1, keep_dims=True) return graph.replace( edges=tf.concat([graph.edges[..., :1], rest_length], axis=-1))
def _build(self, graph): agg_receiver_nodes_features = blocks.broadcast_receiver_nodes_to_edges( graph) agg_sender_nodes_features = blocks.broadcast_sender_nodes_to_edges( graph) # aggreate across replicas replica_ctx = tf.distribute.get_replica_context() agg_receiver_nodes_features = replica_ctx.all_reduce( "sum", agg_receiver_nodes_features) agg_sender_nodes_features = replica_ctx.all_reduce( "sum", agg_sender_nodes_features) edges_to_collect = [ graph.edges, agg_receiver_nodes_features, agg_sender_nodes_features ] collected_edges = tf.concat(edges_to_collect, axis=-1) updated_edges = self._edge_model(collected_edges) return graph.replace(edges=updated_edges)
def _build(self, node_values, node_keys, node_queries, attention_graph): # Sender nodes put their keys and values in the edges. # [total_num_edges, num_heads, query_size] sender_keys = blocks.broadcast_sender_nodes_to_edges( attention_graph.replace(nodes=node_keys)) # [total_num_edges, num_heads, value_size] sender_values = blocks.broadcast_sender_nodes_to_edges( attention_graph.replace(nodes=node_values)) # Receiver nodes put their queries in the edges. # [total_num_edges, num_heads, key_size] receiver_queries = blocks.broadcast_receiver_nodes_to_edges( attention_graph.replace(nodes=node_queries)) # Attention weight for each edge. # [total_num_edges, num_heads] attention_weights_logits = tf.reduce_sum( sender_keys * tf.transpose(receiver_queries), axis=-1) normalized_attention_weights = _received_edges_normalizer( attention_graph.replace(edges=attention_weights_logits), normalizer=self._normalizer) # Attending to sender values according to the weights. # [total_num_edges, num_heads, embedding_size] attented_edges = sender_values * normalized_attention_weights[..., None] # Summing all of the attended values from each node. # [total_num_nodes, num_heads, embedding_size] received_edges_aggregator = blocks.ReceivedEdgesToNodesAggregator( reducer=tf.unsorted_segment_sum) aggregated_attended_values = received_edges_aggregator( attention_graph.replace(edges=attented_edges)) return attention_graph.replace(nodes=aggregated_attended_values)
def _build(self, graph_features): """Connects the multi-head self-attention module. Uses edge_features to compute key, values and node_features for queries. The self-attention is only computed according to the connectivity of the input graphs, with receiver nodes attending to sender nodes. Args: graph_features: Graph containing connectivity information between nodes via the senders and receivers fields. Node A will only attempt to attend to Node B if `attention_graph` contains an edge sent by Node A and received by Node B. Returns: An output `graphs.GraphsTuple` with updated nodes containing the aggregated attended value for each of the nodes with shape [total_num_nodes, num_heads, value_size]. Raises: ValueError: if the input graph does not have edges. """ """ # TODO(arc): Figure out how to incorporate edge information into attention updates. """ edges = self._edge_block(graph_features).edges num_heads = self.num_heads key_size = self.key_size value_size = self.value_size node_embed_dim = tf.shape(graph_features.nodes)[-1] # [total_num_nodes, d] => [total_num_nodes, key_size * num_heads] q = self._attention_node_projection_model(graph_features.nodes) q = tf.reshape( q, [tf.reduce_sum(graph_features.n_node), num_heads, key_size]) # [total_num_edges, (key_size + value_size) * num_heads] # project edge features to get key, values kv = self._attention_edge_projection_model(edges) kv = tf.reshape(kv, [-1, num_heads, key_size + value_size]) # k => [total_num_edges, num_heads, key_size] # v => [total_num_edges, num_heads, value_size] k, v = tf.split(kv, [key_size, value_size], -1) sender_keys = k sender_values = v # Receiver nodes put their queries in the edges. # [total_num_edges, num_heads, key_size] receiver_queries = blocks.broadcast_receiver_nodes_to_edges( graph_features.replace(nodes=q)) # Attention weight for each edge. # [total_num_edges, num_heads, 1] attention_weights_logits = snt.BatchApply( self._query_key_product_model)(tf.concat( [sender_keys, receiver_queries], axis=-1)) # [total_num_edges, num_heads] attention_weights_logits = tf.squeeze(attention_weights_logits, -1) # compute softmax weights # [total_num_edges, num_heads] normalized_attention_weights = _received_edges_normalizer( graph_features.replace(edges=attention_weights_logits), normalizer=_unsorted_segment_softmax) # Attending to sender values according to the weights. # [total_num_edges, num_heads, value_size] attented_edges = sender_values * normalized_attention_weights[..., None] received_edges_aggregator = blocks.ReceivedEdgesToNodesAggregator( reducer=tf.unsorted_segment_sum) # Summing all of the attended values from each node. # [total_num_nodes, num_heads, value_size] aggregated_attended_values = received_edges_aggregator( graph_features.replace(edges=attented_edges)) # concatenate all the heads and project to required dimension. # cast to [total_num_nodes, num_heads * value_size] aggregated_attended_values = tf.reshape(aggregated_attended_values, [-1, num_heads * value_size]) # -> [total_num_nodes, node_embed_dim] aggregated_attended_values = self._node_model( aggregated_attended_values) return self._global_block( graph_features.replace(nodes=aggregated_attended_values, edges=edges))
tvars = graph_network.trainable_variables print('') ############### # broadcast graphs_tuple = utils_tf.data_dicts_to_graphs_tuple([data_dict_0]) updated_broadcast_globals_to_nodes = graphs_tuple.replace( nodes=blocks.broadcast_globals_to_nodes(graphs_tuple)) updated_broadcast_globals_to_edges = graphs_tuple.replace( edges=blocks.broadcast_globals_to_edges(graphs_tuple)) updated_broadcast_sender_nodes_to_edges = graphs_tuple.replace( edges=blocks.broadcast_sender_nodes_to_edges(graphs_tuple)) updated_broadcast_receiver_nodes_to_edges = graphs_tuple.replace( edges=blocks.broadcast_receiver_nodes_to_edges(graphs_tuple)) ############ # aggregate graphs_tuple = utils_tf.data_dicts_to_graphs_tuple([data_dict_0]) reducer = tf.math.unsorted_segment_sum #######yr updated_edges_to_globals = graphs_tuple.replace( globals=blocks.EdgesToGlobalsAggregator(reducer=reducer)(graphs_tuple)) updated_nodes_to_globals = graphs_tuple.replace( globals=blocks.NodesToGlobalsAggregator(reducer=reducer)(graphs_tuple)) updated_sent_edges_to_nodes = graphs_tuple.replace( nodes=blocks.SentEdgesToNodesAggregator(reducer=reducer)(graphs_tuple)) updated_received_edges_to_nodes = graphs_tuple.replace( nodes=blocks.ReceivedEdgesToNodesAggregator(reducer=reducer)(graphs_tuple))