def test_fill_node_state(self, node_size): """Tests for filling the node state with a constant content.""" for g in self.graphs_dicts_in: g["n_node"] = g["nodes"].shape[0] g.pop("nodes") graphs_tuple = utils_tf.data_dicts_to_graphs_tuple(self.graphs_dicts_in) n_nodes = np.sum(self.reference_graph.n_node) graphs_tuple = utils_tf.set_zero_node_features(graphs_tuple, node_size) self.assertAllEqual((n_nodes, node_size), graphs_tuple.nodes.get_shape().as_list())
def test_fill_state_user_specified_types(self, dtype): """Tests that the features are created with the correct default type.""" for g in self.graphs_dicts_in: g.pop("nodes") g.pop("globals") g.pop("edges") graphs_tuple = utils_tf.data_dicts_to_graphs_tuple(self.graphs_dicts_in) graphs_tuple = utils_tf.set_zero_edge_features(graphs_tuple, 1, dtype) graphs_tuple = utils_tf.set_zero_node_features(graphs_tuple, 1, dtype) graphs_tuple = utils_tf.set_zero_global_features(graphs_tuple, 1, dtype) self.assertEqual(dtype, graphs_tuple.edges.dtype) self.assertEqual(dtype, graphs_tuple.nodes.dtype) self.assertEqual(dtype, graphs_tuple.globals.dtype)
def test_fill_state_default_types(self): """Tests that the features are created with the correct default type.""" for g in self.graphs_dicts_in: g.pop("nodes") g.pop("globals") g.pop("edges") graphs_tuple = utils_tf.data_dicts_to_graphs_tuple(self.graphs_dicts_in) graphs_tuple = utils_tf.set_zero_edge_features(graphs_tuple, edge_size=1) graphs_tuple = utils_tf.set_zero_node_features(graphs_tuple, node_size=1) graphs_tuple = utils_tf.set_zero_global_features( graphs_tuple, global_size=1) self.assertEqual(tf.float32, graphs_tuple.edges.dtype) self.assertEqual(tf.float32, graphs_tuple.nodes.dtype) self.assertEqual(tf.float32, graphs_tuple.globals.dtype)
def test_fill_node_state_dynamic(self, node_size): """Tests for filling the node state with a constant content.""" for g in self.graphs_dicts_in: g["n_node"] = g["nodes"].shape[0] g.pop("nodes") graphs_tuple = utils_tf.data_dicts_to_graphs_tuple(self.graphs_dicts_in) graphs_tuple = graphs_tuple._replace( n_node=tf.constant( graphs_tuple.n_node, shape=graphs_tuple.n_node.get_shape())) n_nodes = np.sum(self.reference_graph.n_node) graphs_tuple = utils_tf.set_zero_node_features(graphs_tuple, node_size) actual_nodes = graphs_tuple.nodes.numpy() self.assertNDArrayNear( np.zeros((n_nodes, node_size)), actual_nodes, err=1e-4)
def zeros_graph(sample_graph, edge_size, node_size, global_size): zeros_graphs = sample_graph.replace(nodes=None, edges=None, globals=None) zeros_graphs = utils_tf.set_zero_edge_features(zeros_graphs, edge_size) zeros_graphs = utils_tf.set_zero_node_features(zeros_graphs, node_size) zeros_graphs = utils_tf.set_zero_global_features(zeros_graphs, global_size) return zeros_graphs