def test_empty(self): graph, args, out = dot.to_graph(lambda: None)() self.assertEmpty(args) self.assertIsNone(out) self.assertEmpty(graph.nodes) self.assertEmpty(graph.edges) self.assertEmpty(graph.subgraphs)
def test_call(self): def my_function(x): return x graph, _, _ = dot.to_graph(jax.jit(my_function))(jnp.ones([])) self.assertEmpty(graph.nodes) self.assertEmpty(graph.edges) jit, = graph.subgraphs self.assertEqual(jit.title, "xla_call (my_function)")
def test_pmap(self): def my_function(x): return x n = jax.local_device_count() graph, _, _ = dot.to_graph(jax.pmap(my_function))(jnp.ones([n])) self.assertEmpty(graph.nodes) self.assertEmpty(graph.edges) jit, = graph.subgraphs self.assertEqual(jit.title, "xla_pmap (my_function)")
def test_add_module(self): mod = AddModule() a = b = jnp.ones([]) graph, args, c = dot.to_graph(mod)(a, b) self.assertEqual(args, (a, b)) self.assertEqual(c, a + b) self.assertEmpty(graph.edges) add_graph, = graph.subgraphs self.assertEqual(add_graph.title, "add_module") self.assertEmpty(add_graph.subgraphs) add_edge_a, add_edge_b = add_graph.edges self.assertEqual(add_edge_a, (a, c)) self.assertEqual(add_edge_b, (b, c)) add_node, = add_graph.nodes self.assertEqual(add_node.title, "add") add_out, = add_node.outputs self.assertEqual(add_out, c)