Ejemplo n.º 1
0
    def wrapped(module, *args, **kwargs):
        """Calls the original method with a group name set before and after."""
        if not base.frame_stack:
            raise ValueError(
                "All `hk.Module`s must be initialized inside an `hk.transform`."
            )

        frame = base.current_frame()
        state = base.ModuleState(module=module, method_name=method_name)
        with frame.module(state), _module_method_call(module, method_name):
            # hk.Module enters the module name scope for all methods.
            module_name = getattr(module, "module_name", None)
            f = functools.partial(unbound_method, module)
            f = functools.partial(run_interceptors, f, method_name, module)
            if modules_with_named_call and module_name:
                local_name = module_name.split("/")[-1]
                f = named_call.stateful_named_call(f, name=local_name)

            out = f(*args, **kwargs)

            # Notify parent modules about our existence.
            if module_name is not None:
                for module_state in frame.module_stack:
                    module_state.module._submodules.add(module_name)  # pylint: disable=protected-access
        return out
Ejemplo n.º 2
0
    def wrapped(self, *args, **kwargs):
        """Calls the original method with a group name set before and after."""
        if not base.frame_stack:
            raise ValueError(
                "All `hk.Module`s must be initialized inside an `hk.transform`."
            )

        frame = base.current_frame()
        state = base.ModuleState(module=self, method_name=method_name)
        with frame.module(state), _module_method_call(self, method_name):
            # hk.Module enters the module name scope for all methods.
            module_name = getattr(self, "module_name", None)
            f = functools.partial(unbound_method, self)
            f = functools.partial(run_interceptors, f, method_name, self)
            # TODO(tomhennigan): With omnistaging primitives (like named call) will
            # stage out return values eagerly. For functions that produce non-Array
            # values (e.g. `def is_batched(self, x) -> bool`) a tracer will be
            # returned that might result in a concretization error. For now we only
            # enable named call on __call__ (covering 99% of the interesting usages)
            # with an assumption that __call__ is `f(*) -> Tree[Array]`. Longer term
            # we may want to split static and dynamic results in named call to support
            # other methods.
            if modules_with_named_call and module_name and method_name == "__call__":
                local_name = module_name.split("/")[-1]
                f = named_call.stateful_named_call(f, name=local_name)

            out = f(*args, **kwargs)

            # Notify parent modules about our existence.
            if module_name is not None:
                for module_state in frame.module_stack:
                    module_state.module._submodules.add(module_name)  # pylint: disable=protected-access
        return out
Ejemplo n.º 3
0
  def test_jax_transforms(self, transform):
    f = jnp.sum
    x = jnp.array([1.])

    unnamed_out = transform(f)(x)
    named_out = transform(named_call.stateful_named_call(f, name='test'))(x)

    self.assertEqual(unnamed_out, named_out)
Ejemplo n.º 4
0
 def test_partial_eval(self):
     if not hasattr(xla.xb, 'parameter'):
         self.skipTest('Need Jaxlib version > 0.1.45')
     f = named_call.stateful_named_call(lambda x, y: y if x else None,
                                        name='test')
     f = jax.jit(functools.partial(f, True))
     out = f(5)
     self.assertEqual(out, 5)
Ejemplo n.º 5
0
 def test_static_argnums(self):
     if not hasattr(xla.xb, 'parameter'):
         self.skipTest('Need Jaxlib version > 0.1.45')
     f = named_call.stateful_named_call(lambda x, y: y if x else None,
                                        name='test')
     f = jax.jit(f, static_argnums=(0, ))
     out = f(True, 5)
     self.assertEqual(out, 5)
Ejemplo n.º 6
0
    def test_jax_transforms(self, transform):
        if not hasattr(xla.xb, 'parameter'):
            self.skipTest('Need Jaxlib version > 0.1.45')
        f = jax.numpy.sum
        x = jax.numpy.array([1.])

        unnamed_out = transform(f)(x)
        named_out = transform(named_call.stateful_named_call(f,
                                                             name='test'))(x)

        self.assertEqual(unnamed_out, named_out)
Ejemplo n.º 7
0
  def test_non_jaxtype_arg(self):
    # For the test to fail without the invalid JaxType filter we need to pass
    # in a valid JaxType that forces the invalid Jaxtype to be raised to an
    # abstract value.
    def f(not_a_jaxtype, a_jaxtype):
      # then Jax needs to try and evaluate the abstractified non-JaxType
      if not_a_jaxtype:
        return a_jaxtype
      return 0

    f = named_call.stateful_named_call(f, name='test')
    out = jax.jit(f, static_argnums=(0,))('not a Jaxtype', 1)
    self.assertEqual(out, 1)
Ejemplo n.º 8
0
 def test_partial_eval(self):
   f = named_call.stateful_named_call(lambda x, y: y if x else None,
                                      name='test')
   f = jax.jit(functools.partial(f, True))
   out = f(5)
   self.assertEqual(out, 5)
Ejemplo n.º 9
0
 def test_static_argnums(self):
   f = named_call.stateful_named_call(lambda x, y: y if x else None,
                                      name='test')
   f = jax.jit(f, static_argnums=(0,))
   out = f(True, 5)
   self.assertEqual(out, 5)