def test_perturb_graphs_tuple(self, n_samples, sigma): graphs, _, _ = self._setup_graphs() noisy_graphs = graph_utils.perturb_graphs_tuple(graphs, n_samples, sigma) expected_n = n_samples * graph_utils.get_num_graphs(graphs) actual_n = graph_utils.get_num_graphs(noisy_graphs) self.assertEqual(expected_n, actual_n) self.assertEqual(graphs.nodes.shape[-1], noisy_graphs.nodes.shape[-1]) self.assertEqual(graphs.edges.shape[-1], noisy_graphs.edges.shape[-1])
def assertAttributionShape(self, graphs, att): self.assertEqual(graph_utils.get_num_graphs(graphs), graph_utils.get_num_graphs(att)) self.assertEqual(graphs.nodes.shape[0], att.nodes.shape[0]) if att.edges is not None: self.assertEqual(graphs.edges.shape[0], att.edges.shape[0]) np.testing.assert_allclose(graphs.n_node, att.n_node) np.testing.assert_allclose(graphs.n_edge, att.n_edge) np.testing.assert_allclose(graphs.senders, att.senders) np.testing.assert_allclose(graphs.receivers, att.receivers)
def test_smiles_to_graphs_tuple(self): """Check that graphs have same number of nodes as atoms.""" graphs, smiles_list, _ = self._setup_graphs() mol_list = [featurization.smiles_to_mol(smi) for smi in smiles_list] n_atoms = [mol.GetNumAtoms() for mol in mol_list] self.assertLen(mol_list, graph_utils.get_num_graphs(graphs)) self.assertEqual(n_atoms, graphs.n_node.numpy().tolist())
def test_split_graphs_tuple(self): """Check that we can split graphtuples into a list of graphs.""" graphs, _, _ = self._setup_graphs() graph_list = list(graph_utils.split_graphs_tuple(graphs)) self.assertLen(graph_list, graph_utils.get_num_graphs(graphs)) for index, graph_index in enumerate(graph_list): expected_graph = graph_nets.utils_np.get_graph(graphs, index) self.assertEqualGraphsTuple(graph_index, expected_graph)
def test_ReadoutGAP(self, globals_size, act): """Check that output global shape, nodes and edges are the same.""" graphs = self._setup_graphs() module = models.ReadoutGAP(globals_size, act) out_graphs = module(graphs) self.assertGraphShape( out_graphs, num_graphs=graph_utils.get_num_graphs(graphs), globals_dim=globals_size, node_dim=graphs.nodes.shape[-1], edge_dim=graphs.edges.shape[-1])
def test_node_layer(self, node_layer, node_size): """Check that output only changes node shape.""" graphs = self._setup_graphs() node_fn = models.get_mlp_fn([node_size]) module = node_layer(node_fn) out_graphs = module(graphs) self.assertGraphShape( out_graphs, num_graphs=graph_utils.get_num_graphs(graphs), node_dim=node_size, edge_dim=graphs.edges.shape[-1])
def test_perturb_graphs_tuple_zero(self, n_samples): """When sigma is zero, graphs should be the same.""" graphs, _, _ = self._setup_graphs() sigma = 0.0 n_graphs = graph_utils.get_num_graphs(graphs) noisy_graphs = graph_utils.perturb_graphs_tuple(graphs, n_samples, sigma) for index in range(n_samples): indices = np.arange(index * n_graphs, (index + 1) * n_graphs) sub_graphs = graph_utils.get_graphs_tf(noisy_graphs, indices) np.testing.assert_allclose(graphs.nodes, sub_graphs.nodes) np.testing.assert_allclose(graphs.edges, sub_graphs.edges)
def test_NodeEdgeLayer(self, node_size, edge_dim): """Check that output only changes node and edge shape.""" graphs = self._setup_graphs() node_fn = models.get_mlp_fn([node_size]) edge_fn = models.get_mlp_fn([edge_dim]) module = models.NodeEdgeLayer(node_fn, edge_fn) out_graphs = module(graphs) self.assertGraphShape( out_graphs, num_graphs=graph_utils.get_num_graphs(graphs), node_dim=node_size, edge_dim=edge_dim)
def test_interpolate_graphs_tuple_batch(self, n_steps): """Check that interpolated graphs are same, irrepective if batched.""" end, _, tensorizer = self._setup_graphs() n_graphs = graph_utils.get_num_graphs(end) start = graph_utils.make_constant_like(end, *tensorizer.get_null_vectors()) interp, _, _ = graph_utils.interpolate_graphs_tuple(start, end, n_steps) start_iter = graph_utils.split_graphs_tuple(start) end_iter = graph_utils.split_graphs_tuple(end) for i, (start_i, end_i) in enumerate(zip(start_iter, end_iter)): indices = np.arange(0, n_steps * n_graphs, n_graphs) + i actual = graph_utils.get_graphs_tf(interp, indices) expected, _, _ = graph_utils.interpolate_graphs_tuple( start_i, end_i, n_steps) self.assertEqualGraphsTuple(expected, actual)
def assertGraphShape(self, graphs, num_graphs=None, node_dim=None, edge_dim=None, globals_dim=None): """Check that a graph has the correct shape for several fields.""" if num_graphs is not None: self.assertEqual(graph_utils.get_num_graphs(graphs), num_graphs) if node_dim is not None: self.assertEqual(graphs.nodes.shape[-1], node_dim) if edge_dim is not None: self.assertEqual(graphs.edges.shape[-1], edge_dim) if globals_dim is not None: self.assertEqual(graphs.globals.shape[-1], globals_dim)
def _setup_graphs(self): """Setup graphs and smiles if needed.""" tensorizer = featurization.MolTensorizer() smiles = ['CO', 'CCC', 'CN1C=NC2=C1C(=O)N(C(=O)N2C)C'] graphs = graph_utils.smiles_to_graphs_tuple(smiles, tensorizer) return graphs, graph_utils.get_num_graphs(graphs)