示例#1
0
    def test_intercept_methods_calling_underlying_optional(self):
        def do_nothing_interceptor(f, args, kwargs, context):
            del f, context
            self.assertEmpty(args)
            self.assertEmpty(kwargs)

        m = RaisesModule()
        with module.intercept_methods(do_nothing_interceptor):
            m()

        with self.assertRaises(AssertionError):
            m()  # Without the interceptor we expect an error.

        # The previous error should not stop us from re-applying.
        with module.intercept_methods(do_nothing_interceptor):
            m()
示例#2
0
    def test_intercept_method(self):
        mod = IdentityModule()
        x = jnp.ones([])
        call_count = []

        def add_one_interceptor(f, args, kwargs, context):
            call_count.append(None)
            self.assertLen(context, 3)
            self.assertIs(context.module, mod)
            self.assertEqual(context.method_name, "__call__")
            self.assertEqual(context.orig_method(2), 2)
            self.assertEqual(args, (x, ))
            self.assertEmpty(kwargs)
            y = f(*args, **kwargs)
            return y + 1

        y1 = mod(x)
        with module.intercept_methods(add_one_interceptor):
            y2 = mod(x)
        y3 = mod(x)

        self.assertLen(call_count, 1)
        self.assertEqual(y1, 1)
        self.assertEqual(y2, 2)
        self.assertEqual(y3, 1)
示例#3
0
  def test_intercept_methods_run_in_lifo_order(self):
    def op_interceptor(op):
      def _interceptor(f, args, kwargs, context):
        del context
        y = f(*args, **kwargs)
        return op(y)
      return _interceptor

    mod = IdentityModule()
    x = 7
    with module.intercept_methods(op_interceptor(lambda a: a + 1)), \
         module.intercept_methods(op_interceptor(lambda a: a ** 2)):
      y = mod(x)
    self.assertEqual(y, (x ** 2) + 1)

    with module.intercept_methods(op_interceptor(lambda a: a ** 2)), \
         module.intercept_methods(op_interceptor(lambda a: a + 1)):
      y = mod(x)
    self.assertEqual(y, (x + 1) ** 2)
示例#4
0
    def test_name_like_interceptor_method_names_unchanged(self):
        log = []

        def log_parent_methods(f, args, kwargs, context: module.MethodContext):
            if isinstance(context.module, ModuleWithCustomName):
                log.append(context.method_name)
            return f(*args, **kwargs)

        with module.intercept_methods(log_parent_methods):
            m = ModuleWithCustomName(name="parent")
            m.foo()  # foo pretends to be __call__.
            m.bar()  # bar pretends to be baz.
            # baz and call are happy to be themselves.
            m.baz()
            m()

        self.assertEqual(log, ["__init__", "foo", "bar", "baz", "__call__"])
示例#5
0
  def test_policy_with_interceptor(self):
    sidechannel = []
    def my_interceptor(next_f, args, kwargs, context):
      sidechannel.append(context)
      return next_f(*args, **kwargs)

    # We need this to make sure that the mixed precision interceptor is
    # installed when we call set_policy (this only happens the first call).
    mixed_precision.reset_thread_local_state_for_test()

    policy = jmp.get_policy('p=f16,c=f32,o=f16')
    with module.intercept_methods(my_interceptor):
      mixed_precision.set_policy(OuterModule, policy)
      x = OuterModule()()
      self.assertEqual(x.dtype, jnp.float16)
    # Outer.init, Outer.call, Inner.init, Inner.call
    self.assertLen(sidechannel, 4)