示例#1
0
  def test_unused_field_can_be_none(
      self, use_edges, use_nodes, use_globals, none_field):
    """Checks that computation can handle non-necessary fields left None."""
    input_graph = self._get_input_graph([none_field])
    node_block = blocks.NodeBlock(
        node_model_fn=self._node_model_fn,
        use_received_edges=use_edges,
        use_sent_edges=use_edges,
        use_nodes=use_nodes,
        use_globals=use_globals)
    output_graph = node_block(input_graph)

    model_inputs = []
    if use_edges:
      model_inputs.append(
          blocks.ReceivedEdgesToNodesAggregator(
              tf.unsorted_segment_sum)(input_graph))
      model_inputs.append(
          blocks.SentEdgesToNodesAggregator(
              tf.unsorted_segment_sum)(input_graph))
    if use_nodes:
      model_inputs.append(input_graph.nodes)
    if use_globals:
      model_inputs.append(blocks.broadcast_globals_to_nodes(input_graph))

    model_inputs = tf.concat(model_inputs, axis=-1)
    self.assertEqual(input_graph.edges, output_graph.edges)
    self.assertEqual(input_graph.globals, output_graph.globals)

    with self.test_session() as sess:
      actual_nodes, model_inputs_out = sess.run(
          (output_graph.nodes, model_inputs))

    expected_output_nodes = model_inputs_out * self._scale
    self.assertNDArrayNear(expected_output_nodes, actual_nodes, err=1e-4)
示例#2
0
文件: GNN.py 项目: YannickCharles/RQ1
    def __init__(self, edge_model_fn, node_model_fn, name="fwd_gnn"):
        # aggregator_fn = tf.math.unsorted_segment_sum,
        """Initializes the transition model"""

        super(GraphNeuralNetwork_transition, self).__init__(name=name)

        with self._enter_variable_scope():
            self._edge_block = blocks.EdgeBlock(edge_model_fn=edge_model_fn,
                                                use_edges=False,
                                                use_receiver_nodes=True,
                                                use_sender_nodes=True,
                                                use_globals=False,
                                                name='edge_block')

            # self._node_block = blocks.NodeBlock(
            #     node_model_fn=node_model_fn,
            #     use_received_edges=True,
            #     use_sent_edges=True,
            #     use_nodes=True,
            #     use_globals=True,
            #     received_edges_reducer=tf.math.unsorted_segment_sum,
            #     sent_edges_reducer=tf.math.unsorted_segment_sum,
            #     name="node_block")

            self._sent_edges_aggregator = blocks.SentEdgesToNodesAggregator(
                reducer=tf.math.unsorted_segment_sum)
            self._node_model = node_model_fn()
示例#3
0
 def __init__(self,
              node_model_fn,
              received_edges_reducer=tf.math.unsorted_segment_sum,
              sent_edges_reducer=tf.math.unsorted_segment_sum,
              name='dist_node_block'):
     super(NodeBlock, self).__init__(name=name)
     with self._enter_variable_scope():
         self._received_edges_aggregator = blocks.ReceivedEdgesToNodesAggregator(
             received_edges_reducer)
         self._sent_edges_aggregator = blocks.SentEdgesToNodesAggregator(
             sent_edges_reducer)
         self._node_model = node_model_fn()
示例#4
0
    def test_output_values(self, use_received_edges, use_sent_edges, use_nodes,
                           use_globals, received_edges_reducer,
                           sent_edges_reducer):
        """Compares the output of a NodeBlock to an explicit computation."""
        input_graph = self._get_input_graph()
        node_block = blocks.NodeBlock(
            node_model_fn=self._node_model_fn,
            use_received_edges=use_received_edges,
            use_sent_edges=use_sent_edges,
            use_nodes=use_nodes,
            use_globals=use_globals,
            received_edges_reducer=received_edges_reducer,
            sent_edges_reducer=sent_edges_reducer)
        output_graph = node_block(input_graph)

        model_inputs = []
        if use_received_edges:
            model_inputs.append(
                blocks.ReceivedEdgesToNodesAggregator(received_edges_reducer)(
                    input_graph))
        if use_sent_edges:
            model_inputs.append(
                blocks.SentEdgesToNodesAggregator(sent_edges_reducer)(
                    input_graph))
        if use_nodes:
            model_inputs.append(input_graph.nodes)
        if use_globals:
            model_inputs.append(blocks.broadcast_globals_to_nodes(input_graph))

        model_inputs = tf.concat(model_inputs, axis=-1)
        self.assertEqual(input_graph.edges, output_graph.edges)
        self.assertEqual(input_graph.globals, output_graph.globals)

        with self.test_session() as sess:
            output_graph_out, model_inputs_out = sess.run(
                (output_graph, model_inputs))

        expected_output_nodes = model_inputs_out * self._scale
        self.assertNDArrayNear(expected_output_nodes,
                               output_graph_out.nodes,
                               err=1e-4)
示例#5
0
class FieldAggregatorsTest(GraphModuleTest):

  @parameterized.named_parameters(
      ("edges_to_globals",
       blocks.EdgesToGlobalsAggregator(tf.unsorted_segment_sum),
       SEGMENT_SUM_EDGES_TO_GLOBALS,),
      ("nodes_to_globals",
       blocks.NodesToGlobalsAggregator(tf.unsorted_segment_sum),
       SEGMENT_SUM_NODES_TO_GLOBALS,),
      ("sent_edges_to_nodes",
       blocks.SentEdgesToNodesAggregator(tf.unsorted_segment_sum),
       SEGMENT_SUM_SENT_EDGES_TO_NODES,),
      ("received_edges_to_nodes",
       blocks.ReceivedEdgesToNodesAggregator(tf.unsorted_segment_sum),
       SEGMENT_SUM_RECEIVED_EDGES_TO_NODES),
  )
  def test_output_values(self, aggregator, expected):
    input_graph = self._get_input_graph()
    aggregated = aggregator(input_graph)
    with self.test_session() as sess:
      aggregated_out = sess.run(aggregated)
    self.assertNDArrayNear(
        np.array(expected, dtype=np.float32), aggregated_out, err=1e-4)

  @parameterized.named_parameters(
      ("edges_to_globals",
       blocks.EdgesToGlobalsAggregator(tf.unsorted_segment_sum),
       SEGMENT_SUM_EDGES_TO_GLOBALS,),
      ("nodes_to_globals",
       blocks.NodesToGlobalsAggregator(tf.unsorted_segment_sum),
       SEGMENT_SUM_NODES_TO_GLOBALS,),
      ("sent_edges_to_nodes",
       blocks.SentEdgesToNodesAggregator(tf.unsorted_segment_sum),
       SEGMENT_SUM_SENT_EDGES_TO_NODES,),
      ("received_edges_to_nodes",
       blocks.ReceivedEdgesToNodesAggregator(tf.unsorted_segment_sum),
       SEGMENT_SUM_RECEIVED_EDGES_TO_NODES),
  )
  def test_output_values_larger_rank(self, aggregator, expected):
    input_graph = self._get_input_graph()
    input_graph = input_graph.map(
        lambda v: tf.reshape(v, [v.get_shape().as_list()[0]] + [2, -1]))
    aggregated = aggregator(input_graph)
    with self.test_session() as sess:
      aggregated_out = sess.run(aggregated)
    self.assertNDArrayNear(
        np.reshape(np.array(expected, dtype=np.float32),
                   [len(expected)] + [2, -1]),
        aggregated_out,
        err=1e-4)

  @parameterized.named_parameters(
      ("received edges to nodes missing edges",
       blocks.ReceivedEdgesToNodesAggregator, "edges"),
      ("sent edges to nodes missing edges",
       blocks.SentEdgesToNodesAggregator, "edges"),
      ("nodes to globals missing nodes",
       blocks.NodesToGlobalsAggregator, "nodes"),
      ("edges to globals missing nodes",
       blocks.EdgesToGlobalsAggregator, "edges"),)
  def test_missing_field_raises_exception(self, constructor, none_field):
    """Tests that aggregator fail if a required field is missing."""
    input_graph = self._get_input_graph([none_field])
    with self.assertRaisesRegexp(ValueError, none_field):
      constructor(tf.unsorted_segment_sum)(input_graph)

  @parameterized.named_parameters(
      ("received edges to nodes missing nodes and globals",
       blocks.ReceivedEdgesToNodesAggregator, ["nodes", "globals"]),
      ("sent edges to nodes missing nodes and globals",
       blocks.SentEdgesToNodesAggregator, ["nodes", "globals"]),
      ("nodes to globals missing edges and globals",
       blocks.NodesToGlobalsAggregator,
       ["edges", "receivers", "senders", "globals"]),
      ("edges to globals missing globals",
       blocks.EdgesToGlobalsAggregator, ["globals"]),
  )
  def test_unused_field_can_be_none(self, constructor, none_fields):
    """Tests that aggregator fail if a required field is missing."""
    input_graph = self._get_input_graph(none_fields)
    constructor(tf.unsorted_segment_sum)(input_graph)
示例#6
0
tvars = graph_network.trainable_variables
print('')

###############
# broadcast

graphs_tuple = utils_tf.data_dicts_to_graphs_tuple([data_dict_0])
updated_broadcast_globals_to_nodes = graphs_tuple.replace(
    nodes=blocks.broadcast_globals_to_nodes(graphs_tuple))
updated_broadcast_globals_to_edges = graphs_tuple.replace(
    edges=blocks.broadcast_globals_to_edges(graphs_tuple))
updated_broadcast_sender_nodes_to_edges = graphs_tuple.replace(
    edges=blocks.broadcast_sender_nodes_to_edges(graphs_tuple))
updated_broadcast_receiver_nodes_to_edges = graphs_tuple.replace(
    edges=blocks.broadcast_receiver_nodes_to_edges(graphs_tuple))

############
# aggregate

graphs_tuple = utils_tf.data_dicts_to_graphs_tuple([data_dict_0])

reducer = tf.math.unsorted_segment_sum  #######yr
updated_edges_to_globals = graphs_tuple.replace(
    globals=blocks.EdgesToGlobalsAggregator(reducer=reducer)(graphs_tuple))
updated_nodes_to_globals = graphs_tuple.replace(
    globals=blocks.NodesToGlobalsAggregator(reducer=reducer)(graphs_tuple))
updated_sent_edges_to_nodes = graphs_tuple.replace(
    nodes=blocks.SentEdgesToNodesAggregator(reducer=reducer)(graphs_tuple))
updated_received_edges_to_nodes = graphs_tuple.replace(
    nodes=blocks.ReceivedEdgesToNodesAggregator(reducer=reducer)(graphs_tuple))