def wrapped(module, *args, **kwargs): """Calls the original method with a group name set before and after.""" if not base.frame_stack: raise ValueError( "All `hk.Module`s must be initialized inside an `hk.transform`." ) frame = base.current_frame() state = base.ModuleState(module=module, method_name=method_name) with frame.module(state), _module_method_call(module, method_name): # hk.Module enters the module name scope for all methods. module_name = getattr(module, "module_name", None) f = functools.partial(unbound_method, module) f = functools.partial(run_interceptors, f, method_name, module) if modules_with_named_call and module_name: local_name = module_name.split("/")[-1] f = named_call.stateful_named_call(f, name=local_name) out = f(*args, **kwargs) # Notify parent modules about our existence. if module_name is not None: for module_state in frame.module_stack: module_state.module._submodules.add(module_name) # pylint: disable=protected-access return out
def wrapped(self, *args, **kwargs): """Calls the original method with a group name set before and after.""" if not base.frame_stack: raise ValueError( "All `hk.Module`s must be initialized inside an `hk.transform`." ) frame = base.current_frame() state = base.ModuleState(module=self, method_name=method_name) with frame.module(state), _module_method_call(self, method_name): # hk.Module enters the module name scope for all methods. module_name = getattr(self, "module_name", None) f = functools.partial(unbound_method, self) f = functools.partial(run_interceptors, f, method_name, self) # TODO(tomhennigan): With omnistaging primitives (like named call) will # stage out return values eagerly. For functions that produce non-Array # values (e.g. `def is_batched(self, x) -> bool`) a tracer will be # returned that might result in a concretization error. For now we only # enable named call on __call__ (covering 99% of the interesting usages) # with an assumption that __call__ is `f(*) -> Tree[Array]`. Longer term # we may want to split static and dynamic results in named call to support # other methods. if modules_with_named_call and module_name and method_name == "__call__": local_name = module_name.split("/")[-1] f = named_call.stateful_named_call(f, name=local_name) out = f(*args, **kwargs) # Notify parent modules about our existence. if module_name is not None: for module_state in frame.module_stack: module_state.module._submodules.add(module_name) # pylint: disable=protected-access return out
def test_jax_transforms(self, transform): f = jnp.sum x = jnp.array([1.]) unnamed_out = transform(f)(x) named_out = transform(named_call.stateful_named_call(f, name='test'))(x) self.assertEqual(unnamed_out, named_out)
def test_partial_eval(self): if not hasattr(xla.xb, 'parameter'): self.skipTest('Need Jaxlib version > 0.1.45') f = named_call.stateful_named_call(lambda x, y: y if x else None, name='test') f = jax.jit(functools.partial(f, True)) out = f(5) self.assertEqual(out, 5)
def test_static_argnums(self): if not hasattr(xla.xb, 'parameter'): self.skipTest('Need Jaxlib version > 0.1.45') f = named_call.stateful_named_call(lambda x, y: y if x else None, name='test') f = jax.jit(f, static_argnums=(0, )) out = f(True, 5) self.assertEqual(out, 5)
def test_jax_transforms(self, transform): if not hasattr(xla.xb, 'parameter'): self.skipTest('Need Jaxlib version > 0.1.45') f = jax.numpy.sum x = jax.numpy.array([1.]) unnamed_out = transform(f)(x) named_out = transform(named_call.stateful_named_call(f, name='test'))(x) self.assertEqual(unnamed_out, named_out)
def test_non_jaxtype_arg(self): # For the test to fail without the invalid JaxType filter we need to pass # in a valid JaxType that forces the invalid Jaxtype to be raised to an # abstract value. def f(not_a_jaxtype, a_jaxtype): # then Jax needs to try and evaluate the abstractified non-JaxType if not_a_jaxtype: return a_jaxtype return 0 f = named_call.stateful_named_call(f, name='test') out = jax.jit(f, static_argnums=(0,))('not a Jaxtype', 1) self.assertEqual(out, 1)
def test_partial_eval(self): f = named_call.stateful_named_call(lambda x, y: y if x else None, name='test') f = jax.jit(functools.partial(f, True)) out = f(5) self.assertEqual(out, 5)
def test_static_argnums(self): f = named_call.stateful_named_call(lambda x, y: y if x else None, name='test') f = jax.jit(f, static_argnums=(0,)) out = f(True, 5) self.assertEqual(out, 5)