def wrapped_fun(*args) -> str: # Disable namescopes so they don't show up in the generated dot current_setting = module.modules_with_named_call try: module.profiler_name_scopes(enabled=False) return _graph_to_dot(*graph_fun(*args)) finally: module.profiler_name_scopes(enabled=current_setting)
def test_module_namescope_setting_unchanged(self, flag): current_setting = module.modules_with_named_call try: module.profiler_name_scopes(enabled=flag) _ = dot.to_dot(lambda x: x)(jnp.ones((1, 1))) self.assertEqual(module.modules_with_named_call, flag) finally: module.profiler_name_scopes(enabled=current_setting)
def test_no_namescopes_inside_dot(self): mod = AddModule() current_setting = module.modules_with_named_call try: module.profiler_name_scopes(enabled=True) with mock.patch.object(stateful, "named_call") as mock_f: _ = dot.to_dot(mod)(1, 1) mock_f.assert_not_called() finally: module.profiler_name_scopes(enabled=current_setting)
def test_no_namescopes_inside_abstract_dot(self): mod = AddModule() current_setting = module.modules_with_named_call a = b = jax.ShapeDtypeStruct(shape=tuple(), dtype=jnp.float32) try: module.profiler_name_scopes(enabled=True) with mock.patch.object(stateful, "named_call") as mock_f: _ = dot.abstract_to_dot(mod)(a, b) mock_f.assert_not_called() finally: module.profiler_name_scopes(enabled=current_setting)
# See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== """Tests for haiku._src.dot.""" from absl.testing import absltest from absl.testing import parameterized from haiku._src import dot from haiku._src import module from haiku._src import named_call from haiku._src import test_utils import jax import jax.numpy as jnp import mock module.profiler_name_scopes(False) class DotTest(parameterized.TestCase): def test_empty(self): graph, args, out = dot.to_graph(lambda: None)() self.assertEmpty(args) self.assertIsNone(out) self.assertEmpty(graph.nodes) self.assertEmpty(graph.edges) self.assertEmpty(graph.subgraphs) @test_utils.transform_and_run def test_add_module(self): mod = AddModule() a = b = jnp.ones([])