def test_interpolate_graphs_tuple_endpoints(self, n_steps): """Check that our interpolation matches at endpoints.""" several_ends, _, tensorizer = self._setup_graphs() for end in graph_utils.split_graphs_tuple(several_ends): start = graph_utils.make_constant_like(end, *tensorizer.get_null_vectors()) interp, _, _ = graph_utils.interpolate_graphs_tuple(start, end, n_steps) start_interp = graph_utils.get_graphs_tf(interp, np.array([0])) end_interp = graph_utils.get_graphs_tf(interp, np.array([n_steps - 1])) self.assertEqualGraphsTuple(start, start_interp) self.assertEqualGraphsTuple(end, end_interp)
def test_get_graphs_tf(self, indices): """Check that we can split graphtuples into a list of graphs.""" graphs, _, _ = self._setup_graphs() sub_graphs = graph_utils.get_graphs_tf(graphs, np.array(indices)) graph_list = [graph_nets.utils_tf.get_graph(graphs, i) for i in indices] expected_graphs = graph_nets.utils_tf.concat(graph_list, axis=0) self.assertEqualGraphsTuple(sub_graphs, expected_graphs)
def test_get_graphs_tf_noedges(self): """Check that we can split graphtuples with no edge information.""" indices = [1, 2, 5, 6] graphs, _, _ = self._setup_graphs() graphs = graphs.replace(edges=None) sub_graphs = graph_utils.get_graphs_tf(graphs, np.array(indices)) graph_list = [graph_nets.utils_tf.get_graph(graphs, i) for i in indices] expected_graphs = graph_nets.utils_tf.concat(graph_list, axis=0) self.assertEqualGraphsTuple(sub_graphs, expected_graphs)
def test_perturb_graphs_tuple_zero(self, n_samples): """When sigma is zero, graphs should be the same.""" graphs, _, _ = self._setup_graphs() sigma = 0.0 n_graphs = graph_utils.get_num_graphs(graphs) noisy_graphs = graph_utils.perturb_graphs_tuple(graphs, n_samples, sigma) for index in range(n_samples): indices = np.arange(index * n_graphs, (index + 1) * n_graphs) sub_graphs = graph_utils.get_graphs_tf(noisy_graphs, indices) np.testing.assert_allclose(graphs.nodes, sub_graphs.nodes) np.testing.assert_allclose(graphs.edges, sub_graphs.edges)
def test_interpolate_graphs_tuple_batch(self, n_steps): """Check that interpolated graphs are same, irrepective if batched.""" end, _, tensorizer = self._setup_graphs() n_graphs = graph_utils.get_num_graphs(end) start = graph_utils.make_constant_like(end, *tensorizer.get_null_vectors()) interp, _, _ = graph_utils.interpolate_graphs_tuple(start, end, n_steps) start_iter = graph_utils.split_graphs_tuple(start) end_iter = graph_utils.split_graphs_tuple(end) for i, (start_i, end_i) in enumerate(zip(start_iter, end_iter)): indices = np.arange(0, n_steps * n_graphs, n_graphs) + i actual = graph_utils.get_graphs_tf(interp, indices) expected, _, _ = graph_utils.interpolate_graphs_tuple( start_i, end_i, n_steps) self.assertEqualGraphsTuple(expected, actual)