def test_output_values( self, use_edges, use_receiver_nodes, use_sender_nodes, use_globals): """Compares the output of an EdgeBlock to an explicit computation.""" input_graph = self._get_input_graph() edge_block = blocks.EdgeBlock( edge_model_fn=self._edge_model_fn, use_edges=use_edges, use_receiver_nodes=use_receiver_nodes, use_sender_nodes=use_sender_nodes, use_globals=use_globals) output_graph = edge_block(input_graph) model_inputs = [] if use_edges: model_inputs.append(input_graph.edges) if use_receiver_nodes: model_inputs.append(blocks.broadcast_receiver_nodes_to_edges(input_graph)) if use_sender_nodes: model_inputs.append(blocks.broadcast_sender_nodes_to_edges(input_graph)) if use_globals: model_inputs.append(blocks.broadcast_globals_to_edges(input_graph)) model_inputs = tf.concat(model_inputs, axis=-1) self.assertEqual(input_graph.nodes, output_graph.nodes) self.assertEqual(input_graph.globals, output_graph.globals) with self.test_session() as sess: output_graph_out, model_inputs_out = sess.run( (output_graph, model_inputs)) expected_output_edges = model_inputs_out * self._scale self.assertNDArrayNear( expected_output_edges, output_graph_out.edges, err=1e-4)
def test_unused_field_can_be_none( self, use_edges, use_nodes, use_globals, none_field): """Checks that computation can handle non-necessary fields left None.""" input_graph = self._get_input_graph([none_field]) edge_block = blocks.EdgeBlock( edge_model_fn=self._edge_model_fn, use_edges=use_edges, use_receiver_nodes=use_nodes, use_sender_nodes=use_nodes, use_globals=use_globals) output_graph = edge_block(input_graph) model_inputs = [] if use_edges: model_inputs.append(input_graph.edges) if use_nodes: model_inputs.append(blocks.broadcast_receiver_nodes_to_edges(input_graph)) model_inputs.append(blocks.broadcast_sender_nodes_to_edges(input_graph)) if use_globals: model_inputs.append(blocks.broadcast_globals_to_edges(input_graph)) model_inputs = tf.concat(model_inputs, axis=-1) self.assertEqual(input_graph.nodes, output_graph.nodes) self.assertEqual(input_graph.globals, output_graph.globals) with self.test_session() as sess: actual_edges, model_inputs_out = sess.run( (output_graph.edges, model_inputs)) expected_output_edges = model_inputs_out * self._scale self.assertNDArrayNear(expected_output_edges, actual_edges, err=1e-4)
for unused_pass in range(num_recurrent_passes): previous_graphs = graph_network(previous_graphs) print(previous_graphs.nodes[0]) output_graphs = previous_graphs tvars = graph_network.trainable_variables print('') ############### # broadcast graphs_tuple = utils_tf.data_dicts_to_graphs_tuple([data_dict_0]) updated_broadcast_globals_to_nodes = graphs_tuple.replace( nodes=blocks.broadcast_globals_to_nodes(graphs_tuple)) updated_broadcast_globals_to_edges = graphs_tuple.replace( edges=blocks.broadcast_globals_to_edges(graphs_tuple)) updated_broadcast_sender_nodes_to_edges = graphs_tuple.replace( edges=blocks.broadcast_sender_nodes_to_edges(graphs_tuple)) updated_broadcast_receiver_nodes_to_edges = graphs_tuple.replace( edges=blocks.broadcast_receiver_nodes_to_edges(graphs_tuple)) ############ # aggregate graphs_tuple = utils_tf.data_dicts_to_graphs_tuple([data_dict_0]) reducer = tf.math.unsorted_segment_sum #######yr updated_edges_to_globals = graphs_tuple.replace( globals=blocks.EdgesToGlobalsAggregator(reducer=reducer)(graphs_tuple)) updated_nodes_to_globals = graphs_tuple.replace( globals=blocks.NodesToGlobalsAggregator(reducer=reducer)(graphs_tuple))