예제 #1
0
파일: dot.py 프로젝트: vinid/dm-haiku
 def wrapped_fun(*args) -> str:
     # Disable namescopes so they don't show up in the generated dot
     current_setting = module.modules_with_named_call
     try:
         module.profiler_name_scopes(enabled=False)
         return _graph_to_dot(*graph_fun(*args))
     finally:
         module.profiler_name_scopes(enabled=current_setting)
예제 #2
0
 def test_module_namescope_setting_unchanged(self, flag):
     current_setting = module.modules_with_named_call
     try:
         module.profiler_name_scopes(enabled=flag)
         _ = dot.to_dot(lambda x: x)(jnp.ones((1, 1)))
         self.assertEqual(module.modules_with_named_call, flag)
     finally:
         module.profiler_name_scopes(enabled=current_setting)
예제 #3
0
 def test_no_namescopes_inside_dot(self):
   mod = AddModule()
   current_setting = module.modules_with_named_call
   try:
     module.profiler_name_scopes(enabled=True)
     with mock.patch.object(stateful, "named_call") as mock_f:
       _ = dot.to_dot(mod)(1, 1)
       mock_f.assert_not_called()
   finally:
     module.profiler_name_scopes(enabled=current_setting)
예제 #4
0
 def test_no_namescopes_inside_abstract_dot(self):
   mod = AddModule()
   current_setting = module.modules_with_named_call
   a = b = jax.ShapeDtypeStruct(shape=tuple(), dtype=jnp.float32)
   try:
     module.profiler_name_scopes(enabled=True)
     with mock.patch.object(stateful, "named_call") as mock_f:
       _ = dot.abstract_to_dot(mod)(a, b)
       mock_f.assert_not_called()
   finally:
     module.profiler_name_scopes(enabled=current_setting)
예제 #5
0
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for haiku._src.dot."""

from absl.testing import absltest
from absl.testing import parameterized
from haiku._src import dot
from haiku._src import module
from haiku._src import named_call
from haiku._src import test_utils
import jax
import jax.numpy as jnp
import mock

module.profiler_name_scopes(False)


class DotTest(parameterized.TestCase):
    def test_empty(self):
        graph, args, out = dot.to_graph(lambda: None)()
        self.assertEmpty(args)
        self.assertIsNone(out)
        self.assertEmpty(graph.nodes)
        self.assertEmpty(graph.edges)
        self.assertEmpty(graph.subgraphs)

    @test_utils.transform_and_run
    def test_add_module(self):
        mod = AddModule()
        a = b = jnp.ones([])