Beispiel #1
0
  def test_unused_field_can_be_none(
      self, use_edges, use_nodes, use_globals, none_field):
    """Checks that computation can handle non-necessary fields left None."""
    input_graph = self._get_input_graph([none_field])
    node_block = blocks.NodeBlock(
        node_model_fn=self._node_model_fn,
        use_received_edges=use_edges,
        use_sent_edges=use_edges,
        use_nodes=use_nodes,
        use_globals=use_globals)
    output_graph = node_block(input_graph)

    model_inputs = []
    if use_edges:
      model_inputs.append(
          blocks.ReceivedEdgesToNodesAggregator(
              tf.unsorted_segment_sum)(input_graph))
      model_inputs.append(
          blocks.SentEdgesToNodesAggregator(
              tf.unsorted_segment_sum)(input_graph))
    if use_nodes:
      model_inputs.append(input_graph.nodes)
    if use_globals:
      model_inputs.append(blocks.broadcast_globals_to_nodes(input_graph))

    model_inputs = tf.concat(model_inputs, axis=-1)
    self.assertEqual(input_graph.edges, output_graph.edges)
    self.assertEqual(input_graph.globals, output_graph.globals)

    with self.test_session() as sess:
      actual_nodes, model_inputs_out = sess.run(
          (output_graph.nodes, model_inputs))

    expected_output_nodes = model_inputs_out * self._scale
    self.assertNDArrayNear(expected_output_nodes, actual_nodes, err=1e-4)
Beispiel #2
0
    def _build(self,
               input_graph,
               hidden_size=50,
               attn_scale=1.0,
               attn_dropout_keep_prob=1.0,
               regularizer=None,
               is_training=False):

        node_values = input_graph.nodes
        edge_values = input_graph.edges

        value_dims = node_values.shape[-1].value
        assert value_dims == edge_values.shape[-1].value

        # Compute edge values, sender feature + edge feature.
        # - edge_values = [total_num_edges, value_dims]
        edge_value_block = blocks.EdgeBlock(edge_model_fn=lambda: snt.Linear(
            output_size=value_dims, regularizers={'w': regularizer}),
                                            use_edges=True,
                                            use_receiver_nodes=True,
                                            use_sender_nodes=True,
                                            use_globals=False,
                                            name='update_edge_values')
        edge_values = edge_value_block(input_graph).edges
        tf.summary.histogram('mpnn/edge_values', edge_values)

        logits_block = blocks.EdgeBlock(
            edge_model_fn=lambda: snt.Linear(output_size=1,
                                             regularizers={'w': regularizer}),
            # edge_model_fn=lambda: snt.nets.MLP(output_sizes=[hidden_size, 1],
            #                                    activation=tf.nn.tanh,
            #                                    regularizers={'w': regularizer}),
            use_edges=True,
            use_receiver_nodes=True,
            use_sender_nodes=True,
            use_globals=False,
            name='update_attention_logits')
        attention_weights_logits = attn_scale * logits_block(input_graph).edges
        tf.summary.histogram('mpnn/logits', attention_weights_logits)

        normalized_attention_weight = modules._received_edges_normalizer(
            input_graph.replace(edges=attention_weights_logits),
            normalizer=self._normalizer)
        normalized_attention_weight = slim.dropout(normalized_attention_weight,
                                                   attn_dropout_keep_prob,
                                                   is_training=is_training)

        # Attending to sender values according to the weights.
        # - attended_edges = [total_num_edges, value_dims]
        attended_edges = edge_values * normalized_attention_weight

        # Summing all of the attended values from each node.
        # aggregated_attended_values = [total_num_nodes, embedding_size]
        received_edges_aggregator = blocks.ReceivedEdgesToNodesAggregator(
            reducer=tf.math.unsorted_segment_sum)
        aggregated_attended_values = received_edges_aggregator(
            input_graph.replace(edges=attended_edges))

        return input_graph.replace(nodes=aggregated_attended_values,
                                   edges=edge_values)
Beispiel #3
0
    def __init__(self, step_size, name="SpringMassSimulator"):
        super(SpringMassSimulator, self).__init__(name=name)
        self._step_size = step_size

        with self._enter_variable_scope():
            self._aggregator = blocks.ReceivedEdgesToNodesAggregator(
                reducer=tf.unsorted_segment_sum)
Beispiel #4
0
    def _build(self, node_values, node_keys, node_queries, attention_graph):
        """Connects the multi-head self-attention module.

        The self-attention is only computed according to the connectivity of the
        input graphs, with receiver nodes attending to sender nodes.

        Args:
          node_values: Tensor containing the values associated to each of the nodes.
            The expected shape is [total_num_nodes, num_heads, key_size].
          node_keys: Tensor containing the key associated to each of the nodes. The
            expected shape is [total_num_nodes, num_heads, key_size].
          node_queries: Tensor containing the query associated to each of the nodes.
            The expected shape is [total_num_nodes, num_heads, query_size]. The
            query size must be equal to the key size.
          attention_graph: Graph containing connectivity information between nodes
            via the senders and receivers fields. Node A will only attempt to attend
            to Node B if `attention_graph` contains an edge sent by Node A and
            received by Node B.

        Returns:
          An output `graphs.GraphsTuple` with updated nodes containing the
          aggregated attended value for each of the nodes with shape
          [total_num_nodes, num_heads, value_size].

        Raises:
          ValueError: if the input graph does not have edges.
        """

        # Sender nodes put their keys and values in the edges.
        sender_keys = blocks.broadcast_sender_nodes_to_edges(
            attention_graph.replace(nodes=node_keys))
        sender_values = blocks.broadcast_sender_nodes_to_edges(
            attention_graph.replace(nodes=node_values))

        # Receiver nodes put their queries in the edges.
        receiver_queries = blocks.broadcast_receiver_nodes_to_edges(
            attention_graph.replace(nodes=node_queries))

        # Attention weight for each edge.
        attention_weights_logits = tf.reduce_sum(sender_keys *
                                                 receiver_queries,
                                                 axis=-1)
        normalized_attention_weights = _received_edges_normalizer(
            attention_graph.replace(edges=attention_weights_logits),
            normalizer=self._normalizer)

        # Attending to sender values according to the weights.
        attented_edges = sender_values * \
            normalized_attention_weights[..., None]

        # Summing all of the attended values from each node.
        received_edges_aggregator = blocks.ReceivedEdgesToNodesAggregator(
            reducer=tf.unsorted_segment_sum)
        aggregated_attended_values = received_edges_aggregator(
            attention_graph.replace(edges=attented_edges))

        return attention_graph.replace(nodes=aggregated_attended_values)
Beispiel #5
0
 def __init__(self,
              node_model_fn,
              received_edges_reducer=tf.math.unsorted_segment_sum,
              sent_edges_reducer=tf.math.unsorted_segment_sum,
              name='dist_node_block'):
     super(NodeBlock, self).__init__(name=name)
     with self._enter_variable_scope():
         self._received_edges_aggregator = blocks.ReceivedEdgesToNodesAggregator(
             received_edges_reducer)
         self._sent_edges_aggregator = blocks.SentEdgesToNodesAggregator(
             sent_edges_reducer)
         self._node_model = node_model_fn()
Beispiel #6
0
    def test_output_values(self, use_received_edges, use_sent_edges, use_nodes,
                           use_globals, received_edges_reducer,
                           sent_edges_reducer):
        """Compares the output of a NodeBlock to an explicit computation."""
        input_graph = self._get_input_graph()
        node_block = blocks.NodeBlock(
            node_model_fn=self._node_model_fn,
            use_received_edges=use_received_edges,
            use_sent_edges=use_sent_edges,
            use_nodes=use_nodes,
            use_globals=use_globals,
            received_edges_reducer=received_edges_reducer,
            sent_edges_reducer=sent_edges_reducer)
        output_graph = node_block(input_graph)

        model_inputs = []
        if use_received_edges:
            model_inputs.append(
                blocks.ReceivedEdgesToNodesAggregator(received_edges_reducer)(
                    input_graph))
        if use_sent_edges:
            model_inputs.append(
                blocks.SentEdgesToNodesAggregator(sent_edges_reducer)(
                    input_graph))
        if use_nodes:
            model_inputs.append(input_graph.nodes)
        if use_globals:
            model_inputs.append(blocks.broadcast_globals_to_nodes(input_graph))

        model_inputs = tf.concat(model_inputs, axis=-1)
        self.assertEqual(input_graph.edges, output_graph.edges)
        self.assertEqual(input_graph.globals, output_graph.globals)

        with self.test_session() as sess:
            output_graph_out, model_inputs_out = sess.run(
                (output_graph, model_inputs))

        expected_output_nodes = model_inputs_out * self._scale
        self.assertNDArrayNear(expected_output_nodes,
                               output_graph_out.nodes,
                               err=1e-4)
Beispiel #7
0
    def _build(self, labels, graph, num_steps):
        """
        description: Updates each node according to its label, previous state and neighbours

        first, it passes concatenation of states of each adjacent node to the
        :param labels: Embedding of each node [n_nodes,embedding_length]
        :param graph: GraphTuple containing connectivity information between nodes
        via the senders and receivers fields.

        :return ret_graph: Graph after one step of message passing
        """

        ret_graph = graph

        #lstm for updating nodes during each time step
        lstm = sn.LSTM(2 * labels.shape[1])
        state = lstm.initial_state(labels.shape[0])

        for _ in range(num_steps):
            #passing sender and receiver nodes through an MLP
            ret_graph = self._edge_block(ret_graph)

            #aggregating edges to nodes (summing up received edges per node)
            received_edges_aggregator = blocks.ReceivedEdgesToNodesAggregator(
                reducer=tf.unsorted_segment_sum)
            messages = received_edges_aggregator(ret_graph)

            #concatenating messages and labels for each node and then passing the result through an MLP
            hidden = self._node_model_fn(tf.concat(labels, messages, axis=1))

            #passing hidden and state through an LSTM
            hidden, state = lstm(hidden, state)

            #aggregating nodes to global representation
            ret_graph = self._global_block(ret_graph.replace(nodes=hidden))

        return ret_graph
Beispiel #8
0
    def _build(self, node_values, node_keys, node_queries, attention_graph):
        # Sender nodes put their keys and values in the edges.
        # [total_num_edges, num_heads, query_size]
        sender_keys = blocks.broadcast_sender_nodes_to_edges(
            attention_graph.replace(nodes=node_keys))

        # [total_num_edges, num_heads, value_size]
        sender_values = blocks.broadcast_sender_nodes_to_edges(
            attention_graph.replace(nodes=node_values))

        # Receiver nodes put their queries in the edges.
        # [total_num_edges, num_heads, key_size]
        receiver_queries = blocks.broadcast_receiver_nodes_to_edges(
            attention_graph.replace(nodes=node_queries))

        # Attention weight for each edge.
        # [total_num_edges, num_heads]
        attention_weights_logits = tf.reduce_sum(
            sender_keys * tf.transpose(receiver_queries), axis=-1)
        normalized_attention_weights = _received_edges_normalizer(
            attention_graph.replace(edges=attention_weights_logits),
            normalizer=self._normalizer)

        # Attending to sender values according to the weights.
        # [total_num_edges, num_heads, embedding_size]
        attented_edges = sender_values * normalized_attention_weights[...,
                                                                      None]

        # Summing all of the attended values from each node.
        # [total_num_nodes, num_heads, embedding_size]
        received_edges_aggregator = blocks.ReceivedEdgesToNodesAggregator(
            reducer=tf.unsorted_segment_sum)
        aggregated_attended_values = received_edges_aggregator(
            attention_graph.replace(edges=attented_edges))

        return attention_graph.replace(nodes=aggregated_attended_values)
Beispiel #9
0
class FieldAggregatorsTest(GraphModuleTest):

  @parameterized.named_parameters(
      ("edges_to_globals",
       blocks.EdgesToGlobalsAggregator(tf.unsorted_segment_sum),
       SEGMENT_SUM_EDGES_TO_GLOBALS,),
      ("nodes_to_globals",
       blocks.NodesToGlobalsAggregator(tf.unsorted_segment_sum),
       SEGMENT_SUM_NODES_TO_GLOBALS,),
      ("sent_edges_to_nodes",
       blocks.SentEdgesToNodesAggregator(tf.unsorted_segment_sum),
       SEGMENT_SUM_SENT_EDGES_TO_NODES,),
      ("received_edges_to_nodes",
       blocks.ReceivedEdgesToNodesAggregator(tf.unsorted_segment_sum),
       SEGMENT_SUM_RECEIVED_EDGES_TO_NODES),
  )
  def test_output_values(self, aggregator, expected):
    input_graph = self._get_input_graph()
    aggregated = aggregator(input_graph)
    with self.test_session() as sess:
      aggregated_out = sess.run(aggregated)
    self.assertNDArrayNear(
        np.array(expected, dtype=np.float32), aggregated_out, err=1e-4)

  @parameterized.named_parameters(
      ("edges_to_globals",
       blocks.EdgesToGlobalsAggregator(tf.unsorted_segment_sum),
       SEGMENT_SUM_EDGES_TO_GLOBALS,),
      ("nodes_to_globals",
       blocks.NodesToGlobalsAggregator(tf.unsorted_segment_sum),
       SEGMENT_SUM_NODES_TO_GLOBALS,),
      ("sent_edges_to_nodes",
       blocks.SentEdgesToNodesAggregator(tf.unsorted_segment_sum),
       SEGMENT_SUM_SENT_EDGES_TO_NODES,),
      ("received_edges_to_nodes",
       blocks.ReceivedEdgesToNodesAggregator(tf.unsorted_segment_sum),
       SEGMENT_SUM_RECEIVED_EDGES_TO_NODES),
  )
  def test_output_values_larger_rank(self, aggregator, expected):
    input_graph = self._get_input_graph()
    input_graph = input_graph.map(
        lambda v: tf.reshape(v, [v.get_shape().as_list()[0]] + [2, -1]))
    aggregated = aggregator(input_graph)
    with self.test_session() as sess:
      aggregated_out = sess.run(aggregated)
    self.assertNDArrayNear(
        np.reshape(np.array(expected, dtype=np.float32),
                   [len(expected)] + [2, -1]),
        aggregated_out,
        err=1e-4)

  @parameterized.named_parameters(
      ("received edges to nodes missing edges",
       blocks.ReceivedEdgesToNodesAggregator, "edges"),
      ("sent edges to nodes missing edges",
       blocks.SentEdgesToNodesAggregator, "edges"),
      ("nodes to globals missing nodes",
       blocks.NodesToGlobalsAggregator, "nodes"),
      ("edges to globals missing nodes",
       blocks.EdgesToGlobalsAggregator, "edges"),)
  def test_missing_field_raises_exception(self, constructor, none_field):
    """Tests that aggregator fail if a required field is missing."""
    input_graph = self._get_input_graph([none_field])
    with self.assertRaisesRegexp(ValueError, none_field):
      constructor(tf.unsorted_segment_sum)(input_graph)

  @parameterized.named_parameters(
      ("received edges to nodes missing nodes and globals",
       blocks.ReceivedEdgesToNodesAggregator, ["nodes", "globals"]),
      ("sent edges to nodes missing nodes and globals",
       blocks.SentEdgesToNodesAggregator, ["nodes", "globals"]),
      ("nodes to globals missing edges and globals",
       blocks.NodesToGlobalsAggregator,
       ["edges", "receivers", "senders", "globals"]),
      ("edges to globals missing globals",
       blocks.EdgesToGlobalsAggregator, ["globals"]),
  )
  def test_unused_field_can_be_none(self, constructor, none_fields):
    """Tests that aggregator fail if a required field is missing."""
    input_graph = self._get_input_graph(none_fields)
    constructor(tf.unsorted_segment_sum)(input_graph)
Beispiel #10
0
    def _build(self, graph_features):
        """Connects the multi-head self-attention module.

    Uses edge_features to compute key, values and node_features
    for queries.

    The self-attention is only computed according to the connectivity of the
    input graphs, with receiver nodes attending to sender nodes.

    Args:
      graph_features: Graph containing connectivity information between nodes
        via the senders and receivers fields. Node A will only attempt to attend
        to Node B if `attention_graph` contains an edge sent by Node A and
        received by Node B.

    Returns:
      An output `graphs.GraphsTuple` with updated nodes containing the
      aggregated attended value for each of the nodes with shape
      [total_num_nodes, num_heads, value_size].

    Raises:
      ValueError: if the input graph does not have edges.
    """
        """
    # TODO(arc): Figure out how to incorporate edge information into
                 attention updates.
    """
        edges = self._edge_block(graph_features).edges
        num_heads = self.num_heads
        key_size = self.key_size
        value_size = self.value_size
        node_embed_dim = tf.shape(graph_features.nodes)[-1]

        # [total_num_nodes, d] => [total_num_nodes, key_size * num_heads]
        q = self._attention_node_projection_model(graph_features.nodes)

        q = tf.reshape(
            q, [tf.reduce_sum(graph_features.n_node), num_heads, key_size])

        # [total_num_edges, (key_size + value_size) * num_heads]
        # project edge features to get key, values
        kv = self._attention_edge_projection_model(edges)
        kv = tf.reshape(kv, [-1, num_heads, key_size + value_size])
        # k => [total_num_edges, num_heads, key_size]
        # v => [total_num_edges, num_heads, value_size]
        k, v = tf.split(kv, [key_size, value_size], -1)

        sender_keys = k
        sender_values = v
        # Receiver nodes put their queries in the edges.
        # [total_num_edges, num_heads, key_size]
        receiver_queries = blocks.broadcast_receiver_nodes_to_edges(
            graph_features.replace(nodes=q))

        # Attention weight for each edge.
        # [total_num_edges, num_heads, 1]
        attention_weights_logits = snt.BatchApply(
            self._query_key_product_model)(tf.concat(
                [sender_keys, receiver_queries], axis=-1))
        # [total_num_edges, num_heads]
        attention_weights_logits = tf.squeeze(attention_weights_logits, -1)

        # compute softmax weights
        # [total_num_edges, num_heads]
        normalized_attention_weights = _received_edges_normalizer(
            graph_features.replace(edges=attention_weights_logits),
            normalizer=_unsorted_segment_softmax)

        # Attending to sender values according to the weights.
        # [total_num_edges, num_heads, value_size]
        attented_edges = sender_values * normalized_attention_weights[...,
                                                                      None]

        received_edges_aggregator = blocks.ReceivedEdgesToNodesAggregator(
            reducer=tf.unsorted_segment_sum)
        # Summing all of the attended values from each node.
        # [total_num_nodes, num_heads, value_size]
        aggregated_attended_values = received_edges_aggregator(
            graph_features.replace(edges=attented_edges))

        # concatenate all the heads and project to required dimension.
        # cast to [total_num_nodes, num_heads * value_size]
        aggregated_attended_values = tf.reshape(aggregated_attended_values,
                                                [-1, num_heads * value_size])
        # -> [total_num_nodes, node_embed_dim]
        aggregated_attended_values = self._node_model(
            aggregated_attended_values)

        return self._global_block(
            graph_features.replace(nodes=aggregated_attended_values,
                                   edges=edges))
Beispiel #11
0
tvars = graph_network.trainable_variables
print('')

###############
# broadcast

graphs_tuple = utils_tf.data_dicts_to_graphs_tuple([data_dict_0])
updated_broadcast_globals_to_nodes = graphs_tuple.replace(
    nodes=blocks.broadcast_globals_to_nodes(graphs_tuple))
updated_broadcast_globals_to_edges = graphs_tuple.replace(
    edges=blocks.broadcast_globals_to_edges(graphs_tuple))
updated_broadcast_sender_nodes_to_edges = graphs_tuple.replace(
    edges=blocks.broadcast_sender_nodes_to_edges(graphs_tuple))
updated_broadcast_receiver_nodes_to_edges = graphs_tuple.replace(
    edges=blocks.broadcast_receiver_nodes_to_edges(graphs_tuple))

############
# aggregate

graphs_tuple = utils_tf.data_dicts_to_graphs_tuple([data_dict_0])

reducer = tf.math.unsorted_segment_sum  #######yr
updated_edges_to_globals = graphs_tuple.replace(
    globals=blocks.EdgesToGlobalsAggregator(reducer=reducer)(graphs_tuple))
updated_nodes_to_globals = graphs_tuple.replace(
    globals=blocks.NodesToGlobalsAggregator(reducer=reducer)(graphs_tuple))
updated_sent_edges_to_nodes = graphs_tuple.replace(
    nodes=blocks.SentEdgesToNodesAggregator(reducer=reducer)(graphs_tuple))
updated_received_edges_to_nodes = graphs_tuple.replace(
    nodes=blocks.ReceivedEdgesToNodesAggregator(reducer=reducer)(graphs_tuple))
def model_fn(features, labels, mode, params):
    sentences = features["input_ids"]
    word_embedding = tf.constant(params['word_embedding'])
    graph_nodes = params['graph_nodes']
    graph_edges = params['graph_edges']
    depth = graph_nodes.shape[1]
    training = mode == tf.estimator.ModeKeys.TRAIN

    padding_mask = tf.cast(tf.not_equal(tf.cast(sentences, tf.int32), tf.constant([[1]])),
                           tf.int32)  # 0 means the token needs to be masked. 1 means it is not masked.
    padding_mask = tf.reshape(padding_mask, [-1, FLAGS.seq_len, 1])
    sentences = tf.nn.embedding_lookup(word_embedding, sentences)
    sentences = tf.reshape(sentences, [-1, FLAGS.seq_len, depth])
    # print("sentences: " + str(sentences))
    # print("padding_mask: " + str(padding_mask))
    question_encoder = tf.keras.layers.LSTM(depth, dropout=FLAGS.dropout, return_sequences=True)
    encoded_question = question_encoder(sentences, training=training)
    encoded_question = tf.cast(padding_mask, tf.float32) * tf.cast(encoded_question, tf.float32)
    encoded_question = tf.reshape(tf.cast(encoded_question, tf.float32), [-1, depth])

    # The template graph
    nodes = graph_nodes.astype(np.float32)
    edges = np.ones([int(np.sum(graph_edges)), 1]).astype(np.float32)
    senders, receivers = np.nonzero(graph_edges)
    globals = np.zeros(FLAGS.global_size).astype(np.float32)

    graph_dict = {"globals": globals,
                  "nodes": nodes,
                  "edges": edges,
                  "senders": senders,
                  "receivers": receivers}
    original_graph = utils_tf.data_dicts_to_graphs_tuple([graph_dict])
    graph_dict["nodes"] = nodes * 0
    # print("encoded_question.shape[0]: " + str(encoded_question.shape[0]))
    batch_of_tensor_data_dicts = [graph_dict for i in range(sentences.shape[0])]

    batch_of_graphs = utils_tf.data_dicts_to_graphs_tuple(batch_of_tensor_data_dicts)
    batch_of_nodes = batch_of_graphs.nodes
    # print("batch_of_nodes: " + str(batch_of_nodes))

    # Euclidean distance to identify closest nodes
    na = tf.reduce_sum(tf.square(tf.math.l2_normalize(encoded_question, -1)), 1)
    nb = tf.reduce_sum(tf.square(tf.math.l2_normalize(nodes, -1)), 1)

    # na as a row and nb as a column vectors
    na = tf.reshape(na, [-1, 1])
    nb = tf.reshape(nb, [1, -1])

    # return pairwise euclidead difference matrix
    distance = tf.sqrt(tf.maximum(na - 2 * tf.matmul(encoded_question, nodes, False, True) + nb, 0.0))

    # calculate attention over the graph
    closest_nodes = tf.cast(tf.argmin(distance, -1), tf.int32)
    # print("closest_nodes: " + str(closest_nodes))

    # # Write the signals onto these nodes
    positions = tf.where(tf.not_equal(tf.reshape(closest_nodes, [-1, FLAGS.seq_len]), 99999))
    # print("positions: " + str(positions))
    positions = tf.slice(positions, [0, 0], [-1, 1])  # we only want the first 2 dimensions, since the last dimension is incorrect
    # print("positions: " + str(positions))
    positions = tf.cast(positions, tf.int32)
    # print("positions: " + str(positions))
    positions = tf.concat([positions, tf.reshape(closest_nodes, [-1, 1])], -1)
    # print("positions: " + str(positions))
    # print("compressed: " + str(compressed1))
    # print("norm_duplicate: " + str(tf.reshape(norm_duplicate, [-1, 1])))
    projection_signal = tf.reshape(encoded_question, [-1, depth])
    # print("projection_signal: " + str(projection_signal))
    batch_of_nodes = tf.tensor_scatter_nd_add(tf.reshape(batch_of_nodes, [-1, 512, depth]), positions, projection_signal)
    # print("batch_of_nodes: " + str(batch_of_nodes))
    batch_of_graphs = batch_of_graphs.replace(nodes=tf.reshape(batch_of_nodes, [-1, depth]))

    global_block = blocks.NodesToGlobalsAggregator(tf.math.unsorted_segment_mean)
    global_dense = tf.keras.layers.Dense(depth, activation='relu')

    num_recurrent_passes = FLAGS.recurrences
    previous_graphs = batch_of_graphs
    original_nodes = tf.reshape(original_graph.nodes, [1, 512, depth])
    dropout = tf.keras.layers.Dropout(FLAGS.dropout)
    layernorm_global = tf.keras.layers.LayerNormalization(epsilon=1e-6)
    layernorm_node = tf.keras.layers.LayerNormalization(epsilon=1e-6)
    new_global = global_block(previous_graphs)
    previous_graphs = previous_graphs.replace(globals=global_dense(new_global))
    previous_graphs = previous_graphs.replace(globals=layernorm_global(previous_graphs.globals))
    initial_global = previous_graphs.globals

    model_fn = snt.nets.MLP(output_sizes=[depth])

    for unused_pass in range(num_recurrent_passes):
        # Update the node features with the function
        updated_nodes = model_fn(previous_graphs.nodes)
        updated_nodes = layernorm_node(updated_nodes)
        temporary_graph = previous_graphs.replace(nodes=updated_nodes)
        graph_sum0 = tf.reduce_sum(tf.reshape(tf.math.abs(temporary_graph.nodes), [-1, 4 * 512 * 300]), -1)

        # Send the node features to the edges that are being sent by that node.
        nodes_at_edges = blocks.broadcast_sender_nodes_to_edges(temporary_graph)
        graph_sum1 = tf.reduce_sum(tf.reshape(tf.math.abs(nodes_at_edges), [-1, 4 * 5551 * 300]), -1)

        temporary_graph = temporary_graph.replace(edges=nodes_at_edges)

        # Aggregate the all of the edges received by every node.
        nodes_with_aggregated_edges = blocks.ReceivedEdgesToNodesAggregator(tf.math.unsorted_segment_mean)(
            temporary_graph)
        graph_sum2 = tf.reduce_sum(tf.reshape(tf.math.abs(nodes_with_aggregated_edges), [-1, 4 * 512 * 300]), -1)
        previous_graphs = previous_graphs.replace(nodes=nodes_with_aggregated_edges)

        current_nodes = previous_graphs.nodes
        current_nodes = tf.reshape(current_nodes, [-1, 512, depth])
        current_nodes = dropout(current_nodes, training=training)
        new_nodes = current_nodes * original_nodes
        previous_graphs = previous_graphs.replace(nodes=tf.reshape(new_nodes, [-1, depth]))
        old_global = previous_graphs.globals
        new_global = global_block(previous_graphs)
        previous_graphs = previous_graphs.replace(globals=global_dense(new_global))
        previous_graphs = previous_graphs.replace(globals=layernorm_global(previous_graphs.globals))

    output_global = tf.keras.layers.Dropout(FLAGS.dropout)(previous_graphs.globals, training=training)
    dense_layer = tf.keras.layers.Dense(1)
    logits = dense_layer(output_global)
    logits = tf.reshape(logits, [-1, num_choices])

    def loss_function(real, pred):
        return tf.nn.sparse_softmax_cross_entropy_with_logits(tf.reshape(real, [-1]), pred)

    # Calculate the loss
    loss = loss_function(features["answer_id"], logits)

    predictions = {
        'original': features["input_ids"],
        'prediction': tf.argmax(logits, -1),
        'correct': features["answer_id"],
        'logits': logits,
        'loss': loss,
        'output_global': tf.reshape(output_global, [-1, 4, 300]),
        'initial_global': tf.reshape(initial_global, [-1, 4, 300]),
        'old_global': tf.reshape(old_global, [-1, 4, 300]),
        'new_global': tf.reshape(new_global, [-1, 4, 300]),
        'graph_sum0': graph_sum0,
        'graph_sum1': graph_sum1,
        'graph_sum2': graph_sum2,
        'closest_nodes': tf.reshape(closest_nodes, [-1, 4, FLAGS.seq_len]),
        'input_id': features["input_ids"],
        'mask': tf.reshape(padding_mask, [-1, 4, FLAGS.seq_len]),
        'encoded_question': tf.reshape(encoded_question, [-1, 4, FLAGS.seq_len, depth])
    }

    if mode == tf.estimator.ModeKeys.PREDICT:
        export_outputs = {
            SIGNATURE_NAME: tf.estimator.export.PredictOutput(predictions)
        }
        return tf.estimator.EstimatorSpec(mode=mode, predictions=predictions, export_outputs=export_outputs)

    if mode == tf.estimator.ModeKeys.TRAIN:
        global_step = tf.compat.v1.train.get_or_create_global_step()

        optimizer = tf.compat.v1.train.AdamOptimizer(learning_rate=FLAGS.learning_rate, beta2=0.98, epsilon=1e-9)

        # Batch norm requires update ops to be added as a dependency to the train_op
        update_ops = tf.compat.v1.get_collection(tf.compat.v1.GraphKeys.UPDATE_OPS)
        with tf.control_dependencies(update_ops):
            train_op = optimizer.minimize(tf.reduce_mean(loss), global_step)
    else:
        train_op = None

    return tf.estimator.EstimatorSpec(
        mode=mode,
        predictions=predictions,
        loss=tf.reduce_mean(loss),
        train_op=train_op)