Ejemplo n.º 1
0
 def test_interpolate_graphs_tuple_batch(self, n_steps):
   """Check that interpolated graphs are same, irrepective if batched."""
   end, _, tensorizer = self._setup_graphs()
   n_graphs = graph_utils.get_num_graphs(end)
   start = graph_utils.make_constant_like(end, *tensorizer.get_null_vectors())
   interp, _, _ = graph_utils.interpolate_graphs_tuple(start, end, n_steps)
   start_iter = graph_utils.split_graphs_tuple(start)
   end_iter = graph_utils.split_graphs_tuple(end)
   for i, (start_i, end_i) in enumerate(zip(start_iter, end_iter)):
     indices = np.arange(0, n_steps * n_graphs, n_graphs) + i
     actual = graph_utils.get_graphs_tf(interp, indices)
     expected, _, _ = graph_utils.interpolate_graphs_tuple(
         start_i, end_i, n_steps)
     self.assertEqualGraphsTuple(expected, actual)
Ejemplo n.º 2
0
 def test_get_true_attributions(self, task_enum):
     """Check we can retrieve attributions and have consistent shape."""
     graphs, mols = self._setup_graphs_mols()
     _, _, att_true = self._setup_task_data(task_enum, mols)
     for graph, att in zip(graph_utils.split_graphs_tuple(graphs),
                           att_true):
         self.assertAttributionShape(graph, att)
Ejemplo n.º 3
0
 def test_split_graphs_tuple(self):
   """Check that we can split graphtuples into a list of graphs."""
   graphs, _, _ = self._setup_graphs()
   graph_list = list(graph_utils.split_graphs_tuple(graphs))
   self.assertLen(graph_list, graph_utils.get_num_graphs(graphs))
   for index, graph_index in enumerate(graph_list):
     expected_graph = graph_nets.utils_np.get_graph(graphs, index)
     self.assertEqualGraphsTuple(graph_index, expected_graph)
Ejemplo n.º 4
0
 def test_interpolate_graphs_tuple_differences(self):
   """Check that our interpolation has constant differences between steps."""
   n_steps = 8
   several_ends, _, tensorizer = self._setup_graphs()
   for end in graph_utils.split_graphs_tuple(several_ends):
     start = graph_utils.make_constant_like(end,
                                            *tensorizer.get_null_vectors())
     interp, _, _ = graph_utils.interpolate_graphs_tuple(start, end, n_steps)
     steps = list(graph_utils.split_graphs_tuple(interp))
     expected_nodes_diff = tf.divide(end.nodes - start.nodes, n_steps - 1)
     expected_edges_diff = tf.divide(end.edges - start.edges, n_steps - 1)
     for x_cur, x_next in zip(steps[:-1], steps[1:]):
       actual_nodes_diff = x_next.nodes - x_cur.nodes
       actual_edges_diff = x_next.edges - x_cur.edges
       np.testing.assert_allclose(
           expected_nodes_diff, actual_nodes_diff, atol=1e-7)
       np.testing.assert_allclose(
           expected_edges_diff, actual_edges_diff, atol=1e-7)
Ejemplo n.º 5
0
 def test_interpolate_graphs_tuple_endpoints(self, n_steps):
   """Check that our interpolation matches at endpoints."""
   several_ends, _, tensorizer = self._setup_graphs()
   for end in graph_utils.split_graphs_tuple(several_ends):
     start = graph_utils.make_constant_like(end,
                                            *tensorizer.get_null_vectors())
     interp, _, _ = graph_utils.interpolate_graphs_tuple(start, end, n_steps)
     start_interp = graph_utils.get_graphs_tf(interp, np.array([0]))
     end_interp = graph_utils.get_graphs_tf(interp, np.array([n_steps - 1]))
     self.assertEqualGraphsTuple(start, start_interp)
     self.assertEqualGraphsTuple(end, end_interp)
 def test_attribute_independence(self, method_name):
     """Check that atts are the same batched and non-batched."""
     graphs, model, tensorizer = self._setup_graphs_model()
     method = self._setup_technique(method_name, tensorizer)
     atts = method.attribute(graphs, model)
     single_graphs = graph_utils.split_graphs_tuple(graphs)
     for xi, actual in zip(single_graphs, atts):
         expected = method.attribute(xi, model)
         np.testing.assert_allclose(actual.nodes,
                                    expected[0].nodes,
                                    rtol=1e-2)
         np.testing.assert_allclose(actual.edges,
                                    expected[0].edges,
                                    rtol=1e-2)
         self.assertAttribution(xi, expected)