Exemple #1
0
 def test_empty(self):
     graph, args, out = dot.to_graph(lambda: None)()
     self.assertEmpty(args)
     self.assertIsNone(out)
     self.assertEmpty(graph.nodes)
     self.assertEmpty(graph.edges)
     self.assertEmpty(graph.subgraphs)
Exemple #2
0
    def test_call(self):
        def my_function(x):
            return x

        graph, _, _ = dot.to_graph(jax.jit(my_function))(jnp.ones([]))
        self.assertEmpty(graph.nodes)
        self.assertEmpty(graph.edges)
        jit, = graph.subgraphs
        self.assertEqual(jit.title, "xla_call (my_function)")
Exemple #3
0
    def test_pmap(self):
        def my_function(x):
            return x

        n = jax.local_device_count()
        graph, _, _ = dot.to_graph(jax.pmap(my_function))(jnp.ones([n]))
        self.assertEmpty(graph.nodes)
        self.assertEmpty(graph.edges)
        jit, = graph.subgraphs
        self.assertEqual(jit.title, "xla_pmap (my_function)")
Exemple #4
0
 def test_add_module(self):
     mod = AddModule()
     a = b = jnp.ones([])
     graph, args, c = dot.to_graph(mod)(a, b)
     self.assertEqual(args, (a, b))
     self.assertEqual(c, a + b)
     self.assertEmpty(graph.edges)
     add_graph, = graph.subgraphs
     self.assertEqual(add_graph.title, "add_module")
     self.assertEmpty(add_graph.subgraphs)
     add_edge_a, add_edge_b = add_graph.edges
     self.assertEqual(add_edge_a, (a, c))
     self.assertEqual(add_edge_b, (b, c))
     add_node, = add_graph.nodes
     self.assertEqual(add_node.title, "add")
     add_out, = add_node.outputs
     self.assertEqual(add_out, c)