def test_intercept_methods_calling_underlying_optional(self): def do_nothing_interceptor(f, args, kwargs, context): del f, context self.assertEmpty(args) self.assertEmpty(kwargs) m = RaisesModule() with module.intercept_methods(do_nothing_interceptor): m() with self.assertRaises(AssertionError): m() # Without the interceptor we expect an error. # The previous error should not stop us from re-applying. with module.intercept_methods(do_nothing_interceptor): m()
def test_intercept_method(self): mod = IdentityModule() x = jnp.ones([]) call_count = [] def add_one_interceptor(f, args, kwargs, context): call_count.append(None) self.assertLen(context, 3) self.assertIs(context.module, mod) self.assertEqual(context.method_name, "__call__") self.assertEqual(context.orig_method(2), 2) self.assertEqual(args, (x, )) self.assertEmpty(kwargs) y = f(*args, **kwargs) return y + 1 y1 = mod(x) with module.intercept_methods(add_one_interceptor): y2 = mod(x) y3 = mod(x) self.assertLen(call_count, 1) self.assertEqual(y1, 1) self.assertEqual(y2, 2) self.assertEqual(y3, 1)
def test_intercept_methods_run_in_lifo_order(self): def op_interceptor(op): def _interceptor(f, args, kwargs, context): del context y = f(*args, **kwargs) return op(y) return _interceptor mod = IdentityModule() x = 7 with module.intercept_methods(op_interceptor(lambda a: a + 1)), \ module.intercept_methods(op_interceptor(lambda a: a ** 2)): y = mod(x) self.assertEqual(y, (x ** 2) + 1) with module.intercept_methods(op_interceptor(lambda a: a ** 2)), \ module.intercept_methods(op_interceptor(lambda a: a + 1)): y = mod(x) self.assertEqual(y, (x + 1) ** 2)
def test_name_like_interceptor_method_names_unchanged(self): log = [] def log_parent_methods(f, args, kwargs, context: module.MethodContext): if isinstance(context.module, ModuleWithCustomName): log.append(context.method_name) return f(*args, **kwargs) with module.intercept_methods(log_parent_methods): m = ModuleWithCustomName(name="parent") m.foo() # foo pretends to be __call__. m.bar() # bar pretends to be baz. # baz and call are happy to be themselves. m.baz() m() self.assertEqual(log, ["__init__", "foo", "bar", "baz", "__call__"])
def test_policy_with_interceptor(self): sidechannel = [] def my_interceptor(next_f, args, kwargs, context): sidechannel.append(context) return next_f(*args, **kwargs) # We need this to make sure that the mixed precision interceptor is # installed when we call set_policy (this only happens the first call). mixed_precision.reset_thread_local_state_for_test() policy = jmp.get_policy('p=f16,c=f32,o=f16') with module.intercept_methods(my_interceptor): mixed_precision.set_policy(OuterModule, policy) x = OuterModule()() self.assertEqual(x.dtype, jnp.float16) # Outer.init, Outer.call, Inner.init, Inner.call self.assertLen(sidechannel, 4)