Example #1
0
  def test_unused_field_can_be_none(
      self, use_edges, use_nodes, use_globals, none_field):
    """Checks that computation can handle non-necessary fields left None."""
    input_graph = self._get_input_graph([none_field])
    global_block = blocks.GlobalBlock(
        global_model_fn=self._global_model_fn,
        use_edges=use_edges,
        use_nodes=use_nodes,
        use_globals=use_globals)
    output_graph = global_block(input_graph)

    model_inputs = []
    if use_edges:
      model_inputs.append(
          blocks.EdgesToGlobalsAggregator(tf.unsorted_segment_sum)(input_graph))
    if use_nodes:
      model_inputs.append(
          blocks.NodesToGlobalsAggregator(tf.unsorted_segment_sum)(input_graph))
    if use_globals:
      model_inputs.append(input_graph.globals)

    model_inputs = tf.concat(model_inputs, axis=-1)
    self.assertEqual(input_graph.edges, output_graph.edges)
    self.assertEqual(input_graph.nodes, output_graph.nodes)

    with self.test_session() as sess:
      actual_globals, model_inputs_out = sess.run(
          (output_graph.globals, model_inputs))

    expected_output_globals = model_inputs_out * self._scale
    self.assertNDArrayNear(expected_output_globals, actual_globals, err=1e-4)
Example #2
0
  def test_output_values(
      self, use_edges, use_nodes, use_globals, edges_reducer, nodes_reducer):
    """Compares the output of a GlobalBlock to an explicit computation."""
    input_graph = self._get_input_graph()
    global_block = blocks.GlobalBlock(
        global_model_fn=self._global_model_fn,
        use_edges=use_edges,
        use_nodes=use_nodes,
        use_globals=use_globals,
        edges_reducer=edges_reducer,
        nodes_reducer=nodes_reducer)
    output_graph = global_block(input_graph)

    model_inputs = []
    if use_edges:
      model_inputs.append(
          blocks.EdgesToGlobalsAggregator(edges_reducer)(input_graph))
    if use_nodes:
      model_inputs.append(
          blocks.NodesToGlobalsAggregator(nodes_reducer)(input_graph))
    if use_globals:
      model_inputs.append(input_graph.globals)

    model_inputs = tf.concat(model_inputs, axis=-1)
    self.assertEqual(input_graph.edges, output_graph.edges)
    self.assertEqual(input_graph.nodes, output_graph.nodes)

    with self.test_session() as sess:
      output_graph_out, model_inputs_out = sess.run(
          (output_graph, model_inputs))

    expected_output_globals = model_inputs_out * self._scale
    self.assertNDArrayNear(
        expected_output_globals, output_graph_out.globals, err=1e-4)
    def _build(self, inputs):
        graph = inputs['graph']
        nodes = self._input_dense(graph.nodes)

        q = self._q_dense(nodes)
        k = self._k_dense(nodes)
        v = self._v_dense(nodes)

        q = self.split_heads(q, self.FLAGS)
        k = self.split_heads(k, self.FLAGS)
        v = self.split_heads(v, self.FLAGS)

        attention_graph = self._sa(node_values=v,
                                   node_keys=k,
                                   node_queries=q,
                                   attention_graph=graph)

        attention_output = self.combine_heads(attention_graph.nodes,
                                              self.FLAGS)
        attention_output = self._output_dense(attention_output)

        sa_skip = nodes + attention_output  # residual/skip connection
        sa_normed = self._sa_laynorm(sa_skip)  # apply layer norm

        ff_skip = sa_normed + self._ff(sa_normed)  # residual/skip connection
        ff_normed = self._ff_laynorm(ff_skip)  # apply layer norm

        # nodes to global aggregator with graph
        if self.FLAGS['tf_gate']:
            ff_normed_doub = self._doub_dense(ff_normed)
            # TODO: try raw activation to make sure this is not just helping because of more weights
            weights = tf.tanh(ff_normed_doub[:, :self.hidden_size])
            vals = tf.sigmoid(ff_normed_doub[:, self.hidden_size:])
            gated = weights * vals
            out_graph = attention_graph.replace(nodes=gated)
        else:
            out_graph = attention_graph.replace(nodes=ff_normed)

        reducer = {
            'sum': tf.unsorted_segment_sum,
            'max': blocks.unsorted_segment_max_or_zero
        }[self.FLAGS['tf_reducer']]
        agg = blocks.NodesToGlobalsAggregator(reducer=reducer)
        agged = agg(out_graph)
        agged = agged / tf.cast(out_graph.n_node, tf.float32)[:, None]
        return agged
Example #4
0
class FieldAggregatorsTest(GraphModuleTest):

  @parameterized.named_parameters(
      ("edges_to_globals",
       blocks.EdgesToGlobalsAggregator(tf.unsorted_segment_sum),
       SEGMENT_SUM_EDGES_TO_GLOBALS,),
      ("nodes_to_globals",
       blocks.NodesToGlobalsAggregator(tf.unsorted_segment_sum),
       SEGMENT_SUM_NODES_TO_GLOBALS,),
      ("sent_edges_to_nodes",
       blocks.SentEdgesToNodesAggregator(tf.unsorted_segment_sum),
       SEGMENT_SUM_SENT_EDGES_TO_NODES,),
      ("received_edges_to_nodes",
       blocks.ReceivedEdgesToNodesAggregator(tf.unsorted_segment_sum),
       SEGMENT_SUM_RECEIVED_EDGES_TO_NODES),
  )
  def test_output_values(self, aggregator, expected):
    input_graph = self._get_input_graph()
    aggregated = aggregator(input_graph)
    with self.test_session() as sess:
      aggregated_out = sess.run(aggregated)
    self.assertNDArrayNear(
        np.array(expected, dtype=np.float32), aggregated_out, err=1e-4)

  @parameterized.named_parameters(
      ("edges_to_globals",
       blocks.EdgesToGlobalsAggregator(tf.unsorted_segment_sum),
       SEGMENT_SUM_EDGES_TO_GLOBALS,),
      ("nodes_to_globals",
       blocks.NodesToGlobalsAggregator(tf.unsorted_segment_sum),
       SEGMENT_SUM_NODES_TO_GLOBALS,),
      ("sent_edges_to_nodes",
       blocks.SentEdgesToNodesAggregator(tf.unsorted_segment_sum),
       SEGMENT_SUM_SENT_EDGES_TO_NODES,),
      ("received_edges_to_nodes",
       blocks.ReceivedEdgesToNodesAggregator(tf.unsorted_segment_sum),
       SEGMENT_SUM_RECEIVED_EDGES_TO_NODES),
  )
  def test_output_values_larger_rank(self, aggregator, expected):
    input_graph = self._get_input_graph()
    input_graph = input_graph.map(
        lambda v: tf.reshape(v, [v.get_shape().as_list()[0]] + [2, -1]))
    aggregated = aggregator(input_graph)
    with self.test_session() as sess:
      aggregated_out = sess.run(aggregated)
    self.assertNDArrayNear(
        np.reshape(np.array(expected, dtype=np.float32),
                   [len(expected)] + [2, -1]),
        aggregated_out,
        err=1e-4)

  @parameterized.named_parameters(
      ("received edges to nodes missing edges",
       blocks.ReceivedEdgesToNodesAggregator, "edges"),
      ("sent edges to nodes missing edges",
       blocks.SentEdgesToNodesAggregator, "edges"),
      ("nodes to globals missing nodes",
       blocks.NodesToGlobalsAggregator, "nodes"),
      ("edges to globals missing nodes",
       blocks.EdgesToGlobalsAggregator, "edges"),)
  def test_missing_field_raises_exception(self, constructor, none_field):
    """Tests that aggregator fail if a required field is missing."""
    input_graph = self._get_input_graph([none_field])
    with self.assertRaisesRegexp(ValueError, none_field):
      constructor(tf.unsorted_segment_sum)(input_graph)

  @parameterized.named_parameters(
      ("received edges to nodes missing nodes and globals",
       blocks.ReceivedEdgesToNodesAggregator, ["nodes", "globals"]),
      ("sent edges to nodes missing nodes and globals",
       blocks.SentEdgesToNodesAggregator, ["nodes", "globals"]),
      ("nodes to globals missing edges and globals",
       blocks.NodesToGlobalsAggregator,
       ["edges", "receivers", "senders", "globals"]),
      ("edges to globals missing globals",
       blocks.EdgesToGlobalsAggregator, ["globals"]),
  )
  def test_unused_field_can_be_none(self, constructor, none_fields):
    """Tests that aggregator fail if a required field is missing."""
    input_graph = self._get_input_graph(none_fields)
    constructor(tf.unsorted_segment_sum)(input_graph)
Example #5
0
tvars = graph_network.trainable_variables
print('')

###############
# broadcast

graphs_tuple = utils_tf.data_dicts_to_graphs_tuple([data_dict_0])
updated_broadcast_globals_to_nodes = graphs_tuple.replace(
    nodes=blocks.broadcast_globals_to_nodes(graphs_tuple))
updated_broadcast_globals_to_edges = graphs_tuple.replace(
    edges=blocks.broadcast_globals_to_edges(graphs_tuple))
updated_broadcast_sender_nodes_to_edges = graphs_tuple.replace(
    edges=blocks.broadcast_sender_nodes_to_edges(graphs_tuple))
updated_broadcast_receiver_nodes_to_edges = graphs_tuple.replace(
    edges=blocks.broadcast_receiver_nodes_to_edges(graphs_tuple))

############
# aggregate

graphs_tuple = utils_tf.data_dicts_to_graphs_tuple([data_dict_0])

reducer = tf.math.unsorted_segment_sum  #######yr
updated_edges_to_globals = graphs_tuple.replace(
    globals=blocks.EdgesToGlobalsAggregator(reducer=reducer)(graphs_tuple))
updated_nodes_to_globals = graphs_tuple.replace(
    globals=blocks.NodesToGlobalsAggregator(reducer=reducer)(graphs_tuple))
updated_sent_edges_to_nodes = graphs_tuple.replace(
    nodes=blocks.SentEdgesToNodesAggregator(reducer=reducer)(graphs_tuple))
updated_received_edges_to_nodes = graphs_tuple.replace(
    nodes=blocks.ReceivedEdgesToNodesAggregator(reducer=reducer)(graphs_tuple))
def model_fn(features, labels, mode, params):
    sentences = features["input_ids"]
    word_embedding = tf.constant(params['word_embedding'])
    graph_nodes = params['graph_nodes']
    graph_edges = params['graph_edges']
    depth = graph_nodes.shape[1]
    training = mode == tf.estimator.ModeKeys.TRAIN

    padding_mask = tf.cast(tf.not_equal(tf.cast(sentences, tf.int32), tf.constant([[1]])),
                           tf.int32)  # 0 means the token needs to be masked. 1 means it is not masked.
    padding_mask = tf.reshape(padding_mask, [-1, FLAGS.seq_len, 1])
    sentences = tf.nn.embedding_lookup(word_embedding, sentences)
    sentences = tf.reshape(sentences, [-1, FLAGS.seq_len, depth])
    # print("sentences: " + str(sentences))
    # print("padding_mask: " + str(padding_mask))
    question_encoder = tf.keras.layers.LSTM(depth, dropout=FLAGS.dropout, return_sequences=True)
    encoded_question = question_encoder(sentences, training=training)
    encoded_question = tf.cast(padding_mask, tf.float32) * tf.cast(encoded_question, tf.float32)
    encoded_question = tf.reshape(tf.cast(encoded_question, tf.float32), [-1, depth])

    # The template graph
    nodes = graph_nodes.astype(np.float32)
    edges = np.ones([int(np.sum(graph_edges)), 1]).astype(np.float32)
    senders, receivers = np.nonzero(graph_edges)
    globals = np.zeros(FLAGS.global_size).astype(np.float32)

    graph_dict = {"globals": globals,
                  "nodes": nodes,
                  "edges": edges,
                  "senders": senders,
                  "receivers": receivers}
    original_graph = utils_tf.data_dicts_to_graphs_tuple([graph_dict])
    graph_dict["nodes"] = nodes * 0
    # print("encoded_question.shape[0]: " + str(encoded_question.shape[0]))
    batch_of_tensor_data_dicts = [graph_dict for i in range(sentences.shape[0])]

    batch_of_graphs = utils_tf.data_dicts_to_graphs_tuple(batch_of_tensor_data_dicts)
    batch_of_nodes = batch_of_graphs.nodes
    # print("batch_of_nodes: " + str(batch_of_nodes))

    # Euclidean distance to identify closest nodes
    na = tf.reduce_sum(tf.square(tf.math.l2_normalize(encoded_question, -1)), 1)
    nb = tf.reduce_sum(tf.square(tf.math.l2_normalize(nodes, -1)), 1)

    # na as a row and nb as a column vectors
    na = tf.reshape(na, [-1, 1])
    nb = tf.reshape(nb, [1, -1])

    # return pairwise euclidead difference matrix
    distance = tf.sqrt(tf.maximum(na - 2 * tf.matmul(encoded_question, nodes, False, True) + nb, 0.0))

    # calculate attention over the graph
    closest_nodes = tf.cast(tf.argmin(distance, -1), tf.int32)
    # print("closest_nodes: " + str(closest_nodes))

    # # Write the signals onto these nodes
    positions = tf.where(tf.not_equal(tf.reshape(closest_nodes, [-1, FLAGS.seq_len]), 99999))
    # print("positions: " + str(positions))
    positions = tf.slice(positions, [0, 0], [-1, 1])  # we only want the first 2 dimensions, since the last dimension is incorrect
    # print("positions: " + str(positions))
    positions = tf.cast(positions, tf.int32)
    # print("positions: " + str(positions))
    positions = tf.concat([positions, tf.reshape(closest_nodes, [-1, 1])], -1)
    # print("positions: " + str(positions))
    # print("compressed: " + str(compressed1))
    # print("norm_duplicate: " + str(tf.reshape(norm_duplicate, [-1, 1])))
    projection_signal = tf.reshape(encoded_question, [-1, depth])
    # print("projection_signal: " + str(projection_signal))
    batch_of_nodes = tf.tensor_scatter_nd_add(tf.reshape(batch_of_nodes, [-1, 512, depth]), positions, projection_signal)
    # print("batch_of_nodes: " + str(batch_of_nodes))
    batch_of_graphs = batch_of_graphs.replace(nodes=tf.reshape(batch_of_nodes, [-1, depth]))

    global_block = blocks.NodesToGlobalsAggregator(tf.math.unsorted_segment_mean)
    global_dense = tf.keras.layers.Dense(depth, activation='relu')

    num_recurrent_passes = FLAGS.recurrences
    previous_graphs = batch_of_graphs
    original_nodes = tf.reshape(original_graph.nodes, [1, 512, depth])
    dropout = tf.keras.layers.Dropout(FLAGS.dropout)
    layernorm_global = tf.keras.layers.LayerNormalization(epsilon=1e-6)
    layernorm_node = tf.keras.layers.LayerNormalization(epsilon=1e-6)
    new_global = global_block(previous_graphs)
    previous_graphs = previous_graphs.replace(globals=global_dense(new_global))
    previous_graphs = previous_graphs.replace(globals=layernorm_global(previous_graphs.globals))
    initial_global = previous_graphs.globals

    model_fn = snt.nets.MLP(output_sizes=[depth])

    for unused_pass in range(num_recurrent_passes):
        # Update the node features with the function
        updated_nodes = model_fn(previous_graphs.nodes)
        updated_nodes = layernorm_node(updated_nodes)
        temporary_graph = previous_graphs.replace(nodes=updated_nodes)
        graph_sum0 = tf.reduce_sum(tf.reshape(tf.math.abs(temporary_graph.nodes), [-1, 4 * 512 * 300]), -1)

        # Send the node features to the edges that are being sent by that node.
        nodes_at_edges = blocks.broadcast_sender_nodes_to_edges(temporary_graph)
        graph_sum1 = tf.reduce_sum(tf.reshape(tf.math.abs(nodes_at_edges), [-1, 4 * 5551 * 300]), -1)

        temporary_graph = temporary_graph.replace(edges=nodes_at_edges)

        # Aggregate the all of the edges received by every node.
        nodes_with_aggregated_edges = blocks.ReceivedEdgesToNodesAggregator(tf.math.unsorted_segment_mean)(
            temporary_graph)
        graph_sum2 = tf.reduce_sum(tf.reshape(tf.math.abs(nodes_with_aggregated_edges), [-1, 4 * 512 * 300]), -1)
        previous_graphs = previous_graphs.replace(nodes=nodes_with_aggregated_edges)

        current_nodes = previous_graphs.nodes
        current_nodes = tf.reshape(current_nodes, [-1, 512, depth])
        current_nodes = dropout(current_nodes, training=training)
        new_nodes = current_nodes * original_nodes
        previous_graphs = previous_graphs.replace(nodes=tf.reshape(new_nodes, [-1, depth]))
        old_global = previous_graphs.globals
        new_global = global_block(previous_graphs)
        previous_graphs = previous_graphs.replace(globals=global_dense(new_global))
        previous_graphs = previous_graphs.replace(globals=layernorm_global(previous_graphs.globals))

    output_global = tf.keras.layers.Dropout(FLAGS.dropout)(previous_graphs.globals, training=training)
    dense_layer = tf.keras.layers.Dense(1)
    logits = dense_layer(output_global)
    logits = tf.reshape(logits, [-1, num_choices])

    def loss_function(real, pred):
        return tf.nn.sparse_softmax_cross_entropy_with_logits(tf.reshape(real, [-1]), pred)

    # Calculate the loss
    loss = loss_function(features["answer_id"], logits)

    predictions = {
        'original': features["input_ids"],
        'prediction': tf.argmax(logits, -1),
        'correct': features["answer_id"],
        'logits': logits,
        'loss': loss,
        'output_global': tf.reshape(output_global, [-1, 4, 300]),
        'initial_global': tf.reshape(initial_global, [-1, 4, 300]),
        'old_global': tf.reshape(old_global, [-1, 4, 300]),
        'new_global': tf.reshape(new_global, [-1, 4, 300]),
        'graph_sum0': graph_sum0,
        'graph_sum1': graph_sum1,
        'graph_sum2': graph_sum2,
        'closest_nodes': tf.reshape(closest_nodes, [-1, 4, FLAGS.seq_len]),
        'input_id': features["input_ids"],
        'mask': tf.reshape(padding_mask, [-1, 4, FLAGS.seq_len]),
        'encoded_question': tf.reshape(encoded_question, [-1, 4, FLAGS.seq_len, depth])
    }

    if mode == tf.estimator.ModeKeys.PREDICT:
        export_outputs = {
            SIGNATURE_NAME: tf.estimator.export.PredictOutput(predictions)
        }
        return tf.estimator.EstimatorSpec(mode=mode, predictions=predictions, export_outputs=export_outputs)

    if mode == tf.estimator.ModeKeys.TRAIN:
        global_step = tf.compat.v1.train.get_or_create_global_step()

        optimizer = tf.compat.v1.train.AdamOptimizer(learning_rate=FLAGS.learning_rate, beta2=0.98, epsilon=1e-9)

        # Batch norm requires update ops to be added as a dependency to the train_op
        update_ops = tf.compat.v1.get_collection(tf.compat.v1.GraphKeys.UPDATE_OPS)
        with tf.control_dependencies(update_ops):
            train_op = optimizer.minimize(tf.reduce_mean(loss), global_step)
    else:
        train_op = None

    return tf.estimator.EstimatorSpec(
        mode=mode,
        predictions=predictions,
        loss=tf.reduce_mean(loss),
        train_op=train_op)