コード例 #1
0
  def test_output_values(
      self, use_edges, use_receiver_nodes, use_sender_nodes, use_globals):
    """Compares the output of an EdgeBlock to an explicit computation."""
    input_graph = self._get_input_graph()
    edge_block = blocks.EdgeBlock(
        edge_model_fn=self._edge_model_fn,
        use_edges=use_edges,
        use_receiver_nodes=use_receiver_nodes,
        use_sender_nodes=use_sender_nodes,
        use_globals=use_globals)
    output_graph = edge_block(input_graph)

    model_inputs = []
    if use_edges:
      model_inputs.append(input_graph.edges)
    if use_receiver_nodes:
      model_inputs.append(blocks.broadcast_receiver_nodes_to_edges(input_graph))
    if use_sender_nodes:
      model_inputs.append(blocks.broadcast_sender_nodes_to_edges(input_graph))
    if use_globals:
      model_inputs.append(blocks.broadcast_globals_to_edges(input_graph))

    model_inputs = tf.concat(model_inputs, axis=-1)
    self.assertEqual(input_graph.nodes, output_graph.nodes)
    self.assertEqual(input_graph.globals, output_graph.globals)

    with self.test_session() as sess:
      output_graph_out, model_inputs_out = sess.run(
          (output_graph, model_inputs))

    expected_output_edges = model_inputs_out * self._scale
    self.assertNDArrayNear(
        expected_output_edges, output_graph_out.edges, err=1e-4)
コード例 #2
0
  def test_unused_field_can_be_none(
      self, use_edges, use_nodes, use_globals, none_field):
    """Checks that computation can handle non-necessary fields left None."""
    input_graph = self._get_input_graph([none_field])
    edge_block = blocks.EdgeBlock(
        edge_model_fn=self._edge_model_fn,
        use_edges=use_edges,
        use_receiver_nodes=use_nodes,
        use_sender_nodes=use_nodes,
        use_globals=use_globals)
    output_graph = edge_block(input_graph)

    model_inputs = []
    if use_edges:
      model_inputs.append(input_graph.edges)
    if use_nodes:
      model_inputs.append(blocks.broadcast_receiver_nodes_to_edges(input_graph))
      model_inputs.append(blocks.broadcast_sender_nodes_to_edges(input_graph))
    if use_globals:
      model_inputs.append(blocks.broadcast_globals_to_edges(input_graph))

    model_inputs = tf.concat(model_inputs, axis=-1)
    self.assertEqual(input_graph.nodes, output_graph.nodes)
    self.assertEqual(input_graph.globals, output_graph.globals)

    with self.test_session() as sess:
      actual_edges, model_inputs_out = sess.run(
          (output_graph.edges, model_inputs))

    expected_output_edges = model_inputs_out * self._scale
    self.assertNDArrayNear(expected_output_edges, actual_edges, err=1e-4)
コード例 #3
0
for unused_pass in range(num_recurrent_passes):
    previous_graphs = graph_network(previous_graphs)
    print(previous_graphs.nodes[0])
output_graphs = previous_graphs

tvars = graph_network.trainable_variables
print('')

###############
# broadcast

graphs_tuple = utils_tf.data_dicts_to_graphs_tuple([data_dict_0])
updated_broadcast_globals_to_nodes = graphs_tuple.replace(
    nodes=blocks.broadcast_globals_to_nodes(graphs_tuple))
updated_broadcast_globals_to_edges = graphs_tuple.replace(
    edges=blocks.broadcast_globals_to_edges(graphs_tuple))
updated_broadcast_sender_nodes_to_edges = graphs_tuple.replace(
    edges=blocks.broadcast_sender_nodes_to_edges(graphs_tuple))
updated_broadcast_receiver_nodes_to_edges = graphs_tuple.replace(
    edges=blocks.broadcast_receiver_nodes_to_edges(graphs_tuple))

############
# aggregate

graphs_tuple = utils_tf.data_dicts_to_graphs_tuple([data_dict_0])

reducer = tf.math.unsorted_segment_sum  #######yr
updated_edges_to_globals = graphs_tuple.replace(
    globals=blocks.EdgesToGlobalsAggregator(reducer=reducer)(graphs_tuple))
updated_nodes_to_globals = graphs_tuple.replace(
    globals=blocks.NodesToGlobalsAggregator(reducer=reducer)(graphs_tuple))