def _build(self, node_values, node_keys, node_queries, attention_graph): """Connects the multi-head self-attention module. The self-attention is only computed according to the connectivity of the input graphs, with receiver nodes attending to sender nodes. Args: node_values: Tensor containing the values associated to each of the nodes. The expected shape is [total_num_nodes, num_heads, key_size]. node_keys: Tensor containing the key associated to each of the nodes. The expected shape is [total_num_nodes, num_heads, key_size]. node_queries: Tensor containing the query associated to each of the nodes. The expected shape is [total_num_nodes, num_heads, query_size]. The query size must be equal to the key size. attention_graph: Graph containing connectivity information between nodes via the senders and receivers fields. Node A will only attempt to attend to Node B if `attention_graph` contains an edge sent by Node A and received by Node B. Returns: An output `graphs.GraphsTuple` with updated nodes containing the aggregated attended value for each of the nodes with shape [total_num_nodes, num_heads, value_size]. Raises: ValueError: if the input graph does not have edges. """ # Sender nodes put their keys and values in the edges. sender_keys = blocks.broadcast_sender_nodes_to_edges( attention_graph.replace(nodes=node_keys)) sender_values = blocks.broadcast_sender_nodes_to_edges( attention_graph.replace(nodes=node_values)) # Receiver nodes put their queries in the edges. receiver_queries = blocks.broadcast_receiver_nodes_to_edges( attention_graph.replace(nodes=node_queries)) # Attention weight for each edge. attention_weights_logits = tf.reduce_sum(sender_keys * receiver_queries, axis=-1) normalized_attention_weights = _received_edges_normalizer( attention_graph.replace(edges=attention_weights_logits), normalizer=self._normalizer) # Attending to sender values according to the weights. attented_edges = sender_values * \ normalized_attention_weights[..., None] # Summing all of the attended values from each node. received_edges_aggregator = blocks.ReceivedEdgesToNodesAggregator( reducer=tf.unsorted_segment_sum) aggregated_attended_values = received_edges_aggregator( attention_graph.replace(edges=attented_edges)) return attention_graph.replace(nodes=aggregated_attended_values)
def _build(self, attended_graph): """ Feed the input through the layer :param attended_graph: the graph to attend to :return: result """ stacked_edges = tf.stack([ blocks.broadcast_sender_nodes_to_edges(attended_graph), blocks.broadcast_receiver_nodes_to_edges(attended_graph) ], axis=1) his = None for k in range(self.heads): e = tf.map_fn( lambda edge: tf.concat([ tf.tensordot(self.W[k], edge[0], axes=1), tf.tensordot(self.W[k], edge[1], axes=1) ], axis=0), stacked_edges) attended_e = tf.exp(tf.nn.leaky_relu(self.attentions[k](e))) e_sender_sum = tf.math.unsorted_segment_sum( attended_e, attended_graph.senders, num_segments=tf.shape(attended_graph.nodes)[0]) e_receiver_sum = tf.math.unsorted_segment_sum( attended_e, attended_graph.receivers, num_segments=tf.shape(attended_graph.nodes)[0]) stacked_to_avg = tf.stack([ attended_e, tf.add(tf.gather(e_sender_sum, attended_graph.senders), tf.gather(e_receiver_sum, attended_graph.receivers)) ], axis=1) e_avg = tf.map_fn(lambda avg: tf.divide(avg[0], avg[1]), stacked_to_avg) Whi = tf.map_fn( lambda edge: tf.tensordot(self.W[k], edge, axes=1), blocks.broadcast_sender_nodes_to_edges(attended_graph)) aWhi = tf.multiply(Whi, e_avg) hi = tf.math.unsorted_segment_sum(aWhi, attended_graph.senders, num_segments=tf.shape( attended_graph.nodes)[0]) if his is None: his = hi else: his = tf.add(his, hi) his = tf.divide(his, self.heads) return attended_graph.replace(nodes=his)
def test_unused_field_can_be_none( self, use_edges, use_nodes, use_globals, none_field): """Checks that computation can handle non-necessary fields left None.""" input_graph = self._get_input_graph([none_field]) edge_block = blocks.EdgeBlock( edge_model_fn=self._edge_model_fn, use_edges=use_edges, use_receiver_nodes=use_nodes, use_sender_nodes=use_nodes, use_globals=use_globals) output_graph = edge_block(input_graph) model_inputs = [] if use_edges: model_inputs.append(input_graph.edges) if use_nodes: model_inputs.append(blocks.broadcast_receiver_nodes_to_edges(input_graph)) model_inputs.append(blocks.broadcast_sender_nodes_to_edges(input_graph)) if use_globals: model_inputs.append(blocks.broadcast_globals_to_edges(input_graph)) model_inputs = tf.concat(model_inputs, axis=-1) self.assertEqual(input_graph.nodes, output_graph.nodes) self.assertEqual(input_graph.globals, output_graph.globals) with self.test_session() as sess: actual_edges, model_inputs_out = sess.run( (output_graph.edges, model_inputs)) expected_output_edges = model_inputs_out * self._scale self.assertNDArrayNear(expected_output_edges, actual_edges, err=1e-4)
def test_output_values( self, use_edges, use_receiver_nodes, use_sender_nodes, use_globals): """Compares the output of an EdgeBlock to an explicit computation.""" input_graph = self._get_input_graph() edge_block = blocks.EdgeBlock( edge_model_fn=self._edge_model_fn, use_edges=use_edges, use_receiver_nodes=use_receiver_nodes, use_sender_nodes=use_sender_nodes, use_globals=use_globals) output_graph = edge_block(input_graph) model_inputs = [] if use_edges: model_inputs.append(input_graph.edges) if use_receiver_nodes: model_inputs.append(blocks.broadcast_receiver_nodes_to_edges(input_graph)) if use_sender_nodes: model_inputs.append(blocks.broadcast_sender_nodes_to_edges(input_graph)) if use_globals: model_inputs.append(blocks.broadcast_globals_to_edges(input_graph)) model_inputs = tf.concat(model_inputs, axis=-1) self.assertEqual(input_graph.nodes, output_graph.nodes) self.assertEqual(input_graph.globals, output_graph.globals) with self.test_session() as sess: output_graph_out, model_inputs_out = sess.run( (output_graph, model_inputs)) expected_output_edges = model_inputs_out * self._scale self.assertNDArrayNear( expected_output_edges, output_graph_out.edges, err=1e-4)
def _build(self, graph): """Builds a SpringMassSimulator. Args: graph: A graphs.GraphsTuple having, for some integers N, E, G: - edges: Nx2 tf.Tensor of [spring_constant, rest_length] for each edge. - nodes: Ex5 tf.Tensor of [x, y, v_x, v_y, is_fixed] features for each node. - globals: Gx2 tf.Tensor containing the gravitational constant. Returns: A graphs.GraphsTuple of the same shape as `graph`, but where: - edges: Holds the force [f_x, f_y] acting on each edge. - nodes: Holds positions and velocities after applying one step of Euler integration. """ receiver_nodes = blocks.broadcast_receiver_nodes_to_edges(graph) sender_nodes = blocks.broadcast_sender_nodes_to_edges(graph) spring_force_per_edge = hookes_law(receiver_nodes, sender_nodes, graph.edges[..., 0:1], graph.edges[..., 1:2]) graph = graph.replace(edges=spring_force_per_edge) spring_force_per_node = self._aggregator(graph) gravity = blocks.broadcast_globals_to_nodes(graph) updated_velocities = euler_integration(graph.nodes, spring_force_per_node + gravity, self._step_size) graph = graph.replace(nodes=updated_velocities) return graph
def _build(self, node_values, node_keys, node_queries, attention_graph): # Sender nodes put their keys and values in the edges. # [total_num_edges, num_heads, query_size] sender_keys = blocks.broadcast_sender_nodes_to_edges( attention_graph.replace(nodes=node_keys)) # [total_num_edges, num_heads, value_size] sender_values = blocks.broadcast_sender_nodes_to_edges( attention_graph.replace(nodes=node_values)) # Receiver nodes put their queries in the edges. # [total_num_edges, num_heads, key_size] receiver_queries = blocks.broadcast_receiver_nodes_to_edges( attention_graph.replace(nodes=node_queries)) # Attention weight for each edge. # [total_num_edges, num_heads] attention_weights_logits = tf.reduce_sum( sender_keys * tf.transpose(receiver_queries), axis=-1) normalized_attention_weights = _received_edges_normalizer( attention_graph.replace(edges=attention_weights_logits), normalizer=self._normalizer) # Attending to sender values according to the weights. # [total_num_edges, num_heads, embedding_size] attented_edges = sender_values * normalized_attention_weights[..., None] # Summing all of the attended values from each node. # [total_num_nodes, num_heads, embedding_size] received_edges_aggregator = blocks.ReceivedEdgesToNodesAggregator( reducer=tf.unsorted_segment_sum) aggregated_attended_values = received_edges_aggregator( attention_graph.replace(edges=attented_edges)) return attention_graph.replace(nodes=aggregated_attended_values)
def set_rest_lengths(graph): """Computes and sets rest lengths for the springs in a physical system. The rest length is taken to be the distance between each edge's nodes. Args: graph: a graphs.GraphsTuple having, for some integers N, E: - nodes: Nx5 Tensor of [x, y, _, _, _] for each node. - edges: Ex2 Tensor of [spring_constant, _] for each edge. Returns: The input graph, but with [spring_constant, rest_length] for each edge. """ receiver_nodes = blocks.broadcast_receiver_nodes_to_edges(graph) sender_nodes = blocks.broadcast_sender_nodes_to_edges(graph) rest_length = tf.norm( receiver_nodes[..., :2] - sender_nodes[..., :2], axis=-1, keep_dims=True) return graph.replace( edges=tf.concat([graph.edges[..., :1], rest_length], axis=-1))
def _build(self, graph): agg_receiver_nodes_features = blocks.broadcast_receiver_nodes_to_edges( graph) agg_sender_nodes_features = blocks.broadcast_sender_nodes_to_edges( graph) # aggreate across replicas replica_ctx = tf.distribute.get_replica_context() agg_receiver_nodes_features = replica_ctx.all_reduce( "sum", agg_receiver_nodes_features) agg_sender_nodes_features = replica_ctx.all_reduce( "sum", agg_sender_nodes_features) edges_to_collect = [ graph.edges, agg_receiver_nodes_features, agg_sender_nodes_features ] collected_edges = tf.concat(edges_to_collect, axis=-1) updated_edges = self._edge_model(collected_edges) return graph.replace(edges=updated_edges)
def _build(self, graph_features): """Connects the multi-head self-attention module. The self-attention is only computed according to the connectivity of the input graphs, with receiver nodes attending to sender nodes. Args: graph_features: Graph containing connectivity information between nodes via the senders and receivers fields. Node A will only attempt to attend to Node B if `attention_graph` contains an edge sent by Node A and received by Node B. Returns: An output `graphs.GraphsTuple` with updated nodes containing the aggregated attended value for each of the nodes with shape [total_num_nodes, num_heads, value_size]. Raises: ValueError: if the input graph does not have edges. """ """ # TODO(arc): Figure out how to incorporate edge information into attention updates. """ nodes = graph_features.nodes num_heads = self.num_heads key_size = self.key_size value_size = self.value_size node_embed_dim = tf.shape(nodes)[-1] qkv_size = 2 * key_size + value_size total_size = qkv_size * num_heads # denote as F # [total_num_nodes, d] => [total_num_nodes, F] qkv_flat = self._attention_projection_model(nodes) qkv = tf.reshape(qkv_flat, [-1, num_heads, qkv_size]) # q => [total_num_nodes, num_heads, key_size] # k => [total_num_nodes, num_heads, key_size] # v => [total_num_nodes, num_heads, value_size] q, k, v = tf.split(qkv, [key_size, key_size, value_size], -1) # Sender nodes put their keys and values in the edges. # [total_num_edges, num_heads, query_size] sender_keys = blocks.broadcast_sender_nodes_to_edges( graph_features.replace(nodes=k)) # [total_num_edges, num_heads, value_size] sender_values = blocks.broadcast_sender_nodes_to_edges( graph_features.replace(nodes=v)) # Receiver nodes put their queries in the edges. # [total_num_edges, num_heads, key_size] receiver_queries = blocks.broadcast_receiver_nodes_to_edges( graph_features.replace(nodes=q)) # Attention weight for each edge. # [total_num_edges, num_heads, 1] attention_weights_logits = snt.BatchApply( self._query_key_product_model)(tf.concat( [sender_keys, receiver_queries], axis=-1)) # [total_num_edges, num_heads] attention_weights_logits = tf.squeeze(attention_weights_logits, -1) # compute softmax weights # [total_num_edges, num_heads] normalized_attention_weights = _received_edges_normalizer( graph_features.replace(edges=attention_weights_logits), normalizer=self._normalizer) # Attending to sender values according to the weights. # [total_num_edges, num_heads, value_size] attented_edges = sender_values * normalized_attention_weights[..., None] received_edges_aggregator = blocks.ReceivedEdgesToNodesAggregator( reducer=tf.unsorted_segment_sum) # Summing all of the attended values from each node. # [total_num_nodes, num_heads, value_size] aggregated_attended_values = received_edges_aggregator( graph_features.replace(edges=attented_edges)) # concatenate all the heads and project to required dimension. # cast to [total_num_nodes, num_heads * value_size] aggregated_attended_values = tf.reshape(aggregated_attended_values, [-1, num_heads * value_size]) # -> [total_num_nodes, node_embed_dim] aggregated_attended_values = self._node_model( aggregated_attended_values) return graph_features.replace(nodes=aggregated_attended_values)
print(previous_graphs.nodes[0]) output_graphs = previous_graphs tvars = graph_network.trainable_variables print('') ############### # broadcast graphs_tuple = utils_tf.data_dicts_to_graphs_tuple([data_dict_0]) updated_broadcast_globals_to_nodes = graphs_tuple.replace( nodes=blocks.broadcast_globals_to_nodes(graphs_tuple)) updated_broadcast_globals_to_edges = graphs_tuple.replace( edges=blocks.broadcast_globals_to_edges(graphs_tuple)) updated_broadcast_sender_nodes_to_edges = graphs_tuple.replace( edges=blocks.broadcast_sender_nodes_to_edges(graphs_tuple)) updated_broadcast_receiver_nodes_to_edges = graphs_tuple.replace( edges=blocks.broadcast_receiver_nodes_to_edges(graphs_tuple)) ############ # aggregate graphs_tuple = utils_tf.data_dicts_to_graphs_tuple([data_dict_0]) reducer = tf.math.unsorted_segment_sum #######yr updated_edges_to_globals = graphs_tuple.replace( globals=blocks.EdgesToGlobalsAggregator(reducer=reducer)(graphs_tuple)) updated_nodes_to_globals = graphs_tuple.replace( globals=blocks.NodesToGlobalsAggregator(reducer=reducer)(graphs_tuple)) updated_sent_edges_to_nodes = graphs_tuple.replace( nodes=blocks.SentEdgesToNodesAggregator(reducer=reducer)(graphs_tuple))
def model_fn(features, labels, mode, params): sentences = features["input_ids"] word_embedding = tf.constant(params['word_embedding']) graph_nodes = params['graph_nodes'] graph_edges = params['graph_edges'] depth = graph_nodes.shape[1] training = mode == tf.estimator.ModeKeys.TRAIN padding_mask = tf.cast(tf.not_equal(tf.cast(sentences, tf.int32), tf.constant([[1]])), tf.int32) # 0 means the token needs to be masked. 1 means it is not masked. padding_mask = tf.reshape(padding_mask, [-1, FLAGS.seq_len, 1]) sentences = tf.nn.embedding_lookup(word_embedding, sentences) sentences = tf.reshape(sentences, [-1, FLAGS.seq_len, depth]) # print("sentences: " + str(sentences)) # print("padding_mask: " + str(padding_mask)) question_encoder = tf.keras.layers.LSTM(depth, dropout=FLAGS.dropout, return_sequences=True) encoded_question = question_encoder(sentences, training=training) encoded_question = tf.cast(padding_mask, tf.float32) * tf.cast(encoded_question, tf.float32) encoded_question = tf.reshape(tf.cast(encoded_question, tf.float32), [-1, depth]) # The template graph nodes = graph_nodes.astype(np.float32) edges = np.ones([int(np.sum(graph_edges)), 1]).astype(np.float32) senders, receivers = np.nonzero(graph_edges) globals = np.zeros(FLAGS.global_size).astype(np.float32) graph_dict = {"globals": globals, "nodes": nodes, "edges": edges, "senders": senders, "receivers": receivers} original_graph = utils_tf.data_dicts_to_graphs_tuple([graph_dict]) graph_dict["nodes"] = nodes * 0 # print("encoded_question.shape[0]: " + str(encoded_question.shape[0])) batch_of_tensor_data_dicts = [graph_dict for i in range(sentences.shape[0])] batch_of_graphs = utils_tf.data_dicts_to_graphs_tuple(batch_of_tensor_data_dicts) batch_of_nodes = batch_of_graphs.nodes # print("batch_of_nodes: " + str(batch_of_nodes)) # Euclidean distance to identify closest nodes na = tf.reduce_sum(tf.square(tf.math.l2_normalize(encoded_question, -1)), 1) nb = tf.reduce_sum(tf.square(tf.math.l2_normalize(nodes, -1)), 1) # na as a row and nb as a column vectors na = tf.reshape(na, [-1, 1]) nb = tf.reshape(nb, [1, -1]) # return pairwise euclidead difference matrix distance = tf.sqrt(tf.maximum(na - 2 * tf.matmul(encoded_question, nodes, False, True) + nb, 0.0)) # calculate attention over the graph closest_nodes = tf.cast(tf.argmin(distance, -1), tf.int32) # print("closest_nodes: " + str(closest_nodes)) # # Write the signals onto these nodes positions = tf.where(tf.not_equal(tf.reshape(closest_nodes, [-1, FLAGS.seq_len]), 99999)) # print("positions: " + str(positions)) positions = tf.slice(positions, [0, 0], [-1, 1]) # we only want the first 2 dimensions, since the last dimension is incorrect # print("positions: " + str(positions)) positions = tf.cast(positions, tf.int32) # print("positions: " + str(positions)) positions = tf.concat([positions, tf.reshape(closest_nodes, [-1, 1])], -1) # print("positions: " + str(positions)) # print("compressed: " + str(compressed1)) # print("norm_duplicate: " + str(tf.reshape(norm_duplicate, [-1, 1]))) projection_signal = tf.reshape(encoded_question, [-1, depth]) # print("projection_signal: " + str(projection_signal)) batch_of_nodes = tf.tensor_scatter_nd_add(tf.reshape(batch_of_nodes, [-1, 512, depth]), positions, projection_signal) # print("batch_of_nodes: " + str(batch_of_nodes)) batch_of_graphs = batch_of_graphs.replace(nodes=tf.reshape(batch_of_nodes, [-1, depth])) global_block = blocks.NodesToGlobalsAggregator(tf.math.unsorted_segment_mean) global_dense = tf.keras.layers.Dense(depth, activation='relu') num_recurrent_passes = FLAGS.recurrences previous_graphs = batch_of_graphs original_nodes = tf.reshape(original_graph.nodes, [1, 512, depth]) dropout = tf.keras.layers.Dropout(FLAGS.dropout) layernorm_global = tf.keras.layers.LayerNormalization(epsilon=1e-6) layernorm_node = tf.keras.layers.LayerNormalization(epsilon=1e-6) new_global = global_block(previous_graphs) previous_graphs = previous_graphs.replace(globals=global_dense(new_global)) previous_graphs = previous_graphs.replace(globals=layernorm_global(previous_graphs.globals)) initial_global = previous_graphs.globals model_fn = snt.nets.MLP(output_sizes=[depth]) for unused_pass in range(num_recurrent_passes): # Update the node features with the function updated_nodes = model_fn(previous_graphs.nodes) updated_nodes = layernorm_node(updated_nodes) temporary_graph = previous_graphs.replace(nodes=updated_nodes) graph_sum0 = tf.reduce_sum(tf.reshape(tf.math.abs(temporary_graph.nodes), [-1, 4 * 512 * 300]), -1) # Send the node features to the edges that are being sent by that node. nodes_at_edges = blocks.broadcast_sender_nodes_to_edges(temporary_graph) graph_sum1 = tf.reduce_sum(tf.reshape(tf.math.abs(nodes_at_edges), [-1, 4 * 5551 * 300]), -1) temporary_graph = temporary_graph.replace(edges=nodes_at_edges) # Aggregate the all of the edges received by every node. nodes_with_aggregated_edges = blocks.ReceivedEdgesToNodesAggregator(tf.math.unsorted_segment_mean)( temporary_graph) graph_sum2 = tf.reduce_sum(tf.reshape(tf.math.abs(nodes_with_aggregated_edges), [-1, 4 * 512 * 300]), -1) previous_graphs = previous_graphs.replace(nodes=nodes_with_aggregated_edges) current_nodes = previous_graphs.nodes current_nodes = tf.reshape(current_nodes, [-1, 512, depth]) current_nodes = dropout(current_nodes, training=training) new_nodes = current_nodes * original_nodes previous_graphs = previous_graphs.replace(nodes=tf.reshape(new_nodes, [-1, depth])) old_global = previous_graphs.globals new_global = global_block(previous_graphs) previous_graphs = previous_graphs.replace(globals=global_dense(new_global)) previous_graphs = previous_graphs.replace(globals=layernorm_global(previous_graphs.globals)) output_global = tf.keras.layers.Dropout(FLAGS.dropout)(previous_graphs.globals, training=training) dense_layer = tf.keras.layers.Dense(1) logits = dense_layer(output_global) logits = tf.reshape(logits, [-1, num_choices]) def loss_function(real, pred): return tf.nn.sparse_softmax_cross_entropy_with_logits(tf.reshape(real, [-1]), pred) # Calculate the loss loss = loss_function(features["answer_id"], logits) predictions = { 'original': features["input_ids"], 'prediction': tf.argmax(logits, -1), 'correct': features["answer_id"], 'logits': logits, 'loss': loss, 'output_global': tf.reshape(output_global, [-1, 4, 300]), 'initial_global': tf.reshape(initial_global, [-1, 4, 300]), 'old_global': tf.reshape(old_global, [-1, 4, 300]), 'new_global': tf.reshape(new_global, [-1, 4, 300]), 'graph_sum0': graph_sum0, 'graph_sum1': graph_sum1, 'graph_sum2': graph_sum2, 'closest_nodes': tf.reshape(closest_nodes, [-1, 4, FLAGS.seq_len]), 'input_id': features["input_ids"], 'mask': tf.reshape(padding_mask, [-1, 4, FLAGS.seq_len]), 'encoded_question': tf.reshape(encoded_question, [-1, 4, FLAGS.seq_len, depth]) } if mode == tf.estimator.ModeKeys.PREDICT: export_outputs = { SIGNATURE_NAME: tf.estimator.export.PredictOutput(predictions) } return tf.estimator.EstimatorSpec(mode=mode, predictions=predictions, export_outputs=export_outputs) if mode == tf.estimator.ModeKeys.TRAIN: global_step = tf.compat.v1.train.get_or_create_global_step() optimizer = tf.compat.v1.train.AdamOptimizer(learning_rate=FLAGS.learning_rate, beta2=0.98, epsilon=1e-9) # Batch norm requires update ops to be added as a dependency to the train_op update_ops = tf.compat.v1.get_collection(tf.compat.v1.GraphKeys.UPDATE_OPS) with tf.control_dependencies(update_ops): train_op = optimizer.minimize(tf.reduce_mean(loss), global_step) else: train_op = None return tf.estimator.EstimatorSpec( mode=mode, predictions=predictions, loss=tf.reduce_mean(loss), train_op=train_op)