def test_unused_field_can_be_none( self, use_edges, use_nodes, use_globals, none_field): """Checks that computation can handle non-necessary fields left None.""" input_graph = self._get_input_graph([none_field]) node_block = blocks.NodeBlock( node_model_fn=self._node_model_fn, use_received_edges=use_edges, use_sent_edges=use_edges, use_nodes=use_nodes, use_globals=use_globals) output_graph = node_block(input_graph) model_inputs = [] if use_edges: model_inputs.append( blocks.ReceivedEdgesToNodesAggregator( tf.unsorted_segment_sum)(input_graph)) model_inputs.append( blocks.SentEdgesToNodesAggregator( tf.unsorted_segment_sum)(input_graph)) if use_nodes: model_inputs.append(input_graph.nodes) if use_globals: model_inputs.append(blocks.broadcast_globals_to_nodes(input_graph)) model_inputs = tf.concat(model_inputs, axis=-1) self.assertEqual(input_graph.edges, output_graph.edges) self.assertEqual(input_graph.globals, output_graph.globals) with self.test_session() as sess: actual_nodes, model_inputs_out = sess.run( (output_graph.nodes, model_inputs)) expected_output_nodes = model_inputs_out * self._scale self.assertNDArrayNear(expected_output_nodes, actual_nodes, err=1e-4)
def _build(self, input_graph, hidden_size=50, attn_scale=1.0, attn_dropout_keep_prob=1.0, regularizer=None, is_training=False): node_values = input_graph.nodes edge_values = input_graph.edges value_dims = node_values.shape[-1].value assert value_dims == edge_values.shape[-1].value # Compute edge values, sender feature + edge feature. # - edge_values = [total_num_edges, value_dims] edge_value_block = blocks.EdgeBlock(edge_model_fn=lambda: snt.Linear( output_size=value_dims, regularizers={'w': regularizer}), use_edges=True, use_receiver_nodes=True, use_sender_nodes=True, use_globals=False, name='update_edge_values') edge_values = edge_value_block(input_graph).edges tf.summary.histogram('mpnn/edge_values', edge_values) logits_block = blocks.EdgeBlock( edge_model_fn=lambda: snt.Linear(output_size=1, regularizers={'w': regularizer}), # edge_model_fn=lambda: snt.nets.MLP(output_sizes=[hidden_size, 1], # activation=tf.nn.tanh, # regularizers={'w': regularizer}), use_edges=True, use_receiver_nodes=True, use_sender_nodes=True, use_globals=False, name='update_attention_logits') attention_weights_logits = attn_scale * logits_block(input_graph).edges tf.summary.histogram('mpnn/logits', attention_weights_logits) normalized_attention_weight = modules._received_edges_normalizer( input_graph.replace(edges=attention_weights_logits), normalizer=self._normalizer) normalized_attention_weight = slim.dropout(normalized_attention_weight, attn_dropout_keep_prob, is_training=is_training) # Attending to sender values according to the weights. # - attended_edges = [total_num_edges, value_dims] attended_edges = edge_values * normalized_attention_weight # Summing all of the attended values from each node. # aggregated_attended_values = [total_num_nodes, embedding_size] received_edges_aggregator = blocks.ReceivedEdgesToNodesAggregator( reducer=tf.math.unsorted_segment_sum) aggregated_attended_values = received_edges_aggregator( input_graph.replace(edges=attended_edges)) return input_graph.replace(nodes=aggregated_attended_values, edges=edge_values)
def __init__(self, step_size, name="SpringMassSimulator"): super(SpringMassSimulator, self).__init__(name=name) self._step_size = step_size with self._enter_variable_scope(): self._aggregator = blocks.ReceivedEdgesToNodesAggregator( reducer=tf.unsorted_segment_sum)
def _build(self, node_values, node_keys, node_queries, attention_graph): """Connects the multi-head self-attention module. The self-attention is only computed according to the connectivity of the input graphs, with receiver nodes attending to sender nodes. Args: node_values: Tensor containing the values associated to each of the nodes. The expected shape is [total_num_nodes, num_heads, key_size]. node_keys: Tensor containing the key associated to each of the nodes. The expected shape is [total_num_nodes, num_heads, key_size]. node_queries: Tensor containing the query associated to each of the nodes. The expected shape is [total_num_nodes, num_heads, query_size]. The query size must be equal to the key size. attention_graph: Graph containing connectivity information between nodes via the senders and receivers fields. Node A will only attempt to attend to Node B if `attention_graph` contains an edge sent by Node A and received by Node B. Returns: An output `graphs.GraphsTuple` with updated nodes containing the aggregated attended value for each of the nodes with shape [total_num_nodes, num_heads, value_size]. Raises: ValueError: if the input graph does not have edges. """ # Sender nodes put their keys and values in the edges. sender_keys = blocks.broadcast_sender_nodes_to_edges( attention_graph.replace(nodes=node_keys)) sender_values = blocks.broadcast_sender_nodes_to_edges( attention_graph.replace(nodes=node_values)) # Receiver nodes put their queries in the edges. receiver_queries = blocks.broadcast_receiver_nodes_to_edges( attention_graph.replace(nodes=node_queries)) # Attention weight for each edge. attention_weights_logits = tf.reduce_sum(sender_keys * receiver_queries, axis=-1) normalized_attention_weights = _received_edges_normalizer( attention_graph.replace(edges=attention_weights_logits), normalizer=self._normalizer) # Attending to sender values according to the weights. attented_edges = sender_values * \ normalized_attention_weights[..., None] # Summing all of the attended values from each node. received_edges_aggregator = blocks.ReceivedEdgesToNodesAggregator( reducer=tf.unsorted_segment_sum) aggregated_attended_values = received_edges_aggregator( attention_graph.replace(edges=attented_edges)) return attention_graph.replace(nodes=aggregated_attended_values)
def __init__(self, node_model_fn, received_edges_reducer=tf.math.unsorted_segment_sum, sent_edges_reducer=tf.math.unsorted_segment_sum, name='dist_node_block'): super(NodeBlock, self).__init__(name=name) with self._enter_variable_scope(): self._received_edges_aggregator = blocks.ReceivedEdgesToNodesAggregator( received_edges_reducer) self._sent_edges_aggregator = blocks.SentEdgesToNodesAggregator( sent_edges_reducer) self._node_model = node_model_fn()
def test_output_values(self, use_received_edges, use_sent_edges, use_nodes, use_globals, received_edges_reducer, sent_edges_reducer): """Compares the output of a NodeBlock to an explicit computation.""" input_graph = self._get_input_graph() node_block = blocks.NodeBlock( node_model_fn=self._node_model_fn, use_received_edges=use_received_edges, use_sent_edges=use_sent_edges, use_nodes=use_nodes, use_globals=use_globals, received_edges_reducer=received_edges_reducer, sent_edges_reducer=sent_edges_reducer) output_graph = node_block(input_graph) model_inputs = [] if use_received_edges: model_inputs.append( blocks.ReceivedEdgesToNodesAggregator(received_edges_reducer)( input_graph)) if use_sent_edges: model_inputs.append( blocks.SentEdgesToNodesAggregator(sent_edges_reducer)( input_graph)) if use_nodes: model_inputs.append(input_graph.nodes) if use_globals: model_inputs.append(blocks.broadcast_globals_to_nodes(input_graph)) model_inputs = tf.concat(model_inputs, axis=-1) self.assertEqual(input_graph.edges, output_graph.edges) self.assertEqual(input_graph.globals, output_graph.globals) with self.test_session() as sess: output_graph_out, model_inputs_out = sess.run( (output_graph, model_inputs)) expected_output_nodes = model_inputs_out * self._scale self.assertNDArrayNear(expected_output_nodes, output_graph_out.nodes, err=1e-4)
def _build(self, labels, graph, num_steps): """ description: Updates each node according to its label, previous state and neighbours first, it passes concatenation of states of each adjacent node to the :param labels: Embedding of each node [n_nodes,embedding_length] :param graph: GraphTuple containing connectivity information between nodes via the senders and receivers fields. :return ret_graph: Graph after one step of message passing """ ret_graph = graph #lstm for updating nodes during each time step lstm = sn.LSTM(2 * labels.shape[1]) state = lstm.initial_state(labels.shape[0]) for _ in range(num_steps): #passing sender and receiver nodes through an MLP ret_graph = self._edge_block(ret_graph) #aggregating edges to nodes (summing up received edges per node) received_edges_aggregator = blocks.ReceivedEdgesToNodesAggregator( reducer=tf.unsorted_segment_sum) messages = received_edges_aggregator(ret_graph) #concatenating messages and labels for each node and then passing the result through an MLP hidden = self._node_model_fn(tf.concat(labels, messages, axis=1)) #passing hidden and state through an LSTM hidden, state = lstm(hidden, state) #aggregating nodes to global representation ret_graph = self._global_block(ret_graph.replace(nodes=hidden)) return ret_graph
def _build(self, node_values, node_keys, node_queries, attention_graph): # Sender nodes put their keys and values in the edges. # [total_num_edges, num_heads, query_size] sender_keys = blocks.broadcast_sender_nodes_to_edges( attention_graph.replace(nodes=node_keys)) # [total_num_edges, num_heads, value_size] sender_values = blocks.broadcast_sender_nodes_to_edges( attention_graph.replace(nodes=node_values)) # Receiver nodes put their queries in the edges. # [total_num_edges, num_heads, key_size] receiver_queries = blocks.broadcast_receiver_nodes_to_edges( attention_graph.replace(nodes=node_queries)) # Attention weight for each edge. # [total_num_edges, num_heads] attention_weights_logits = tf.reduce_sum( sender_keys * tf.transpose(receiver_queries), axis=-1) normalized_attention_weights = _received_edges_normalizer( attention_graph.replace(edges=attention_weights_logits), normalizer=self._normalizer) # Attending to sender values according to the weights. # [total_num_edges, num_heads, embedding_size] attented_edges = sender_values * normalized_attention_weights[..., None] # Summing all of the attended values from each node. # [total_num_nodes, num_heads, embedding_size] received_edges_aggregator = blocks.ReceivedEdgesToNodesAggregator( reducer=tf.unsorted_segment_sum) aggregated_attended_values = received_edges_aggregator( attention_graph.replace(edges=attented_edges)) return attention_graph.replace(nodes=aggregated_attended_values)
class FieldAggregatorsTest(GraphModuleTest): @parameterized.named_parameters( ("edges_to_globals", blocks.EdgesToGlobalsAggregator(tf.unsorted_segment_sum), SEGMENT_SUM_EDGES_TO_GLOBALS,), ("nodes_to_globals", blocks.NodesToGlobalsAggregator(tf.unsorted_segment_sum), SEGMENT_SUM_NODES_TO_GLOBALS,), ("sent_edges_to_nodes", blocks.SentEdgesToNodesAggregator(tf.unsorted_segment_sum), SEGMENT_SUM_SENT_EDGES_TO_NODES,), ("received_edges_to_nodes", blocks.ReceivedEdgesToNodesAggregator(tf.unsorted_segment_sum), SEGMENT_SUM_RECEIVED_EDGES_TO_NODES), ) def test_output_values(self, aggregator, expected): input_graph = self._get_input_graph() aggregated = aggregator(input_graph) with self.test_session() as sess: aggregated_out = sess.run(aggregated) self.assertNDArrayNear( np.array(expected, dtype=np.float32), aggregated_out, err=1e-4) @parameterized.named_parameters( ("edges_to_globals", blocks.EdgesToGlobalsAggregator(tf.unsorted_segment_sum), SEGMENT_SUM_EDGES_TO_GLOBALS,), ("nodes_to_globals", blocks.NodesToGlobalsAggregator(tf.unsorted_segment_sum), SEGMENT_SUM_NODES_TO_GLOBALS,), ("sent_edges_to_nodes", blocks.SentEdgesToNodesAggregator(tf.unsorted_segment_sum), SEGMENT_SUM_SENT_EDGES_TO_NODES,), ("received_edges_to_nodes", blocks.ReceivedEdgesToNodesAggregator(tf.unsorted_segment_sum), SEGMENT_SUM_RECEIVED_EDGES_TO_NODES), ) def test_output_values_larger_rank(self, aggregator, expected): input_graph = self._get_input_graph() input_graph = input_graph.map( lambda v: tf.reshape(v, [v.get_shape().as_list()[0]] + [2, -1])) aggregated = aggregator(input_graph) with self.test_session() as sess: aggregated_out = sess.run(aggregated) self.assertNDArrayNear( np.reshape(np.array(expected, dtype=np.float32), [len(expected)] + [2, -1]), aggregated_out, err=1e-4) @parameterized.named_parameters( ("received edges to nodes missing edges", blocks.ReceivedEdgesToNodesAggregator, "edges"), ("sent edges to nodes missing edges", blocks.SentEdgesToNodesAggregator, "edges"), ("nodes to globals missing nodes", blocks.NodesToGlobalsAggregator, "nodes"), ("edges to globals missing nodes", blocks.EdgesToGlobalsAggregator, "edges"),) def test_missing_field_raises_exception(self, constructor, none_field): """Tests that aggregator fail if a required field is missing.""" input_graph = self._get_input_graph([none_field]) with self.assertRaisesRegexp(ValueError, none_field): constructor(tf.unsorted_segment_sum)(input_graph) @parameterized.named_parameters( ("received edges to nodes missing nodes and globals", blocks.ReceivedEdgesToNodesAggregator, ["nodes", "globals"]), ("sent edges to nodes missing nodes and globals", blocks.SentEdgesToNodesAggregator, ["nodes", "globals"]), ("nodes to globals missing edges and globals", blocks.NodesToGlobalsAggregator, ["edges", "receivers", "senders", "globals"]), ("edges to globals missing globals", blocks.EdgesToGlobalsAggregator, ["globals"]), ) def test_unused_field_can_be_none(self, constructor, none_fields): """Tests that aggregator fail if a required field is missing.""" input_graph = self._get_input_graph(none_fields) constructor(tf.unsorted_segment_sum)(input_graph)
def _build(self, graph_features): """Connects the multi-head self-attention module. Uses edge_features to compute key, values and node_features for queries. The self-attention is only computed according to the connectivity of the input graphs, with receiver nodes attending to sender nodes. Args: graph_features: Graph containing connectivity information between nodes via the senders and receivers fields. Node A will only attempt to attend to Node B if `attention_graph` contains an edge sent by Node A and received by Node B. Returns: An output `graphs.GraphsTuple` with updated nodes containing the aggregated attended value for each of the nodes with shape [total_num_nodes, num_heads, value_size]. Raises: ValueError: if the input graph does not have edges. """ """ # TODO(arc): Figure out how to incorporate edge information into attention updates. """ edges = self._edge_block(graph_features).edges num_heads = self.num_heads key_size = self.key_size value_size = self.value_size node_embed_dim = tf.shape(graph_features.nodes)[-1] # [total_num_nodes, d] => [total_num_nodes, key_size * num_heads] q = self._attention_node_projection_model(graph_features.nodes) q = tf.reshape( q, [tf.reduce_sum(graph_features.n_node), num_heads, key_size]) # [total_num_edges, (key_size + value_size) * num_heads] # project edge features to get key, values kv = self._attention_edge_projection_model(edges) kv = tf.reshape(kv, [-1, num_heads, key_size + value_size]) # k => [total_num_edges, num_heads, key_size] # v => [total_num_edges, num_heads, value_size] k, v = tf.split(kv, [key_size, value_size], -1) sender_keys = k sender_values = v # Receiver nodes put their queries in the edges. # [total_num_edges, num_heads, key_size] receiver_queries = blocks.broadcast_receiver_nodes_to_edges( graph_features.replace(nodes=q)) # Attention weight for each edge. # [total_num_edges, num_heads, 1] attention_weights_logits = snt.BatchApply( self._query_key_product_model)(tf.concat( [sender_keys, receiver_queries], axis=-1)) # [total_num_edges, num_heads] attention_weights_logits = tf.squeeze(attention_weights_logits, -1) # compute softmax weights # [total_num_edges, num_heads] normalized_attention_weights = _received_edges_normalizer( graph_features.replace(edges=attention_weights_logits), normalizer=_unsorted_segment_softmax) # Attending to sender values according to the weights. # [total_num_edges, num_heads, value_size] attented_edges = sender_values * normalized_attention_weights[..., None] received_edges_aggregator = blocks.ReceivedEdgesToNodesAggregator( reducer=tf.unsorted_segment_sum) # Summing all of the attended values from each node. # [total_num_nodes, num_heads, value_size] aggregated_attended_values = received_edges_aggregator( graph_features.replace(edges=attented_edges)) # concatenate all the heads and project to required dimension. # cast to [total_num_nodes, num_heads * value_size] aggregated_attended_values = tf.reshape(aggregated_attended_values, [-1, num_heads * value_size]) # -> [total_num_nodes, node_embed_dim] aggregated_attended_values = self._node_model( aggregated_attended_values) return self._global_block( graph_features.replace(nodes=aggregated_attended_values, edges=edges))
tvars = graph_network.trainable_variables print('') ############### # broadcast graphs_tuple = utils_tf.data_dicts_to_graphs_tuple([data_dict_0]) updated_broadcast_globals_to_nodes = graphs_tuple.replace( nodes=blocks.broadcast_globals_to_nodes(graphs_tuple)) updated_broadcast_globals_to_edges = graphs_tuple.replace( edges=blocks.broadcast_globals_to_edges(graphs_tuple)) updated_broadcast_sender_nodes_to_edges = graphs_tuple.replace( edges=blocks.broadcast_sender_nodes_to_edges(graphs_tuple)) updated_broadcast_receiver_nodes_to_edges = graphs_tuple.replace( edges=blocks.broadcast_receiver_nodes_to_edges(graphs_tuple)) ############ # aggregate graphs_tuple = utils_tf.data_dicts_to_graphs_tuple([data_dict_0]) reducer = tf.math.unsorted_segment_sum #######yr updated_edges_to_globals = graphs_tuple.replace( globals=blocks.EdgesToGlobalsAggregator(reducer=reducer)(graphs_tuple)) updated_nodes_to_globals = graphs_tuple.replace( globals=blocks.NodesToGlobalsAggregator(reducer=reducer)(graphs_tuple)) updated_sent_edges_to_nodes = graphs_tuple.replace( nodes=blocks.SentEdgesToNodesAggregator(reducer=reducer)(graphs_tuple)) updated_received_edges_to_nodes = graphs_tuple.replace( nodes=blocks.ReceivedEdgesToNodesAggregator(reducer=reducer)(graphs_tuple))
def model_fn(features, labels, mode, params): sentences = features["input_ids"] word_embedding = tf.constant(params['word_embedding']) graph_nodes = params['graph_nodes'] graph_edges = params['graph_edges'] depth = graph_nodes.shape[1] training = mode == tf.estimator.ModeKeys.TRAIN padding_mask = tf.cast(tf.not_equal(tf.cast(sentences, tf.int32), tf.constant([[1]])), tf.int32) # 0 means the token needs to be masked. 1 means it is not masked. padding_mask = tf.reshape(padding_mask, [-1, FLAGS.seq_len, 1]) sentences = tf.nn.embedding_lookup(word_embedding, sentences) sentences = tf.reshape(sentences, [-1, FLAGS.seq_len, depth]) # print("sentences: " + str(sentences)) # print("padding_mask: " + str(padding_mask)) question_encoder = tf.keras.layers.LSTM(depth, dropout=FLAGS.dropout, return_sequences=True) encoded_question = question_encoder(sentences, training=training) encoded_question = tf.cast(padding_mask, tf.float32) * tf.cast(encoded_question, tf.float32) encoded_question = tf.reshape(tf.cast(encoded_question, tf.float32), [-1, depth]) # The template graph nodes = graph_nodes.astype(np.float32) edges = np.ones([int(np.sum(graph_edges)), 1]).astype(np.float32) senders, receivers = np.nonzero(graph_edges) globals = np.zeros(FLAGS.global_size).astype(np.float32) graph_dict = {"globals": globals, "nodes": nodes, "edges": edges, "senders": senders, "receivers": receivers} original_graph = utils_tf.data_dicts_to_graphs_tuple([graph_dict]) graph_dict["nodes"] = nodes * 0 # print("encoded_question.shape[0]: " + str(encoded_question.shape[0])) batch_of_tensor_data_dicts = [graph_dict for i in range(sentences.shape[0])] batch_of_graphs = utils_tf.data_dicts_to_graphs_tuple(batch_of_tensor_data_dicts) batch_of_nodes = batch_of_graphs.nodes # print("batch_of_nodes: " + str(batch_of_nodes)) # Euclidean distance to identify closest nodes na = tf.reduce_sum(tf.square(tf.math.l2_normalize(encoded_question, -1)), 1) nb = tf.reduce_sum(tf.square(tf.math.l2_normalize(nodes, -1)), 1) # na as a row and nb as a column vectors na = tf.reshape(na, [-1, 1]) nb = tf.reshape(nb, [1, -1]) # return pairwise euclidead difference matrix distance = tf.sqrt(tf.maximum(na - 2 * tf.matmul(encoded_question, nodes, False, True) + nb, 0.0)) # calculate attention over the graph closest_nodes = tf.cast(tf.argmin(distance, -1), tf.int32) # print("closest_nodes: " + str(closest_nodes)) # # Write the signals onto these nodes positions = tf.where(tf.not_equal(tf.reshape(closest_nodes, [-1, FLAGS.seq_len]), 99999)) # print("positions: " + str(positions)) positions = tf.slice(positions, [0, 0], [-1, 1]) # we only want the first 2 dimensions, since the last dimension is incorrect # print("positions: " + str(positions)) positions = tf.cast(positions, tf.int32) # print("positions: " + str(positions)) positions = tf.concat([positions, tf.reshape(closest_nodes, [-1, 1])], -1) # print("positions: " + str(positions)) # print("compressed: " + str(compressed1)) # print("norm_duplicate: " + str(tf.reshape(norm_duplicate, [-1, 1]))) projection_signal = tf.reshape(encoded_question, [-1, depth]) # print("projection_signal: " + str(projection_signal)) batch_of_nodes = tf.tensor_scatter_nd_add(tf.reshape(batch_of_nodes, [-1, 512, depth]), positions, projection_signal) # print("batch_of_nodes: " + str(batch_of_nodes)) batch_of_graphs = batch_of_graphs.replace(nodes=tf.reshape(batch_of_nodes, [-1, depth])) global_block = blocks.NodesToGlobalsAggregator(tf.math.unsorted_segment_mean) global_dense = tf.keras.layers.Dense(depth, activation='relu') num_recurrent_passes = FLAGS.recurrences previous_graphs = batch_of_graphs original_nodes = tf.reshape(original_graph.nodes, [1, 512, depth]) dropout = tf.keras.layers.Dropout(FLAGS.dropout) layernorm_global = tf.keras.layers.LayerNormalization(epsilon=1e-6) layernorm_node = tf.keras.layers.LayerNormalization(epsilon=1e-6) new_global = global_block(previous_graphs) previous_graphs = previous_graphs.replace(globals=global_dense(new_global)) previous_graphs = previous_graphs.replace(globals=layernorm_global(previous_graphs.globals)) initial_global = previous_graphs.globals model_fn = snt.nets.MLP(output_sizes=[depth]) for unused_pass in range(num_recurrent_passes): # Update the node features with the function updated_nodes = model_fn(previous_graphs.nodes) updated_nodes = layernorm_node(updated_nodes) temporary_graph = previous_graphs.replace(nodes=updated_nodes) graph_sum0 = tf.reduce_sum(tf.reshape(tf.math.abs(temporary_graph.nodes), [-1, 4 * 512 * 300]), -1) # Send the node features to the edges that are being sent by that node. nodes_at_edges = blocks.broadcast_sender_nodes_to_edges(temporary_graph) graph_sum1 = tf.reduce_sum(tf.reshape(tf.math.abs(nodes_at_edges), [-1, 4 * 5551 * 300]), -1) temporary_graph = temporary_graph.replace(edges=nodes_at_edges) # Aggregate the all of the edges received by every node. nodes_with_aggregated_edges = blocks.ReceivedEdgesToNodesAggregator(tf.math.unsorted_segment_mean)( temporary_graph) graph_sum2 = tf.reduce_sum(tf.reshape(tf.math.abs(nodes_with_aggregated_edges), [-1, 4 * 512 * 300]), -1) previous_graphs = previous_graphs.replace(nodes=nodes_with_aggregated_edges) current_nodes = previous_graphs.nodes current_nodes = tf.reshape(current_nodes, [-1, 512, depth]) current_nodes = dropout(current_nodes, training=training) new_nodes = current_nodes * original_nodes previous_graphs = previous_graphs.replace(nodes=tf.reshape(new_nodes, [-1, depth])) old_global = previous_graphs.globals new_global = global_block(previous_graphs) previous_graphs = previous_graphs.replace(globals=global_dense(new_global)) previous_graphs = previous_graphs.replace(globals=layernorm_global(previous_graphs.globals)) output_global = tf.keras.layers.Dropout(FLAGS.dropout)(previous_graphs.globals, training=training) dense_layer = tf.keras.layers.Dense(1) logits = dense_layer(output_global) logits = tf.reshape(logits, [-1, num_choices]) def loss_function(real, pred): return tf.nn.sparse_softmax_cross_entropy_with_logits(tf.reshape(real, [-1]), pred) # Calculate the loss loss = loss_function(features["answer_id"], logits) predictions = { 'original': features["input_ids"], 'prediction': tf.argmax(logits, -1), 'correct': features["answer_id"], 'logits': logits, 'loss': loss, 'output_global': tf.reshape(output_global, [-1, 4, 300]), 'initial_global': tf.reshape(initial_global, [-1, 4, 300]), 'old_global': tf.reshape(old_global, [-1, 4, 300]), 'new_global': tf.reshape(new_global, [-1, 4, 300]), 'graph_sum0': graph_sum0, 'graph_sum1': graph_sum1, 'graph_sum2': graph_sum2, 'closest_nodes': tf.reshape(closest_nodes, [-1, 4, FLAGS.seq_len]), 'input_id': features["input_ids"], 'mask': tf.reshape(padding_mask, [-1, 4, FLAGS.seq_len]), 'encoded_question': tf.reshape(encoded_question, [-1, 4, FLAGS.seq_len, depth]) } if mode == tf.estimator.ModeKeys.PREDICT: export_outputs = { SIGNATURE_NAME: tf.estimator.export.PredictOutput(predictions) } return tf.estimator.EstimatorSpec(mode=mode, predictions=predictions, export_outputs=export_outputs) if mode == tf.estimator.ModeKeys.TRAIN: global_step = tf.compat.v1.train.get_or_create_global_step() optimizer = tf.compat.v1.train.AdamOptimizer(learning_rate=FLAGS.learning_rate, beta2=0.98, epsilon=1e-9) # Batch norm requires update ops to be added as a dependency to the train_op update_ops = tf.compat.v1.get_collection(tf.compat.v1.GraphKeys.UPDATE_OPS) with tf.control_dependencies(update_ops): train_op = optimizer.minimize(tf.reduce_mean(loss), global_step) else: train_op = None return tf.estimator.EstimatorSpec( mode=mode, predictions=predictions, loss=tf.reduce_mean(loss), train_op=train_op)