Пример #1
0
 def test_perturb_graphs_tuple(self, n_samples, sigma):
   graphs, _, _ = self._setup_graphs()
   noisy_graphs = graph_utils.perturb_graphs_tuple(graphs, n_samples, sigma)
   expected_n = n_samples * graph_utils.get_num_graphs(graphs)
   actual_n = graph_utils.get_num_graphs(noisy_graphs)
   self.assertEqual(expected_n, actual_n)
   self.assertEqual(graphs.nodes.shape[-1], noisy_graphs.nodes.shape[-1])
   self.assertEqual(graphs.edges.shape[-1], noisy_graphs.edges.shape[-1])
Пример #2
0
 def assertAttributionShape(self, graphs, att):
     self.assertEqual(graph_utils.get_num_graphs(graphs),
                      graph_utils.get_num_graphs(att))
     self.assertEqual(graphs.nodes.shape[0], att.nodes.shape[0])
     if att.edges is not None:
         self.assertEqual(graphs.edges.shape[0], att.edges.shape[0])
     np.testing.assert_allclose(graphs.n_node, att.n_node)
     np.testing.assert_allclose(graphs.n_edge, att.n_edge)
     np.testing.assert_allclose(graphs.senders, att.senders)
     np.testing.assert_allclose(graphs.receivers, att.receivers)
Пример #3
0
 def test_smiles_to_graphs_tuple(self):
   """Check that graphs have same number of nodes as atoms."""
   graphs, smiles_list, _ = self._setup_graphs()
   mol_list = [featurization.smiles_to_mol(smi) for smi in smiles_list]
   n_atoms = [mol.GetNumAtoms() for mol in mol_list]
   self.assertLen(mol_list, graph_utils.get_num_graphs(graphs))
   self.assertEqual(n_atoms, graphs.n_node.numpy().tolist())
Пример #4
0
 def test_split_graphs_tuple(self):
   """Check that we can split graphtuples into a list of graphs."""
   graphs, _, _ = self._setup_graphs()
   graph_list = list(graph_utils.split_graphs_tuple(graphs))
   self.assertLen(graph_list, graph_utils.get_num_graphs(graphs))
   for index, graph_index in enumerate(graph_list):
     expected_graph = graph_nets.utils_np.get_graph(graphs, index)
     self.assertEqualGraphsTuple(graph_index, expected_graph)
 def test_ReadoutGAP(self, globals_size, act):
   """Check that output global shape, nodes and edges are the same."""
   graphs = self._setup_graphs()
   module = models.ReadoutGAP(globals_size, act)
   out_graphs = module(graphs)
   self.assertGraphShape(
       out_graphs,
       num_graphs=graph_utils.get_num_graphs(graphs),
       globals_dim=globals_size,
       node_dim=graphs.nodes.shape[-1],
       edge_dim=graphs.edges.shape[-1])
 def test_node_layer(self, node_layer, node_size):
   """Check that output only changes node shape."""
   graphs = self._setup_graphs()
   node_fn = models.get_mlp_fn([node_size])
   module = node_layer(node_fn)
   out_graphs = module(graphs)
   self.assertGraphShape(
       out_graphs,
       num_graphs=graph_utils.get_num_graphs(graphs),
       node_dim=node_size,
       edge_dim=graphs.edges.shape[-1])
Пример #7
0
 def test_perturb_graphs_tuple_zero(self, n_samples):
   """When sigma is zero, graphs should be the same."""
   graphs, _, _ = self._setup_graphs()
   sigma = 0.0
   n_graphs = graph_utils.get_num_graphs(graphs)
   noisy_graphs = graph_utils.perturb_graphs_tuple(graphs, n_samples, sigma)
   for index in range(n_samples):
     indices = np.arange(index * n_graphs, (index + 1) * n_graphs)
     sub_graphs = graph_utils.get_graphs_tf(noisy_graphs, indices)
     np.testing.assert_allclose(graphs.nodes, sub_graphs.nodes)
     np.testing.assert_allclose(graphs.edges, sub_graphs.edges)
 def test_NodeEdgeLayer(self, node_size, edge_dim):
   """Check that output only changes node and edge shape."""
   graphs = self._setup_graphs()
   node_fn = models.get_mlp_fn([node_size])
   edge_fn = models.get_mlp_fn([edge_dim])
   module = models.NodeEdgeLayer(node_fn, edge_fn)
   out_graphs = module(graphs)
   self.assertGraphShape(
       out_graphs,
       num_graphs=graph_utils.get_num_graphs(graphs),
       node_dim=node_size,
       edge_dim=edge_dim)
Пример #9
0
 def test_interpolate_graphs_tuple_batch(self, n_steps):
   """Check that interpolated graphs are same, irrepective if batched."""
   end, _, tensorizer = self._setup_graphs()
   n_graphs = graph_utils.get_num_graphs(end)
   start = graph_utils.make_constant_like(end, *tensorizer.get_null_vectors())
   interp, _, _ = graph_utils.interpolate_graphs_tuple(start, end, n_steps)
   start_iter = graph_utils.split_graphs_tuple(start)
   end_iter = graph_utils.split_graphs_tuple(end)
   for i, (start_i, end_i) in enumerate(zip(start_iter, end_iter)):
     indices = np.arange(0, n_steps * n_graphs, n_graphs) + i
     actual = graph_utils.get_graphs_tf(interp, indices)
     expected, _, _ = graph_utils.interpolate_graphs_tuple(
         start_i, end_i, n_steps)
     self.assertEqualGraphsTuple(expected, actual)
 def assertGraphShape(self,
                      graphs,
                      num_graphs=None,
                      node_dim=None,
                      edge_dim=None,
                      globals_dim=None):
   """Check that a graph has the correct shape for several fields."""
   if num_graphs is not None:
     self.assertEqual(graph_utils.get_num_graphs(graphs), num_graphs)
   if node_dim is not None:
     self.assertEqual(graphs.nodes.shape[-1], node_dim)
   if edge_dim is not None:
     self.assertEqual(graphs.edges.shape[-1], edge_dim)
   if globals_dim is not None:
     self.assertEqual(graphs.globals.shape[-1], globals_dim)
 def _setup_graphs(self):
   """Setup graphs and smiles if needed."""
   tensorizer = featurization.MolTensorizer()
   smiles = ['CO', 'CCC', 'CN1C=NC2=C1C(=O)N(C(=O)N2C)C']
   graphs = graph_utils.smiles_to_graphs_tuple(smiles, tensorizer)
   return graphs, graph_utils.get_num_graphs(graphs)