def test_dynamic_batch_sizes(self, block_constructor): """Checks that all batch sizes are as expected through a GraphNetwork.""" # Remove all placeholders from here, these are unnecessary in tf2. input_graph = utils_np.data_dicts_to_graphs_tuple( [SMALL_GRAPH_1, SMALL_GRAPH_2]) input_graph = input_graph.map(tf.constant, fields=graphs.ALL_FIELDS) model = block_constructor( functools.partial(snt.nets.MLP, output_sizes=[10])) output = model(input_graph) actual = utils_tf.nest_to_numpy(output) for k, v in input_graph._asdict().items(): self.assertEqual(v.shape[0], getattr(actual, k).shape[0])
def test_getitem_one(self, use_tensor_index): index = 2 expected = self.graphs_dicts_out[index] if use_tensor_index: index = tf.constant(index) graphs_tuple = utils_tf.data_dicts_to_graphs_tuple(self.graphs_dicts_in) graph = utils_tf.get_graph(graphs_tuple, index) graph = utils_tf.nest_to_numpy(graph) actual, = utils_np.graphs_tuple_to_data_dicts(graph) for k, v in expected.items(): self.assertAllClose(v, actual[k]) self.assertEqual(expected["nodes"].shape[0], actual["n_node"]) self.assertEqual(expected["edges"].shape[0], actual["n_edge"])
def test_getitem(self, use_tensor_slice): index = slice(1, 3) expected = self.graphs_dicts_out[index] if use_tensor_slice: index = slice(tf.constant(index.start), tf.constant(index.stop)) graphs_tuple = utils_tf.data_dicts_to_graphs_tuple(self.graphs_dicts_in) graphs2 = utils_tf.get_graph(graphs_tuple, index) graphs2 = utils_tf.nest_to_numpy(graphs2) actual = utils_np.graphs_tuple_to_data_dicts(graphs2) for ex, ac in zip(expected, actual): for k, v in ex.items(): self.assertAllClose(v, ac[k]) self.assertEqual(ex["nodes"].shape[0], ac["n_node"]) self.assertEqual(ex["edges"].shape[0], ac["n_edge"])
def test_output_values(self, use_received_edges, use_sent_edges, use_nodes, use_globals, received_edges_reducer, sent_edges_reducer): """Compares the output of a NodeBlock to an explicit computation.""" input_graph = self._get_input_graph() node_block = blocks.NodeBlock( node_model_fn=self._node_model_fn, use_received_edges=use_received_edges, use_sent_edges=use_sent_edges, use_nodes=use_nodes, use_globals=use_globals, received_edges_reducer=received_edges_reducer, sent_edges_reducer=sent_edges_reducer) output_graph = node_block(input_graph) model_inputs = [] if use_received_edges: model_inputs.append( blocks.ReceivedEdgesToNodesAggregator(received_edges_reducer)( input_graph)) if use_sent_edges: model_inputs.append( blocks.SentEdgesToNodesAggregator(sent_edges_reducer)( input_graph)) if use_nodes: model_inputs.append(input_graph.nodes) if use_globals: model_inputs.append(blocks.broadcast_globals_to_nodes(input_graph)) model_inputs = tf.concat(model_inputs, axis=-1) self.assertIs(input_graph.edges, output_graph.edges) self.assertIs(input_graph.globals, output_graph.globals) output_graph_out = utils_tf.nest_to_numpy(output_graph) model_inputs_out = model_inputs expected_output_nodes = model_inputs_out * self._scale self.assertNDArrayNear(expected_output_nodes, output_graph_out.nodes, err=1e-4)
# nd = len(graph.nodes()) # probs = np.zeros((nd, nd)) # for edge in graph.edges(data=True): # probs[edge[0], edge[1]] = edge[2]["features"][0] # ax.matshow(probs[sort_indices][:, sort_indices], cmap="viridis") # ax.grid(False) num_elements_min_max = (5, 10) inputs, targets, sort_indices, ranks = create_data( 1, num_elements_min_max) inputs_nodes = inputs.nodes.numpy() #[7,1] targets = utils_tf.nest_to_numpy(targets) sort_indices_nodes = sort_indices.nodes.numpy() ranks_nodes = ranks.nodes.numpy() sort_indices = np.squeeze(sort_indices_nodes).astype(int) # # Plot sort linked lists. # # The matrix plots show each element from the sorted list (rows), and which # # element they link to as next largest (columns). Ground truth is a diagonal # # offset toward the upper-right by one. # fig = plt.figure(1, figsize=(4, 4)) # fig.clf() # ax = fig.add_subplot(1, 1, 1) # plot_linked_list(ax, # utils_np.graphs_tuple_to_networkxs(targets)[0], sort_indices) # ax.set_title("Element-to-element links for sorted elements")
def test_add_remove_padding(self, experimental_unconnected_padding_edges, nested_features): data_dict = test_utils_tf1.generate_random_data_dict( (7, ), (8, ), (9, ), num_nodes_range=(10, 15), num_edges_range=(20, 25)) node_size_np = data_dict["nodes"].shape[0] edge_size_np = data_dict["edges"].shape[0] unpadded_batch_size = 2 graphs_tuple = utils_tf.data_dicts_to_graphs_tuple( unpadded_batch_size * [data_dict]) if nested_features: graphs_tuple = graphs_tuple.replace(edges=[graphs_tuple.edges, {}], nodes=({ "tensor": graphs_tuple.nodes }, ), globals=( [], graphs_tuple.globals, )) num_padding_nodes = 3 num_padding_edges = 4 num_padding_graphs = 5 pad_nodes_to = unpadded_batch_size * node_size_np + num_padding_nodes pad_edges_to = unpadded_batch_size * edge_size_np + num_padding_edges pad_graphs_to = unpadded_batch_size + num_padding_graphs def _get_padded_and_recovered_graphs_tuple(graphs_tuple): padded_graphs_tuple = utils_tf.pad_graphs_tuple( graphs_tuple, pad_nodes_to, pad_edges_to, pad_graphs_to, experimental_unconnected_padding_edges) # Check that we have statically defined shapes after padding. self.assertEqual(_leading_static_shape(padded_graphs_tuple.nodes), pad_nodes_to) self.assertEqual(_leading_static_shape(padded_graphs_tuple.edges), pad_edges_to) self.assertEqual( _leading_static_shape(padded_graphs_tuple.senders), pad_edges_to) self.assertEqual( _leading_static_shape(padded_graphs_tuple.receivers), pad_edges_to) self.assertEqual( _leading_static_shape(padded_graphs_tuple.globals), pad_graphs_to) self.assertEqual(_leading_static_shape(padded_graphs_tuple.n_node), pad_graphs_to) self.assertEqual(_leading_static_shape(padded_graphs_tuple.n_edge), pad_graphs_to) # Check that we can remove the padding. graphs_tuple_size = utils_tf.get_graphs_tuple_size(graphs_tuple) recovered_graphs_tuple = utils_tf.remove_graphs_tuple_padding( padded_graphs_tuple, graphs_tuple_size) return padded_graphs_tuple, recovered_graphs_tuple # Put it into a tf.function so the shapes are unknown statically. compiled_fn = _compile_with_tf_function( _get_padded_and_recovered_graphs_tuple, graphs_tuple) padded_graphs_tuple, recovered_graphs_tuple = compiled_fn(graphs_tuple) if nested_features: # Check that the whole structure of the outputs are the same. tree.assert_same_structure(padded_graphs_tuple, graphs_tuple) tree.assert_same_structure(recovered_graphs_tuple, graphs_tuple) # Undo the nesting for the rest of the test. def remove_nesting(this_graphs_tuple): return this_graphs_tuple.replace( edges=this_graphs_tuple.edges[0], nodes=this_graphs_tuple.nodes[0]["tensor"], globals=this_graphs_tuple.globals[1]) graphs_tuple = remove_nesting(graphs_tuple) padded_graphs_tuple = remove_nesting(padded_graphs_tuple) recovered_graphs_tuple = remove_nesting(recovered_graphs_tuple) # Inspect the padded_graphs_tuple. padded_graphs_tuple_data_dicts = utils_np.graphs_tuple_to_data_dicts( utils_tf.nest_to_numpy(padded_graphs_tuple)) graphs_tuple_data_dicts = utils_np.graphs_tuple_to_data_dicts( utils_tf.nest_to_numpy(graphs_tuple)) self.assertLen(padded_graphs_tuple, pad_graphs_to) # Check that the first 2 graphs from the padded_graphs_tuple are the same. for example_i in range(unpadded_batch_size): tree.map_structure(self.assertAllEqual, graphs_tuple_data_dicts[example_i], padded_graphs_tuple_data_dicts[example_i]) padding_data_dicts = padded_graphs_tuple_data_dicts[ unpadded_batch_size:] # Check that the third graph contains all of the padding nodes and edges. for i, padding_data_dict in enumerate(padding_data_dicts): # Only the first padding graph has nodes and edges. num_nodes = num_padding_nodes if i == 0 else 0 num_edges = num_padding_edges if i == 0 else 0 self.assertAllEqual(padding_data_dict["globals"], np.zeros([9], dtype=np.float32)) self.assertEqual(padding_data_dict["n_node"], num_nodes) self.assertAllEqual(padding_data_dict["nodes"], np.zeros([num_nodes, 7], dtype=np.float32)) self.assertEqual(padding_data_dict["n_edge"], num_edges) self.assertAllEqual(padding_data_dict["edges"], np.zeros([num_edges, 8], dtype=np.float32)) if experimental_unconnected_padding_edges: self.assertAllEqual( padding_data_dict["receivers"], np.zeros([num_edges], dtype=np.int32) + num_nodes) self.assertAllEqual( padding_data_dict["senders"], np.zeros([num_edges], dtype=np.int32) + num_nodes) else: self.assertAllEqual(padding_data_dict["receivers"], np.zeros([num_edges], dtype=np.int32)) self.assertAllEqual(padding_data_dict["senders"], np.zeros([num_edges], dtype=np.int32)) # Check that the recovered_graphs_tuple after removing padding is identical. tree.map_structure(self.assertAllEqual, graphs_tuple._asdict(), recovered_graphs_tuple._asdict())