def test_unused_field_can_be_none( self, use_edges, use_nodes, use_globals, none_field): """Checks that computation can handle non-necessary fields left None.""" input_graph = self._get_input_graph([none_field]) global_block = blocks.GlobalBlock( global_model_fn=self._global_model_fn, use_edges=use_edges, use_nodes=use_nodes, use_globals=use_globals) output_graph = global_block(input_graph) model_inputs = [] if use_edges: model_inputs.append( blocks.EdgesToGlobalsAggregator(tf.unsorted_segment_sum)(input_graph)) if use_nodes: model_inputs.append( blocks.NodesToGlobalsAggregator(tf.unsorted_segment_sum)(input_graph)) if use_globals: model_inputs.append(input_graph.globals) model_inputs = tf.concat(model_inputs, axis=-1) self.assertEqual(input_graph.edges, output_graph.edges) self.assertEqual(input_graph.nodes, output_graph.nodes) with self.test_session() as sess: actual_globals, model_inputs_out = sess.run( (output_graph.globals, model_inputs)) expected_output_globals = model_inputs_out * self._scale self.assertNDArrayNear(expected_output_globals, actual_globals, err=1e-4)
def test_output_values( self, use_edges, use_nodes, use_globals, edges_reducer, nodes_reducer): """Compares the output of a GlobalBlock to an explicit computation.""" input_graph = self._get_input_graph() global_block = blocks.GlobalBlock( global_model_fn=self._global_model_fn, use_edges=use_edges, use_nodes=use_nodes, use_globals=use_globals, edges_reducer=edges_reducer, nodes_reducer=nodes_reducer) output_graph = global_block(input_graph) model_inputs = [] if use_edges: model_inputs.append( blocks.EdgesToGlobalsAggregator(edges_reducer)(input_graph)) if use_nodes: model_inputs.append( blocks.NodesToGlobalsAggregator(nodes_reducer)(input_graph)) if use_globals: model_inputs.append(input_graph.globals) model_inputs = tf.concat(model_inputs, axis=-1) self.assertEqual(input_graph.edges, output_graph.edges) self.assertEqual(input_graph.nodes, output_graph.nodes) with self.test_session() as sess: output_graph_out, model_inputs_out = sess.run( (output_graph, model_inputs)) expected_output_globals = model_inputs_out * self._scale self.assertNDArrayNear( expected_output_globals, output_graph_out.globals, err=1e-4)
def _build(self, inputs): graph = inputs['graph'] nodes = self._input_dense(graph.nodes) q = self._q_dense(nodes) k = self._k_dense(nodes) v = self._v_dense(nodes) q = self.split_heads(q, self.FLAGS) k = self.split_heads(k, self.FLAGS) v = self.split_heads(v, self.FLAGS) attention_graph = self._sa(node_values=v, node_keys=k, node_queries=q, attention_graph=graph) attention_output = self.combine_heads(attention_graph.nodes, self.FLAGS) attention_output = self._output_dense(attention_output) sa_skip = nodes + attention_output # residual/skip connection sa_normed = self._sa_laynorm(sa_skip) # apply layer norm ff_skip = sa_normed + self._ff(sa_normed) # residual/skip connection ff_normed = self._ff_laynorm(ff_skip) # apply layer norm # nodes to global aggregator with graph if self.FLAGS['tf_gate']: ff_normed_doub = self._doub_dense(ff_normed) # TODO: try raw activation to make sure this is not just helping because of more weights weights = tf.tanh(ff_normed_doub[:, :self.hidden_size]) vals = tf.sigmoid(ff_normed_doub[:, self.hidden_size:]) gated = weights * vals out_graph = attention_graph.replace(nodes=gated) else: out_graph = attention_graph.replace(nodes=ff_normed) reducer = { 'sum': tf.unsorted_segment_sum, 'max': blocks.unsorted_segment_max_or_zero }[self.FLAGS['tf_reducer']] agg = blocks.NodesToGlobalsAggregator(reducer=reducer) agged = agg(out_graph) agged = agged / tf.cast(out_graph.n_node, tf.float32)[:, None] return agged
class FieldAggregatorsTest(GraphModuleTest): @parameterized.named_parameters( ("edges_to_globals", blocks.EdgesToGlobalsAggregator(tf.unsorted_segment_sum), SEGMENT_SUM_EDGES_TO_GLOBALS,), ("nodes_to_globals", blocks.NodesToGlobalsAggregator(tf.unsorted_segment_sum), SEGMENT_SUM_NODES_TO_GLOBALS,), ("sent_edges_to_nodes", blocks.SentEdgesToNodesAggregator(tf.unsorted_segment_sum), SEGMENT_SUM_SENT_EDGES_TO_NODES,), ("received_edges_to_nodes", blocks.ReceivedEdgesToNodesAggregator(tf.unsorted_segment_sum), SEGMENT_SUM_RECEIVED_EDGES_TO_NODES), ) def test_output_values(self, aggregator, expected): input_graph = self._get_input_graph() aggregated = aggregator(input_graph) with self.test_session() as sess: aggregated_out = sess.run(aggregated) self.assertNDArrayNear( np.array(expected, dtype=np.float32), aggregated_out, err=1e-4) @parameterized.named_parameters( ("edges_to_globals", blocks.EdgesToGlobalsAggregator(tf.unsorted_segment_sum), SEGMENT_SUM_EDGES_TO_GLOBALS,), ("nodes_to_globals", blocks.NodesToGlobalsAggregator(tf.unsorted_segment_sum), SEGMENT_SUM_NODES_TO_GLOBALS,), ("sent_edges_to_nodes", blocks.SentEdgesToNodesAggregator(tf.unsorted_segment_sum), SEGMENT_SUM_SENT_EDGES_TO_NODES,), ("received_edges_to_nodes", blocks.ReceivedEdgesToNodesAggregator(tf.unsorted_segment_sum), SEGMENT_SUM_RECEIVED_EDGES_TO_NODES), ) def test_output_values_larger_rank(self, aggregator, expected): input_graph = self._get_input_graph() input_graph = input_graph.map( lambda v: tf.reshape(v, [v.get_shape().as_list()[0]] + [2, -1])) aggregated = aggregator(input_graph) with self.test_session() as sess: aggregated_out = sess.run(aggregated) self.assertNDArrayNear( np.reshape(np.array(expected, dtype=np.float32), [len(expected)] + [2, -1]), aggregated_out, err=1e-4) @parameterized.named_parameters( ("received edges to nodes missing edges", blocks.ReceivedEdgesToNodesAggregator, "edges"), ("sent edges to nodes missing edges", blocks.SentEdgesToNodesAggregator, "edges"), ("nodes to globals missing nodes", blocks.NodesToGlobalsAggregator, "nodes"), ("edges to globals missing nodes", blocks.EdgesToGlobalsAggregator, "edges"),) def test_missing_field_raises_exception(self, constructor, none_field): """Tests that aggregator fail if a required field is missing.""" input_graph = self._get_input_graph([none_field]) with self.assertRaisesRegexp(ValueError, none_field): constructor(tf.unsorted_segment_sum)(input_graph) @parameterized.named_parameters( ("received edges to nodes missing nodes and globals", blocks.ReceivedEdgesToNodesAggregator, ["nodes", "globals"]), ("sent edges to nodes missing nodes and globals", blocks.SentEdgesToNodesAggregator, ["nodes", "globals"]), ("nodes to globals missing edges and globals", blocks.NodesToGlobalsAggregator, ["edges", "receivers", "senders", "globals"]), ("edges to globals missing globals", blocks.EdgesToGlobalsAggregator, ["globals"]), ) def test_unused_field_can_be_none(self, constructor, none_fields): """Tests that aggregator fail if a required field is missing.""" input_graph = self._get_input_graph(none_fields) constructor(tf.unsorted_segment_sum)(input_graph)
tvars = graph_network.trainable_variables print('') ############### # broadcast graphs_tuple = utils_tf.data_dicts_to_graphs_tuple([data_dict_0]) updated_broadcast_globals_to_nodes = graphs_tuple.replace( nodes=blocks.broadcast_globals_to_nodes(graphs_tuple)) updated_broadcast_globals_to_edges = graphs_tuple.replace( edges=blocks.broadcast_globals_to_edges(graphs_tuple)) updated_broadcast_sender_nodes_to_edges = graphs_tuple.replace( edges=blocks.broadcast_sender_nodes_to_edges(graphs_tuple)) updated_broadcast_receiver_nodes_to_edges = graphs_tuple.replace( edges=blocks.broadcast_receiver_nodes_to_edges(graphs_tuple)) ############ # aggregate graphs_tuple = utils_tf.data_dicts_to_graphs_tuple([data_dict_0]) reducer = tf.math.unsorted_segment_sum #######yr updated_edges_to_globals = graphs_tuple.replace( globals=blocks.EdgesToGlobalsAggregator(reducer=reducer)(graphs_tuple)) updated_nodes_to_globals = graphs_tuple.replace( globals=blocks.NodesToGlobalsAggregator(reducer=reducer)(graphs_tuple)) updated_sent_edges_to_nodes = graphs_tuple.replace( nodes=blocks.SentEdgesToNodesAggregator(reducer=reducer)(graphs_tuple)) updated_received_edges_to_nodes = graphs_tuple.replace( nodes=blocks.ReceivedEdgesToNodesAggregator(reducer=reducer)(graphs_tuple))
def model_fn(features, labels, mode, params): sentences = features["input_ids"] word_embedding = tf.constant(params['word_embedding']) graph_nodes = params['graph_nodes'] graph_edges = params['graph_edges'] depth = graph_nodes.shape[1] training = mode == tf.estimator.ModeKeys.TRAIN padding_mask = tf.cast(tf.not_equal(tf.cast(sentences, tf.int32), tf.constant([[1]])), tf.int32) # 0 means the token needs to be masked. 1 means it is not masked. padding_mask = tf.reshape(padding_mask, [-1, FLAGS.seq_len, 1]) sentences = tf.nn.embedding_lookup(word_embedding, sentences) sentences = tf.reshape(sentences, [-1, FLAGS.seq_len, depth]) # print("sentences: " + str(sentences)) # print("padding_mask: " + str(padding_mask)) question_encoder = tf.keras.layers.LSTM(depth, dropout=FLAGS.dropout, return_sequences=True) encoded_question = question_encoder(sentences, training=training) encoded_question = tf.cast(padding_mask, tf.float32) * tf.cast(encoded_question, tf.float32) encoded_question = tf.reshape(tf.cast(encoded_question, tf.float32), [-1, depth]) # The template graph nodes = graph_nodes.astype(np.float32) edges = np.ones([int(np.sum(graph_edges)), 1]).astype(np.float32) senders, receivers = np.nonzero(graph_edges) globals = np.zeros(FLAGS.global_size).astype(np.float32) graph_dict = {"globals": globals, "nodes": nodes, "edges": edges, "senders": senders, "receivers": receivers} original_graph = utils_tf.data_dicts_to_graphs_tuple([graph_dict]) graph_dict["nodes"] = nodes * 0 # print("encoded_question.shape[0]: " + str(encoded_question.shape[0])) batch_of_tensor_data_dicts = [graph_dict for i in range(sentences.shape[0])] batch_of_graphs = utils_tf.data_dicts_to_graphs_tuple(batch_of_tensor_data_dicts) batch_of_nodes = batch_of_graphs.nodes # print("batch_of_nodes: " + str(batch_of_nodes)) # Euclidean distance to identify closest nodes na = tf.reduce_sum(tf.square(tf.math.l2_normalize(encoded_question, -1)), 1) nb = tf.reduce_sum(tf.square(tf.math.l2_normalize(nodes, -1)), 1) # na as a row and nb as a column vectors na = tf.reshape(na, [-1, 1]) nb = tf.reshape(nb, [1, -1]) # return pairwise euclidead difference matrix distance = tf.sqrt(tf.maximum(na - 2 * tf.matmul(encoded_question, nodes, False, True) + nb, 0.0)) # calculate attention over the graph closest_nodes = tf.cast(tf.argmin(distance, -1), tf.int32) # print("closest_nodes: " + str(closest_nodes)) # # Write the signals onto these nodes positions = tf.where(tf.not_equal(tf.reshape(closest_nodes, [-1, FLAGS.seq_len]), 99999)) # print("positions: " + str(positions)) positions = tf.slice(positions, [0, 0], [-1, 1]) # we only want the first 2 dimensions, since the last dimension is incorrect # print("positions: " + str(positions)) positions = tf.cast(positions, tf.int32) # print("positions: " + str(positions)) positions = tf.concat([positions, tf.reshape(closest_nodes, [-1, 1])], -1) # print("positions: " + str(positions)) # print("compressed: " + str(compressed1)) # print("norm_duplicate: " + str(tf.reshape(norm_duplicate, [-1, 1]))) projection_signal = tf.reshape(encoded_question, [-1, depth]) # print("projection_signal: " + str(projection_signal)) batch_of_nodes = tf.tensor_scatter_nd_add(tf.reshape(batch_of_nodes, [-1, 512, depth]), positions, projection_signal) # print("batch_of_nodes: " + str(batch_of_nodes)) batch_of_graphs = batch_of_graphs.replace(nodes=tf.reshape(batch_of_nodes, [-1, depth])) global_block = blocks.NodesToGlobalsAggregator(tf.math.unsorted_segment_mean) global_dense = tf.keras.layers.Dense(depth, activation='relu') num_recurrent_passes = FLAGS.recurrences previous_graphs = batch_of_graphs original_nodes = tf.reshape(original_graph.nodes, [1, 512, depth]) dropout = tf.keras.layers.Dropout(FLAGS.dropout) layernorm_global = tf.keras.layers.LayerNormalization(epsilon=1e-6) layernorm_node = tf.keras.layers.LayerNormalization(epsilon=1e-6) new_global = global_block(previous_graphs) previous_graphs = previous_graphs.replace(globals=global_dense(new_global)) previous_graphs = previous_graphs.replace(globals=layernorm_global(previous_graphs.globals)) initial_global = previous_graphs.globals model_fn = snt.nets.MLP(output_sizes=[depth]) for unused_pass in range(num_recurrent_passes): # Update the node features with the function updated_nodes = model_fn(previous_graphs.nodes) updated_nodes = layernorm_node(updated_nodes) temporary_graph = previous_graphs.replace(nodes=updated_nodes) graph_sum0 = tf.reduce_sum(tf.reshape(tf.math.abs(temporary_graph.nodes), [-1, 4 * 512 * 300]), -1) # Send the node features to the edges that are being sent by that node. nodes_at_edges = blocks.broadcast_sender_nodes_to_edges(temporary_graph) graph_sum1 = tf.reduce_sum(tf.reshape(tf.math.abs(nodes_at_edges), [-1, 4 * 5551 * 300]), -1) temporary_graph = temporary_graph.replace(edges=nodes_at_edges) # Aggregate the all of the edges received by every node. nodes_with_aggregated_edges = blocks.ReceivedEdgesToNodesAggregator(tf.math.unsorted_segment_mean)( temporary_graph) graph_sum2 = tf.reduce_sum(tf.reshape(tf.math.abs(nodes_with_aggregated_edges), [-1, 4 * 512 * 300]), -1) previous_graphs = previous_graphs.replace(nodes=nodes_with_aggregated_edges) current_nodes = previous_graphs.nodes current_nodes = tf.reshape(current_nodes, [-1, 512, depth]) current_nodes = dropout(current_nodes, training=training) new_nodes = current_nodes * original_nodes previous_graphs = previous_graphs.replace(nodes=tf.reshape(new_nodes, [-1, depth])) old_global = previous_graphs.globals new_global = global_block(previous_graphs) previous_graphs = previous_graphs.replace(globals=global_dense(new_global)) previous_graphs = previous_graphs.replace(globals=layernorm_global(previous_graphs.globals)) output_global = tf.keras.layers.Dropout(FLAGS.dropout)(previous_graphs.globals, training=training) dense_layer = tf.keras.layers.Dense(1) logits = dense_layer(output_global) logits = tf.reshape(logits, [-1, num_choices]) def loss_function(real, pred): return tf.nn.sparse_softmax_cross_entropy_with_logits(tf.reshape(real, [-1]), pred) # Calculate the loss loss = loss_function(features["answer_id"], logits) predictions = { 'original': features["input_ids"], 'prediction': tf.argmax(logits, -1), 'correct': features["answer_id"], 'logits': logits, 'loss': loss, 'output_global': tf.reshape(output_global, [-1, 4, 300]), 'initial_global': tf.reshape(initial_global, [-1, 4, 300]), 'old_global': tf.reshape(old_global, [-1, 4, 300]), 'new_global': tf.reshape(new_global, [-1, 4, 300]), 'graph_sum0': graph_sum0, 'graph_sum1': graph_sum1, 'graph_sum2': graph_sum2, 'closest_nodes': tf.reshape(closest_nodes, [-1, 4, FLAGS.seq_len]), 'input_id': features["input_ids"], 'mask': tf.reshape(padding_mask, [-1, 4, FLAGS.seq_len]), 'encoded_question': tf.reshape(encoded_question, [-1, 4, FLAGS.seq_len, depth]) } if mode == tf.estimator.ModeKeys.PREDICT: export_outputs = { SIGNATURE_NAME: tf.estimator.export.PredictOutput(predictions) } return tf.estimator.EstimatorSpec(mode=mode, predictions=predictions, export_outputs=export_outputs) if mode == tf.estimator.ModeKeys.TRAIN: global_step = tf.compat.v1.train.get_or_create_global_step() optimizer = tf.compat.v1.train.AdamOptimizer(learning_rate=FLAGS.learning_rate, beta2=0.98, epsilon=1e-9) # Batch norm requires update ops to be added as a dependency to the train_op update_ops = tf.compat.v1.get_collection(tf.compat.v1.GraphKeys.UPDATE_OPS) with tf.control_dependencies(update_ops): train_op = optimizer.minimize(tf.reduce_mean(loss), global_step) else: train_op = None return tf.estimator.EstimatorSpec( mode=mode, predictions=predictions, loss=tf.reduce_mean(loss), train_op=train_op)