Example #1
0
  def test_unused_field_can_be_none(
      self, use_edges, use_nodes, use_globals, none_field):
    """Checks that computation can handle non-necessary fields left None."""
    input_graph = self._get_input_graph([none_field])
    node_block = blocks.NodeBlock(
        node_model_fn=self._node_model_fn,
        use_received_edges=use_edges,
        use_sent_edges=use_edges,
        use_nodes=use_nodes,
        use_globals=use_globals)
    output_graph = node_block(input_graph)

    model_inputs = []
    if use_edges:
      model_inputs.append(
          blocks.ReceivedEdgesToNodesAggregator(
              tf.unsorted_segment_sum)(input_graph))
      model_inputs.append(
          blocks.SentEdgesToNodesAggregator(
              tf.unsorted_segment_sum)(input_graph))
    if use_nodes:
      model_inputs.append(input_graph.nodes)
    if use_globals:
      model_inputs.append(blocks.broadcast_globals_to_nodes(input_graph))

    model_inputs = tf.concat(model_inputs, axis=-1)
    self.assertEqual(input_graph.edges, output_graph.edges)
    self.assertEqual(input_graph.globals, output_graph.globals)

    with self.test_session() as sess:
      actual_nodes, model_inputs_out = sess.run(
          (output_graph.nodes, model_inputs))

    expected_output_nodes = model_inputs_out * self._scale
    self.assertNDArrayNear(expected_output_nodes, actual_nodes, err=1e-4)
Example #2
0
    def __init__(self, edge_model_fn, node_model_fn, name="fwd_gnn"):
        # aggregator_fn = tf.math.unsorted_segment_sum,
        """Initializes the transition model"""

        super(GraphNeuralNetwork_transition, self).__init__(name=name)

        with self._enter_variable_scope():
            self._edge_block = blocks.EdgeBlock(edge_model_fn=edge_model_fn,
                                                use_edges=False,
                                                use_receiver_nodes=True,
                                                use_sender_nodes=True,
                                                use_globals=False,
                                                name='edge_block')

            # self._node_block = blocks.NodeBlock(
            #     node_model_fn=node_model_fn,
            #     use_received_edges=True,
            #     use_sent_edges=True,
            #     use_nodes=True,
            #     use_globals=True,
            #     received_edges_reducer=tf.math.unsorted_segment_sum,
            #     sent_edges_reducer=tf.math.unsorted_segment_sum,
            #     name="node_block")

            self._sent_edges_aggregator = blocks.SentEdgesToNodesAggregator(
                reducer=tf.math.unsorted_segment_sum)
            self._node_model = node_model_fn()
Example #3
0
 def __init__(self,
              node_model_fn,
              received_edges_reducer=tf.math.unsorted_segment_sum,
              sent_edges_reducer=tf.math.unsorted_segment_sum,
              name='dist_node_block'):
     super(NodeBlock, self).__init__(name=name)
     with self._enter_variable_scope():
         self._received_edges_aggregator = blocks.ReceivedEdgesToNodesAggregator(
             received_edges_reducer)
         self._sent_edges_aggregator = blocks.SentEdgesToNodesAggregator(
             sent_edges_reducer)
         self._node_model = node_model_fn()
Example #4
0
    def test_output_values(self, use_received_edges, use_sent_edges, use_nodes,
                           use_globals, received_edges_reducer,
                           sent_edges_reducer):
        """Compares the output of a NodeBlock to an explicit computation."""
        input_graph = self._get_input_graph()
        node_block = blocks.NodeBlock(
            node_model_fn=self._node_model_fn,
            use_received_edges=use_received_edges,
            use_sent_edges=use_sent_edges,
            use_nodes=use_nodes,
            use_globals=use_globals,
            received_edges_reducer=received_edges_reducer,
            sent_edges_reducer=sent_edges_reducer)
        output_graph = node_block(input_graph)

        model_inputs = []
        if use_received_edges:
            model_inputs.append(
                blocks.ReceivedEdgesToNodesAggregator(received_edges_reducer)(
                    input_graph))
        if use_sent_edges:
            model_inputs.append(
                blocks.SentEdgesToNodesAggregator(sent_edges_reducer)(
                    input_graph))
        if use_nodes:
            model_inputs.append(input_graph.nodes)
        if use_globals:
            model_inputs.append(blocks.broadcast_globals_to_nodes(input_graph))

        model_inputs = tf.concat(model_inputs, axis=-1)
        self.assertEqual(input_graph.edges, output_graph.edges)
        self.assertEqual(input_graph.globals, output_graph.globals)

        with self.test_session() as sess:
            output_graph_out, model_inputs_out = sess.run(
                (output_graph, model_inputs))

        expected_output_nodes = model_inputs_out * self._scale
        self.assertNDArrayNear(expected_output_nodes,
                               output_graph_out.nodes,
                               err=1e-4)
Example #5
0
class FieldAggregatorsTest(GraphModuleTest):

  @parameterized.named_parameters(
      ("edges_to_globals",
       blocks.EdgesToGlobalsAggregator(tf.unsorted_segment_sum),
       SEGMENT_SUM_EDGES_TO_GLOBALS,),
      ("nodes_to_globals",
       blocks.NodesToGlobalsAggregator(tf.unsorted_segment_sum),
       SEGMENT_SUM_NODES_TO_GLOBALS,),
      ("sent_edges_to_nodes",
       blocks.SentEdgesToNodesAggregator(tf.unsorted_segment_sum),
       SEGMENT_SUM_SENT_EDGES_TO_NODES,),
      ("received_edges_to_nodes",
       blocks.ReceivedEdgesToNodesAggregator(tf.unsorted_segment_sum),
       SEGMENT_SUM_RECEIVED_EDGES_TO_NODES),
  )
  def test_output_values(self, aggregator, expected):
    input_graph = self._get_input_graph()
    aggregated = aggregator(input_graph)
    with self.test_session() as sess:
      aggregated_out = sess.run(aggregated)
    self.assertNDArrayNear(
        np.array(expected, dtype=np.float32), aggregated_out, err=1e-4)

  @parameterized.named_parameters(
      ("edges_to_globals",
       blocks.EdgesToGlobalsAggregator(tf.unsorted_segment_sum),
       SEGMENT_SUM_EDGES_TO_GLOBALS,),
      ("nodes_to_globals",
       blocks.NodesToGlobalsAggregator(tf.unsorted_segment_sum),
       SEGMENT_SUM_NODES_TO_GLOBALS,),
      ("sent_edges_to_nodes",
       blocks.SentEdgesToNodesAggregator(tf.unsorted_segment_sum),
       SEGMENT_SUM_SENT_EDGES_TO_NODES,),
      ("received_edges_to_nodes",
       blocks.ReceivedEdgesToNodesAggregator(tf.unsorted_segment_sum),
       SEGMENT_SUM_RECEIVED_EDGES_TO_NODES),
  )
  def test_output_values_larger_rank(self, aggregator, expected):
    input_graph = self._get_input_graph()
    input_graph = input_graph.map(
        lambda v: tf.reshape(v, [v.get_shape().as_list()[0]] + [2, -1]))
    aggregated = aggregator(input_graph)
    with self.test_session() as sess:
      aggregated_out = sess.run(aggregated)
    self.assertNDArrayNear(
        np.reshape(np.array(expected, dtype=np.float32),
                   [len(expected)] + [2, -1]),
        aggregated_out,
        err=1e-4)

  @parameterized.named_parameters(
      ("received edges to nodes missing edges",
       blocks.ReceivedEdgesToNodesAggregator, "edges"),
      ("sent edges to nodes missing edges",
       blocks.SentEdgesToNodesAggregator, "edges"),
      ("nodes to globals missing nodes",
       blocks.NodesToGlobalsAggregator, "nodes"),
      ("edges to globals missing nodes",
       blocks.EdgesToGlobalsAggregator, "edges"),)
  def test_missing_field_raises_exception(self, constructor, none_field):
    """Tests that aggregator fail if a required field is missing."""
    input_graph = self._get_input_graph([none_field])
    with self.assertRaisesRegexp(ValueError, none_field):
      constructor(tf.unsorted_segment_sum)(input_graph)

  @parameterized.named_parameters(
      ("received edges to nodes missing nodes and globals",
       blocks.ReceivedEdgesToNodesAggregator, ["nodes", "globals"]),
      ("sent edges to nodes missing nodes and globals",
       blocks.SentEdgesToNodesAggregator, ["nodes", "globals"]),
      ("nodes to globals missing edges and globals",
       blocks.NodesToGlobalsAggregator,
       ["edges", "receivers", "senders", "globals"]),
      ("edges to globals missing globals",
       blocks.EdgesToGlobalsAggregator, ["globals"]),
  )
  def test_unused_field_can_be_none(self, constructor, none_fields):
    """Tests that aggregator fail if a required field is missing."""
    input_graph = self._get_input_graph(none_fields)
    constructor(tf.unsorted_segment_sum)(input_graph)
Example #6
0
tvars = graph_network.trainable_variables
print('')

###############
# broadcast

graphs_tuple = utils_tf.data_dicts_to_graphs_tuple([data_dict_0])
updated_broadcast_globals_to_nodes = graphs_tuple.replace(
    nodes=blocks.broadcast_globals_to_nodes(graphs_tuple))
updated_broadcast_globals_to_edges = graphs_tuple.replace(
    edges=blocks.broadcast_globals_to_edges(graphs_tuple))
updated_broadcast_sender_nodes_to_edges = graphs_tuple.replace(
    edges=blocks.broadcast_sender_nodes_to_edges(graphs_tuple))
updated_broadcast_receiver_nodes_to_edges = graphs_tuple.replace(
    edges=blocks.broadcast_receiver_nodes_to_edges(graphs_tuple))

############
# aggregate

graphs_tuple = utils_tf.data_dicts_to_graphs_tuple([data_dict_0])

reducer = tf.math.unsorted_segment_sum  #######yr
updated_edges_to_globals = graphs_tuple.replace(
    globals=blocks.EdgesToGlobalsAggregator(reducer=reducer)(graphs_tuple))
updated_nodes_to_globals = graphs_tuple.replace(
    globals=blocks.NodesToGlobalsAggregator(reducer=reducer)(graphs_tuple))
updated_sent_edges_to_nodes = graphs_tuple.replace(
    nodes=blocks.SentEdgesToNodesAggregator(reducer=reducer)(graphs_tuple))
updated_received_edges_to_nodes = graphs_tuple.replace(
    nodes=blocks.ReceivedEdgesToNodesAggregator(reducer=reducer)(graphs_tuple))