Ejemplo n.º 1
0
    def _build(self, node_values, node_keys, node_queries, attention_graph):
        """Connects the multi-head self-attention module.

        The self-attention is only computed according to the connectivity of the
        input graphs, with receiver nodes attending to sender nodes.

        Args:
          node_values: Tensor containing the values associated to each of the nodes.
            The expected shape is [total_num_nodes, num_heads, key_size].
          node_keys: Tensor containing the key associated to each of the nodes. The
            expected shape is [total_num_nodes, num_heads, key_size].
          node_queries: Tensor containing the query associated to each of the nodes.
            The expected shape is [total_num_nodes, num_heads, query_size]. The
            query size must be equal to the key size.
          attention_graph: Graph containing connectivity information between nodes
            via the senders and receivers fields. Node A will only attempt to attend
            to Node B if `attention_graph` contains an edge sent by Node A and
            received by Node B.

        Returns:
          An output `graphs.GraphsTuple` with updated nodes containing the
          aggregated attended value for each of the nodes with shape
          [total_num_nodes, num_heads, value_size].

        Raises:
          ValueError: if the input graph does not have edges.
        """

        # Sender nodes put their keys and values in the edges.
        sender_keys = blocks.broadcast_sender_nodes_to_edges(
            attention_graph.replace(nodes=node_keys))
        sender_values = blocks.broadcast_sender_nodes_to_edges(
            attention_graph.replace(nodes=node_values))

        # Receiver nodes put their queries in the edges.
        receiver_queries = blocks.broadcast_receiver_nodes_to_edges(
            attention_graph.replace(nodes=node_queries))

        # Attention weight for each edge.
        attention_weights_logits = tf.reduce_sum(sender_keys *
                                                 receiver_queries,
                                                 axis=-1)
        normalized_attention_weights = _received_edges_normalizer(
            attention_graph.replace(edges=attention_weights_logits),
            normalizer=self._normalizer)

        # Attending to sender values according to the weights.
        attented_edges = sender_values * \
            normalized_attention_weights[..., None]

        # Summing all of the attended values from each node.
        received_edges_aggregator = blocks.ReceivedEdgesToNodesAggregator(
            reducer=tf.unsorted_segment_sum)
        aggregated_attended_values = received_edges_aggregator(
            attention_graph.replace(edges=attented_edges))

        return attention_graph.replace(nodes=aggregated_attended_values)
Ejemplo n.º 2
0
    def _build(self, attended_graph):
        """
        Feed the input through the layer
        :param attended_graph: the graph to attend to
        :return: result
        """
        stacked_edges = tf.stack([
            blocks.broadcast_sender_nodes_to_edges(attended_graph),
            blocks.broadcast_receiver_nodes_to_edges(attended_graph)
        ],
                                 axis=1)
        his = None
        for k in range(self.heads):
            e = tf.map_fn(
                lambda edge: tf.concat([
                    tf.tensordot(self.W[k], edge[0], axes=1),
                    tf.tensordot(self.W[k], edge[1], axes=1)
                ],
                                       axis=0), stacked_edges)
            attended_e = tf.exp(tf.nn.leaky_relu(self.attentions[k](e)))

            e_sender_sum = tf.math.unsorted_segment_sum(
                attended_e,
                attended_graph.senders,
                num_segments=tf.shape(attended_graph.nodes)[0])
            e_receiver_sum = tf.math.unsorted_segment_sum(
                attended_e,
                attended_graph.receivers,
                num_segments=tf.shape(attended_graph.nodes)[0])
            stacked_to_avg = tf.stack([
                attended_e,
                tf.add(tf.gather(e_sender_sum, attended_graph.senders),
                       tf.gather(e_receiver_sum, attended_graph.receivers))
            ],
                                      axis=1)
            e_avg = tf.map_fn(lambda avg: tf.divide(avg[0], avg[1]),
                              stacked_to_avg)

            Whi = tf.map_fn(
                lambda edge: tf.tensordot(self.W[k], edge, axes=1),
                blocks.broadcast_sender_nodes_to_edges(attended_graph))
            aWhi = tf.multiply(Whi, e_avg)
            hi = tf.math.unsorted_segment_sum(aWhi,
                                              attended_graph.senders,
                                              num_segments=tf.shape(
                                                  attended_graph.nodes)[0])
            if his is None:
                his = hi
            else:
                his = tf.add(his, hi)
        his = tf.divide(his, self.heads)
        return attended_graph.replace(nodes=his)
Ejemplo n.º 3
0
  def test_unused_field_can_be_none(
      self, use_edges, use_nodes, use_globals, none_field):
    """Checks that computation can handle non-necessary fields left None."""
    input_graph = self._get_input_graph([none_field])
    edge_block = blocks.EdgeBlock(
        edge_model_fn=self._edge_model_fn,
        use_edges=use_edges,
        use_receiver_nodes=use_nodes,
        use_sender_nodes=use_nodes,
        use_globals=use_globals)
    output_graph = edge_block(input_graph)

    model_inputs = []
    if use_edges:
      model_inputs.append(input_graph.edges)
    if use_nodes:
      model_inputs.append(blocks.broadcast_receiver_nodes_to_edges(input_graph))
      model_inputs.append(blocks.broadcast_sender_nodes_to_edges(input_graph))
    if use_globals:
      model_inputs.append(blocks.broadcast_globals_to_edges(input_graph))

    model_inputs = tf.concat(model_inputs, axis=-1)
    self.assertEqual(input_graph.nodes, output_graph.nodes)
    self.assertEqual(input_graph.globals, output_graph.globals)

    with self.test_session() as sess:
      actual_edges, model_inputs_out = sess.run(
          (output_graph.edges, model_inputs))

    expected_output_edges = model_inputs_out * self._scale
    self.assertNDArrayNear(expected_output_edges, actual_edges, err=1e-4)
Ejemplo n.º 4
0
  def test_output_values(
      self, use_edges, use_receiver_nodes, use_sender_nodes, use_globals):
    """Compares the output of an EdgeBlock to an explicit computation."""
    input_graph = self._get_input_graph()
    edge_block = blocks.EdgeBlock(
        edge_model_fn=self._edge_model_fn,
        use_edges=use_edges,
        use_receiver_nodes=use_receiver_nodes,
        use_sender_nodes=use_sender_nodes,
        use_globals=use_globals)
    output_graph = edge_block(input_graph)

    model_inputs = []
    if use_edges:
      model_inputs.append(input_graph.edges)
    if use_receiver_nodes:
      model_inputs.append(blocks.broadcast_receiver_nodes_to_edges(input_graph))
    if use_sender_nodes:
      model_inputs.append(blocks.broadcast_sender_nodes_to_edges(input_graph))
    if use_globals:
      model_inputs.append(blocks.broadcast_globals_to_edges(input_graph))

    model_inputs = tf.concat(model_inputs, axis=-1)
    self.assertEqual(input_graph.nodes, output_graph.nodes)
    self.assertEqual(input_graph.globals, output_graph.globals)

    with self.test_session() as sess:
      output_graph_out, model_inputs_out = sess.run(
          (output_graph, model_inputs))

    expected_output_edges = model_inputs_out * self._scale
    self.assertNDArrayNear(
        expected_output_edges, output_graph_out.edges, err=1e-4)
Ejemplo n.º 5
0
    def _build(self, graph):
        """Builds a SpringMassSimulator.

    Args:
      graph: A graphs.GraphsTuple having, for some integers N, E, G:
          - edges: Nx2 tf.Tensor of [spring_constant, rest_length] for each
            edge.
          - nodes: Ex5 tf.Tensor of [x, y, v_x, v_y, is_fixed] features for each
            node.
          - globals: Gx2 tf.Tensor containing the gravitational constant.

    Returns:
      A graphs.GraphsTuple of the same shape as `graph`, but where:
          - edges: Holds the force [f_x, f_y] acting on each edge.
          - nodes: Holds positions and velocities after applying one step of
              Euler integration.
    """
        receiver_nodes = blocks.broadcast_receiver_nodes_to_edges(graph)
        sender_nodes = blocks.broadcast_sender_nodes_to_edges(graph)

        spring_force_per_edge = hookes_law(receiver_nodes, sender_nodes,
                                           graph.edges[..., 0:1],
                                           graph.edges[..., 1:2])
        graph = graph.replace(edges=spring_force_per_edge)

        spring_force_per_node = self._aggregator(graph)
        gravity = blocks.broadcast_globals_to_nodes(graph)
        updated_velocities = euler_integration(graph.nodes,
                                               spring_force_per_node + gravity,
                                               self._step_size)
        graph = graph.replace(nodes=updated_velocities)
        return graph
Ejemplo n.º 6
0
    def _build(self, node_values, node_keys, node_queries, attention_graph):
        # Sender nodes put their keys and values in the edges.
        # [total_num_edges, num_heads, query_size]
        sender_keys = blocks.broadcast_sender_nodes_to_edges(
            attention_graph.replace(nodes=node_keys))

        # [total_num_edges, num_heads, value_size]
        sender_values = blocks.broadcast_sender_nodes_to_edges(
            attention_graph.replace(nodes=node_values))

        # Receiver nodes put their queries in the edges.
        # [total_num_edges, num_heads, key_size]
        receiver_queries = blocks.broadcast_receiver_nodes_to_edges(
            attention_graph.replace(nodes=node_queries))

        # Attention weight for each edge.
        # [total_num_edges, num_heads]
        attention_weights_logits = tf.reduce_sum(
            sender_keys * tf.transpose(receiver_queries), axis=-1)
        normalized_attention_weights = _received_edges_normalizer(
            attention_graph.replace(edges=attention_weights_logits),
            normalizer=self._normalizer)

        # Attending to sender values according to the weights.
        # [total_num_edges, num_heads, embedding_size]
        attented_edges = sender_values * normalized_attention_weights[...,
                                                                      None]

        # Summing all of the attended values from each node.
        # [total_num_nodes, num_heads, embedding_size]
        received_edges_aggregator = blocks.ReceivedEdgesToNodesAggregator(
            reducer=tf.unsorted_segment_sum)
        aggregated_attended_values = received_edges_aggregator(
            attention_graph.replace(edges=attented_edges))

        return attention_graph.replace(nodes=aggregated_attended_values)
Ejemplo n.º 7
0
def set_rest_lengths(graph):
  """Computes and sets rest lengths for the springs in a physical system.

  The rest length is taken to be the distance between each edge's nodes.

  Args:
    graph: a graphs.GraphsTuple having, for some integers N, E:
        - nodes: Nx5 Tensor of [x, y, _, _, _] for each node.
        - edges: Ex2 Tensor of [spring_constant, _] for each edge.

  Returns:
    The input graph, but with [spring_constant, rest_length] for each edge.
  """
  receiver_nodes = blocks.broadcast_receiver_nodes_to_edges(graph)
  sender_nodes = blocks.broadcast_sender_nodes_to_edges(graph)
  rest_length = tf.norm(
      receiver_nodes[..., :2] - sender_nodes[..., :2], axis=-1, keep_dims=True)
  return graph.replace(
      edges=tf.concat([graph.edges[..., :1], rest_length], axis=-1))
Ejemplo n.º 8
0
    def _build(self, graph):
        agg_receiver_nodes_features = blocks.broadcast_receiver_nodes_to_edges(
            graph)
        agg_sender_nodes_features = blocks.broadcast_sender_nodes_to_edges(
            graph)

        # aggreate across replicas

        replica_ctx = tf.distribute.get_replica_context()
        agg_receiver_nodes_features = replica_ctx.all_reduce(
            "sum", agg_receiver_nodes_features)
        agg_sender_nodes_features = replica_ctx.all_reduce(
            "sum", agg_sender_nodes_features)

        edges_to_collect = [
            graph.edges, agg_receiver_nodes_features, agg_sender_nodes_features
        ]
        collected_edges = tf.concat(edges_to_collect, axis=-1)
        updated_edges = self._edge_model(collected_edges)
        return graph.replace(edges=updated_edges)
Ejemplo n.º 9
0
    def _build(self, graph_features):
        """Connects the multi-head self-attention module.

    The self-attention is only computed according to the connectivity of the
    input graphs, with receiver nodes attending to sender nodes.

    Args:
      graph_features: Graph containing connectivity information between nodes
        via the senders and receivers fields. Node A will only attempt to attend
        to Node B if `attention_graph` contains an edge sent by Node A and
        received by Node B.

    Returns:
      An output `graphs.GraphsTuple` with updated nodes containing the
      aggregated attended value for each of the nodes with shape
      [total_num_nodes, num_heads, value_size].

    Raises:
      ValueError: if the input graph does not have edges.
    """
        """
    # TODO(arc): Figure out how to incorporate edge information into
                 attention updates.
    """
        nodes = graph_features.nodes

        num_heads = self.num_heads
        key_size = self.key_size
        value_size = self.value_size
        node_embed_dim = tf.shape(nodes)[-1]

        qkv_size = 2 * key_size + value_size
        total_size = qkv_size * num_heads  # denote as F

        # [total_num_nodes, d] => [total_num_nodes, F]
        qkv_flat = self._attention_projection_model(nodes)

        qkv = tf.reshape(qkv_flat, [-1, num_heads, qkv_size])
        # q => [total_num_nodes, num_heads, key_size]
        # k => [total_num_nodes, num_heads, key_size]
        # v => [total_num_nodes, num_heads, value_size]
        q, k, v = tf.split(qkv, [key_size, key_size, value_size], -1)

        # Sender nodes put their keys and values in the edges.
        # [total_num_edges, num_heads, query_size]
        sender_keys = blocks.broadcast_sender_nodes_to_edges(
            graph_features.replace(nodes=k))
        # [total_num_edges, num_heads, value_size]
        sender_values = blocks.broadcast_sender_nodes_to_edges(
            graph_features.replace(nodes=v))

        # Receiver nodes put their queries in the edges.
        # [total_num_edges, num_heads, key_size]
        receiver_queries = blocks.broadcast_receiver_nodes_to_edges(
            graph_features.replace(nodes=q))

        # Attention weight for each edge.
        # [total_num_edges, num_heads, 1]
        attention_weights_logits = snt.BatchApply(
            self._query_key_product_model)(tf.concat(
                [sender_keys, receiver_queries], axis=-1))
        # [total_num_edges, num_heads]
        attention_weights_logits = tf.squeeze(attention_weights_logits, -1)

        # compute softmax weights
        # [total_num_edges, num_heads]
        normalized_attention_weights = _received_edges_normalizer(
            graph_features.replace(edges=attention_weights_logits),
            normalizer=self._normalizer)

        # Attending to sender values according to the weights.
        # [total_num_edges, num_heads, value_size]
        attented_edges = sender_values * normalized_attention_weights[...,
                                                                      None]

        received_edges_aggregator = blocks.ReceivedEdgesToNodesAggregator(
            reducer=tf.unsorted_segment_sum)
        # Summing all of the attended values from each node.
        # [total_num_nodes, num_heads, value_size]
        aggregated_attended_values = received_edges_aggregator(
            graph_features.replace(edges=attented_edges))

        # concatenate all the heads and project to required dimension.
        # cast to [total_num_nodes, num_heads * value_size]
        aggregated_attended_values = tf.reshape(aggregated_attended_values,
                                                [-1, num_heads * value_size])
        # -> [total_num_nodes, node_embed_dim]
        aggregated_attended_values = self._node_model(
            aggregated_attended_values)

        return graph_features.replace(nodes=aggregated_attended_values)
Ejemplo n.º 10
0
    print(previous_graphs.nodes[0])
output_graphs = previous_graphs

tvars = graph_network.trainable_variables
print('')

###############
# broadcast

graphs_tuple = utils_tf.data_dicts_to_graphs_tuple([data_dict_0])
updated_broadcast_globals_to_nodes = graphs_tuple.replace(
    nodes=blocks.broadcast_globals_to_nodes(graphs_tuple))
updated_broadcast_globals_to_edges = graphs_tuple.replace(
    edges=blocks.broadcast_globals_to_edges(graphs_tuple))
updated_broadcast_sender_nodes_to_edges = graphs_tuple.replace(
    edges=blocks.broadcast_sender_nodes_to_edges(graphs_tuple))
updated_broadcast_receiver_nodes_to_edges = graphs_tuple.replace(
    edges=blocks.broadcast_receiver_nodes_to_edges(graphs_tuple))

############
# aggregate

graphs_tuple = utils_tf.data_dicts_to_graphs_tuple([data_dict_0])

reducer = tf.math.unsorted_segment_sum  #######yr
updated_edges_to_globals = graphs_tuple.replace(
    globals=blocks.EdgesToGlobalsAggregator(reducer=reducer)(graphs_tuple))
updated_nodes_to_globals = graphs_tuple.replace(
    globals=blocks.NodesToGlobalsAggregator(reducer=reducer)(graphs_tuple))
updated_sent_edges_to_nodes = graphs_tuple.replace(
    nodes=blocks.SentEdgesToNodesAggregator(reducer=reducer)(graphs_tuple))
Ejemplo n.º 11
0
def model_fn(features, labels, mode, params):
    sentences = features["input_ids"]
    word_embedding = tf.constant(params['word_embedding'])
    graph_nodes = params['graph_nodes']
    graph_edges = params['graph_edges']
    depth = graph_nodes.shape[1]
    training = mode == tf.estimator.ModeKeys.TRAIN

    padding_mask = tf.cast(tf.not_equal(tf.cast(sentences, tf.int32), tf.constant([[1]])),
                           tf.int32)  # 0 means the token needs to be masked. 1 means it is not masked.
    padding_mask = tf.reshape(padding_mask, [-1, FLAGS.seq_len, 1])
    sentences = tf.nn.embedding_lookup(word_embedding, sentences)
    sentences = tf.reshape(sentences, [-1, FLAGS.seq_len, depth])
    # print("sentences: " + str(sentences))
    # print("padding_mask: " + str(padding_mask))
    question_encoder = tf.keras.layers.LSTM(depth, dropout=FLAGS.dropout, return_sequences=True)
    encoded_question = question_encoder(sentences, training=training)
    encoded_question = tf.cast(padding_mask, tf.float32) * tf.cast(encoded_question, tf.float32)
    encoded_question = tf.reshape(tf.cast(encoded_question, tf.float32), [-1, depth])

    # The template graph
    nodes = graph_nodes.astype(np.float32)
    edges = np.ones([int(np.sum(graph_edges)), 1]).astype(np.float32)
    senders, receivers = np.nonzero(graph_edges)
    globals = np.zeros(FLAGS.global_size).astype(np.float32)

    graph_dict = {"globals": globals,
                  "nodes": nodes,
                  "edges": edges,
                  "senders": senders,
                  "receivers": receivers}
    original_graph = utils_tf.data_dicts_to_graphs_tuple([graph_dict])
    graph_dict["nodes"] = nodes * 0
    # print("encoded_question.shape[0]: " + str(encoded_question.shape[0]))
    batch_of_tensor_data_dicts = [graph_dict for i in range(sentences.shape[0])]

    batch_of_graphs = utils_tf.data_dicts_to_graphs_tuple(batch_of_tensor_data_dicts)
    batch_of_nodes = batch_of_graphs.nodes
    # print("batch_of_nodes: " + str(batch_of_nodes))

    # Euclidean distance to identify closest nodes
    na = tf.reduce_sum(tf.square(tf.math.l2_normalize(encoded_question, -1)), 1)
    nb = tf.reduce_sum(tf.square(tf.math.l2_normalize(nodes, -1)), 1)

    # na as a row and nb as a column vectors
    na = tf.reshape(na, [-1, 1])
    nb = tf.reshape(nb, [1, -1])

    # return pairwise euclidead difference matrix
    distance = tf.sqrt(tf.maximum(na - 2 * tf.matmul(encoded_question, nodes, False, True) + nb, 0.0))

    # calculate attention over the graph
    closest_nodes = tf.cast(tf.argmin(distance, -1), tf.int32)
    # print("closest_nodes: " + str(closest_nodes))

    # # Write the signals onto these nodes
    positions = tf.where(tf.not_equal(tf.reshape(closest_nodes, [-1, FLAGS.seq_len]), 99999))
    # print("positions: " + str(positions))
    positions = tf.slice(positions, [0, 0], [-1, 1])  # we only want the first 2 dimensions, since the last dimension is incorrect
    # print("positions: " + str(positions))
    positions = tf.cast(positions, tf.int32)
    # print("positions: " + str(positions))
    positions = tf.concat([positions, tf.reshape(closest_nodes, [-1, 1])], -1)
    # print("positions: " + str(positions))
    # print("compressed: " + str(compressed1))
    # print("norm_duplicate: " + str(tf.reshape(norm_duplicate, [-1, 1])))
    projection_signal = tf.reshape(encoded_question, [-1, depth])
    # print("projection_signal: " + str(projection_signal))
    batch_of_nodes = tf.tensor_scatter_nd_add(tf.reshape(batch_of_nodes, [-1, 512, depth]), positions, projection_signal)
    # print("batch_of_nodes: " + str(batch_of_nodes))
    batch_of_graphs = batch_of_graphs.replace(nodes=tf.reshape(batch_of_nodes, [-1, depth]))

    global_block = blocks.NodesToGlobalsAggregator(tf.math.unsorted_segment_mean)
    global_dense = tf.keras.layers.Dense(depth, activation='relu')

    num_recurrent_passes = FLAGS.recurrences
    previous_graphs = batch_of_graphs
    original_nodes = tf.reshape(original_graph.nodes, [1, 512, depth])
    dropout = tf.keras.layers.Dropout(FLAGS.dropout)
    layernorm_global = tf.keras.layers.LayerNormalization(epsilon=1e-6)
    layernorm_node = tf.keras.layers.LayerNormalization(epsilon=1e-6)
    new_global = global_block(previous_graphs)
    previous_graphs = previous_graphs.replace(globals=global_dense(new_global))
    previous_graphs = previous_graphs.replace(globals=layernorm_global(previous_graphs.globals))
    initial_global = previous_graphs.globals

    model_fn = snt.nets.MLP(output_sizes=[depth])

    for unused_pass in range(num_recurrent_passes):
        # Update the node features with the function
        updated_nodes = model_fn(previous_graphs.nodes)
        updated_nodes = layernorm_node(updated_nodes)
        temporary_graph = previous_graphs.replace(nodes=updated_nodes)
        graph_sum0 = tf.reduce_sum(tf.reshape(tf.math.abs(temporary_graph.nodes), [-1, 4 * 512 * 300]), -1)

        # Send the node features to the edges that are being sent by that node.
        nodes_at_edges = blocks.broadcast_sender_nodes_to_edges(temporary_graph)
        graph_sum1 = tf.reduce_sum(tf.reshape(tf.math.abs(nodes_at_edges), [-1, 4 * 5551 * 300]), -1)

        temporary_graph = temporary_graph.replace(edges=nodes_at_edges)

        # Aggregate the all of the edges received by every node.
        nodes_with_aggregated_edges = blocks.ReceivedEdgesToNodesAggregator(tf.math.unsorted_segment_mean)(
            temporary_graph)
        graph_sum2 = tf.reduce_sum(tf.reshape(tf.math.abs(nodes_with_aggregated_edges), [-1, 4 * 512 * 300]), -1)
        previous_graphs = previous_graphs.replace(nodes=nodes_with_aggregated_edges)

        current_nodes = previous_graphs.nodes
        current_nodes = tf.reshape(current_nodes, [-1, 512, depth])
        current_nodes = dropout(current_nodes, training=training)
        new_nodes = current_nodes * original_nodes
        previous_graphs = previous_graphs.replace(nodes=tf.reshape(new_nodes, [-1, depth]))
        old_global = previous_graphs.globals
        new_global = global_block(previous_graphs)
        previous_graphs = previous_graphs.replace(globals=global_dense(new_global))
        previous_graphs = previous_graphs.replace(globals=layernorm_global(previous_graphs.globals))

    output_global = tf.keras.layers.Dropout(FLAGS.dropout)(previous_graphs.globals, training=training)
    dense_layer = tf.keras.layers.Dense(1)
    logits = dense_layer(output_global)
    logits = tf.reshape(logits, [-1, num_choices])

    def loss_function(real, pred):
        return tf.nn.sparse_softmax_cross_entropy_with_logits(tf.reshape(real, [-1]), pred)

    # Calculate the loss
    loss = loss_function(features["answer_id"], logits)

    predictions = {
        'original': features["input_ids"],
        'prediction': tf.argmax(logits, -1),
        'correct': features["answer_id"],
        'logits': logits,
        'loss': loss,
        'output_global': tf.reshape(output_global, [-1, 4, 300]),
        'initial_global': tf.reshape(initial_global, [-1, 4, 300]),
        'old_global': tf.reshape(old_global, [-1, 4, 300]),
        'new_global': tf.reshape(new_global, [-1, 4, 300]),
        'graph_sum0': graph_sum0,
        'graph_sum1': graph_sum1,
        'graph_sum2': graph_sum2,
        'closest_nodes': tf.reshape(closest_nodes, [-1, 4, FLAGS.seq_len]),
        'input_id': features["input_ids"],
        'mask': tf.reshape(padding_mask, [-1, 4, FLAGS.seq_len]),
        'encoded_question': tf.reshape(encoded_question, [-1, 4, FLAGS.seq_len, depth])
    }

    if mode == tf.estimator.ModeKeys.PREDICT:
        export_outputs = {
            SIGNATURE_NAME: tf.estimator.export.PredictOutput(predictions)
        }
        return tf.estimator.EstimatorSpec(mode=mode, predictions=predictions, export_outputs=export_outputs)

    if mode == tf.estimator.ModeKeys.TRAIN:
        global_step = tf.compat.v1.train.get_or_create_global_step()

        optimizer = tf.compat.v1.train.AdamOptimizer(learning_rate=FLAGS.learning_rate, beta2=0.98, epsilon=1e-9)

        # Batch norm requires update ops to be added as a dependency to the train_op
        update_ops = tf.compat.v1.get_collection(tf.compat.v1.GraphKeys.UPDATE_OPS)
        with tf.control_dependencies(update_ops):
            train_op = optimizer.minimize(tf.reduce_mean(loss), global_step)
    else:
        train_op = None

    return tf.estimator.EstimatorSpec(
        mode=mode,
        predictions=predictions,
        loss=tf.reduce_mean(loss),
        train_op=train_op)