コード例 #1
0
  def test_unused_field_can_be_none(
      self, use_edges, use_nodes, use_globals, none_field):
    """Checks that computation can handle non-necessary fields left None."""
    input_graph = self._get_input_graph([none_field])
    node_block = blocks.NodeBlock(
        node_model_fn=self._node_model_fn,
        use_received_edges=use_edges,
        use_sent_edges=use_edges,
        use_nodes=use_nodes,
        use_globals=use_globals)
    output_graph = node_block(input_graph)

    model_inputs = []
    if use_edges:
      model_inputs.append(
          blocks.ReceivedEdgesToNodesAggregator(
              tf.unsorted_segment_sum)(input_graph))
      model_inputs.append(
          blocks.SentEdgesToNodesAggregator(
              tf.unsorted_segment_sum)(input_graph))
    if use_nodes:
      model_inputs.append(input_graph.nodes)
    if use_globals:
      model_inputs.append(blocks.broadcast_globals_to_nodes(input_graph))

    model_inputs = tf.concat(model_inputs, axis=-1)
    self.assertEqual(input_graph.edges, output_graph.edges)
    self.assertEqual(input_graph.globals, output_graph.globals)

    with self.test_session() as sess:
      actual_nodes, model_inputs_out = sess.run(
          (output_graph.nodes, model_inputs))

    expected_output_nodes = model_inputs_out * self._scale
    self.assertNDArrayNear(expected_output_nodes, actual_nodes, err=1e-4)
コード例 #2
0
ファイル: GNN.py プロジェクト: YannickCharles/RQ1
    def __init__(self, edge_model_fn, node_model_fn, name="fwd_gnn"):
        # aggregator_fn = tf.math.unsorted_segment_sum,
        """Initializes the transition model"""

        super(GraphNeuralNetwork_transition, self).__init__(name=name)

        with self._enter_variable_scope():
            self._edge_block = blocks.EdgeBlock(edge_model_fn=edge_model_fn,
                                                use_edges=False,
                                                use_receiver_nodes=True,
                                                use_sender_nodes=True,
                                                use_globals=False,
                                                name='edge_block')

            # self._node_block = blocks.NodeBlock(
            #     node_model_fn=node_model_fn,
            #     use_received_edges=True,
            #     use_sent_edges=True,
            #     use_nodes=True,
            #     use_globals=True,
            #     received_edges_reducer=tf.math.unsorted_segment_sum,
            #     sent_edges_reducer=tf.math.unsorted_segment_sum,
            #     name="node_block")

            self._sent_edges_aggregator = blocks.SentEdgesToNodesAggregator(
                reducer=tf.math.unsorted_segment_sum)
            self._node_model = node_model_fn()
コード例 #3
0
 def __init__(self,
              node_model_fn,
              received_edges_reducer=tf.math.unsorted_segment_sum,
              sent_edges_reducer=tf.math.unsorted_segment_sum,
              name='dist_node_block'):
     super(NodeBlock, self).__init__(name=name)
     with self._enter_variable_scope():
         self._received_edges_aggregator = blocks.ReceivedEdgesToNodesAggregator(
             received_edges_reducer)
         self._sent_edges_aggregator = blocks.SentEdgesToNodesAggregator(
             sent_edges_reducer)
         self._node_model = node_model_fn()
コード例 #4
0
ファイル: blocks_test.py プロジェクト: Tubbz-alt/Graph_Nets-2
    def test_output_values(self, use_received_edges, use_sent_edges, use_nodes,
                           use_globals, received_edges_reducer,
                           sent_edges_reducer):
        """Compares the output of a NodeBlock to an explicit computation."""
        input_graph = self._get_input_graph()
        node_block = blocks.NodeBlock(
            node_model_fn=self._node_model_fn,
            use_received_edges=use_received_edges,
            use_sent_edges=use_sent_edges,
            use_nodes=use_nodes,
            use_globals=use_globals,
            received_edges_reducer=received_edges_reducer,
            sent_edges_reducer=sent_edges_reducer)
        output_graph = node_block(input_graph)

        model_inputs = []
        if use_received_edges:
            model_inputs.append(
                blocks.ReceivedEdgesToNodesAggregator(received_edges_reducer)(
                    input_graph))
        if use_sent_edges:
            model_inputs.append(
                blocks.SentEdgesToNodesAggregator(sent_edges_reducer)(
                    input_graph))
        if use_nodes:
            model_inputs.append(input_graph.nodes)
        if use_globals:
            model_inputs.append(blocks.broadcast_globals_to_nodes(input_graph))

        model_inputs = tf.concat(model_inputs, axis=-1)
        self.assertEqual(input_graph.edges, output_graph.edges)
        self.assertEqual(input_graph.globals, output_graph.globals)

        with self.test_session() as sess:
            output_graph_out, model_inputs_out = sess.run(
                (output_graph, model_inputs))

        expected_output_nodes = model_inputs_out * self._scale
        self.assertNDArrayNear(expected_output_nodes,
                               output_graph_out.nodes,
                               err=1e-4)
コード例 #5
0
class FieldAggregatorsTest(GraphModuleTest):

  @parameterized.named_parameters(
      ("edges_to_globals",
       blocks.EdgesToGlobalsAggregator(tf.unsorted_segment_sum),
       SEGMENT_SUM_EDGES_TO_GLOBALS,),
      ("nodes_to_globals",
       blocks.NodesToGlobalsAggregator(tf.unsorted_segment_sum),
       SEGMENT_SUM_NODES_TO_GLOBALS,),
      ("sent_edges_to_nodes",
       blocks.SentEdgesToNodesAggregator(tf.unsorted_segment_sum),
       SEGMENT_SUM_SENT_EDGES_TO_NODES,),
      ("received_edges_to_nodes",
       blocks.ReceivedEdgesToNodesAggregator(tf.unsorted_segment_sum),
       SEGMENT_SUM_RECEIVED_EDGES_TO_NODES),
  )
  def test_output_values(self, aggregator, expected):
    input_graph = self._get_input_graph()
    aggregated = aggregator(input_graph)
    with self.test_session() as sess:
      aggregated_out = sess.run(aggregated)
    self.assertNDArrayNear(
        np.array(expected, dtype=np.float32), aggregated_out, err=1e-4)

  @parameterized.named_parameters(
      ("edges_to_globals",
       blocks.EdgesToGlobalsAggregator(tf.unsorted_segment_sum),
       SEGMENT_SUM_EDGES_TO_GLOBALS,),
      ("nodes_to_globals",
       blocks.NodesToGlobalsAggregator(tf.unsorted_segment_sum),
       SEGMENT_SUM_NODES_TO_GLOBALS,),
      ("sent_edges_to_nodes",
       blocks.SentEdgesToNodesAggregator(tf.unsorted_segment_sum),
       SEGMENT_SUM_SENT_EDGES_TO_NODES,),
      ("received_edges_to_nodes",
       blocks.ReceivedEdgesToNodesAggregator(tf.unsorted_segment_sum),
       SEGMENT_SUM_RECEIVED_EDGES_TO_NODES),
  )
  def test_output_values_larger_rank(self, aggregator, expected):
    input_graph = self._get_input_graph()
    input_graph = input_graph.map(
        lambda v: tf.reshape(v, [v.get_shape().as_list()[0]] + [2, -1]))
    aggregated = aggregator(input_graph)
    with self.test_session() as sess:
      aggregated_out = sess.run(aggregated)
    self.assertNDArrayNear(
        np.reshape(np.array(expected, dtype=np.float32),
                   [len(expected)] + [2, -1]),
        aggregated_out,
        err=1e-4)

  @parameterized.named_parameters(
      ("received edges to nodes missing edges",
       blocks.ReceivedEdgesToNodesAggregator, "edges"),
      ("sent edges to nodes missing edges",
       blocks.SentEdgesToNodesAggregator, "edges"),
      ("nodes to globals missing nodes",
       blocks.NodesToGlobalsAggregator, "nodes"),
      ("edges to globals missing nodes",
       blocks.EdgesToGlobalsAggregator, "edges"),)
  def test_missing_field_raises_exception(self, constructor, none_field):
    """Tests that aggregator fail if a required field is missing."""
    input_graph = self._get_input_graph([none_field])
    with self.assertRaisesRegexp(ValueError, none_field):
      constructor(tf.unsorted_segment_sum)(input_graph)

  @parameterized.named_parameters(
      ("received edges to nodes missing nodes and globals",
       blocks.ReceivedEdgesToNodesAggregator, ["nodes", "globals"]),
      ("sent edges to nodes missing nodes and globals",
       blocks.SentEdgesToNodesAggregator, ["nodes", "globals"]),
      ("nodes to globals missing edges and globals",
       blocks.NodesToGlobalsAggregator,
       ["edges", "receivers", "senders", "globals"]),
      ("edges to globals missing globals",
       blocks.EdgesToGlobalsAggregator, ["globals"]),
  )
  def test_unused_field_can_be_none(self, constructor, none_fields):
    """Tests that aggregator fail if a required field is missing."""
    input_graph = self._get_input_graph(none_fields)
    constructor(tf.unsorted_segment_sum)(input_graph)
コード例 #6
0
tvars = graph_network.trainable_variables
print('')

###############
# broadcast

graphs_tuple = utils_tf.data_dicts_to_graphs_tuple([data_dict_0])
updated_broadcast_globals_to_nodes = graphs_tuple.replace(
    nodes=blocks.broadcast_globals_to_nodes(graphs_tuple))
updated_broadcast_globals_to_edges = graphs_tuple.replace(
    edges=blocks.broadcast_globals_to_edges(graphs_tuple))
updated_broadcast_sender_nodes_to_edges = graphs_tuple.replace(
    edges=blocks.broadcast_sender_nodes_to_edges(graphs_tuple))
updated_broadcast_receiver_nodes_to_edges = graphs_tuple.replace(
    edges=blocks.broadcast_receiver_nodes_to_edges(graphs_tuple))

############
# aggregate

graphs_tuple = utils_tf.data_dicts_to_graphs_tuple([data_dict_0])

reducer = tf.math.unsorted_segment_sum  #######yr
updated_edges_to_globals = graphs_tuple.replace(
    globals=blocks.EdgesToGlobalsAggregator(reducer=reducer)(graphs_tuple))
updated_nodes_to_globals = graphs_tuple.replace(
    globals=blocks.NodesToGlobalsAggregator(reducer=reducer)(graphs_tuple))
updated_sent_edges_to_nodes = graphs_tuple.replace(
    nodes=blocks.SentEdgesToNodesAggregator(reducer=reducer)(graphs_tuple))
updated_received_edges_to_nodes = graphs_tuple.replace(
    nodes=blocks.ReceivedEdgesToNodesAggregator(reducer=reducer)(graphs_tuple))