示例#1
0
  def test_context_reuse_same_instance(self):
    params = {"parent_module/~/child_module": {"w": jnp.array(2.)},
              "parent_module/~/child_module_1": {"w": jnp.array(3.)},
              "parent_module_1/~/child_module": {"w": jnp.array(4.)},
              "parent_module_1/~/child_module_1": {"w": jnp.array(5.)}}

    with base.new_context(params=params) as ctx:
      mod1 = ParentModule()
      mod2 = ParentModule()
      self.assertEqual(mod1.module_name, "parent_module")
      self.assertEqual(mod2.module_name, "parent_module_1")
      for parent, (c1, c2) in ((mod1, (2., 3.)), (mod2, (4., 5.))):
        self.assertEqual(parent.child1(), c1)
        self.assertEqual(parent.child2(), c2)

    with ctx:
      for parent, (c1, c2) in ((mod1, (2., 3.)), (mod2, (4., 5.))):
        self.assertEqual(parent.child1(), c1)
        self.assertEqual(parent.child2(), c2)

    # Creating a new context should not be a problem.
    with base.new_context(params=ctx.collect_params()) as ctx:
      mod1 = ParentModule()
      mod2 = ParentModule()
      self.assertEqual(mod1.module_name, "parent_module")
      self.assertEqual(mod2.module_name, "parent_module_1")
      for parent, (c1, c2) in ((mod1, (2., 3.)), (mod2, (4., 5.))):
        self.assertEqual(parent.child1(), c1)
        self.assertEqual(parent.child2(), c2)
示例#2
0
    def test_get_state_no_init_raises(self):
        with base.new_context():
            with self.assertRaisesRegex(ValueError, "set an init function"):
                base.get_state("i")

        with base.new_context(state={"~": {}}):
            with self.assertRaisesRegex(ValueError, "set an init function"):
                base.get_state("i")
示例#3
0
    def test_get_state_no_shape_raises(self):
        with base.new_context():
            with self.assertRaisesRegex(ValueError, "provide shape and dtype"):
                base.get_state("i", init=jnp.zeros)

        with base.new_context(state={"~": {}}):
            with self.assertRaisesRegex(ValueError, "provide shape and dtype"):
                base.get_state("i", init=jnp.zeros)
示例#4
0
  def test_with_rng(self, seed):
    ctx_key = jax.random.PRNGKey(seed * 2 + 1)
    key = jax.random.PRNGKey(seed)
    _, next_key = jax.random.split(key)
    expected_output = jax.random.uniform(next_key, ())

    with base.new_context(rng=ctx_key):
      without_decorator_out = jax.random.uniform(base.next_rng_key(), ()).item()

    with base.new_context(rng=ctx_key):
      with base.with_rng(key):
        with_decorator_out = jax.random.uniform(base.next_rng_key(), ()).item()

    self.assertNotEqual(without_decorator_out, expected_output)
    self.assertEqual(with_decorator_out, expected_output)
示例#5
0
    def test_do_not_store(self):
        def my_creator(next_creator, shape, dtype, init, context):
            del next_creator, shape, dtype, init, context
            return base.DO_NOT_STORE

        def my_getter(next_getter, value, context):
            assert value is base.DO_NOT_STORE
            return next_getter(
                context.original_init(context.original_shape,
                                      context.original_dtype))

        def my_setter(next_setter, value, context):
            del next_setter, value, context
            return base.DO_NOT_STORE

        with base.new_context() as ctx:
            with base.custom_creator(my_creator, state=True), \
                 base.custom_getter(my_getter, state=True), \
                 base.custom_setter(my_setter):
                self.assertEqual(base.get_parameter("w", [], init=jnp.ones), 1)
                self.assertEqual(base.get_state("s1", [], init=jnp.ones), 1)
                base.set_state("s2", jnp.ones([]))

        self.assertEmpty(ctx.collect_params())
        self.assertEmpty(ctx.collect_state())
示例#6
0
    def test_setter_tree(self):
        witness = []
        x = {"a": jnp.ones([]), "b": jnp.zeros([123])}
        y = jax.tree_map(lambda x: x + 1, x)

        def my_setter(next_setter, value, ctx):
            self.assertIs(value, x)
            self.assertEqual(ctx.original_shape, {"a": (), "b": (123, )})
            self.assertEqual(ctx.original_dtype, {
                "a": jnp.float32,
                "b": jnp.float32
            })
            self.assertEqual(ctx.full_name, "~/x")
            self.assertEqual(ctx.name, "x")
            self.assertIsNone(ctx.module)
            witness.append(None)
            del next_setter
            return y

        with base.new_context():
            with base.custom_setter(my_setter):
                base.set_state("x", x)
                x = base.get_state("x")
                self.assertIs(x, y)

        self.assertNotEmpty(witness)
示例#7
0
    def test_unable_to_mutate_name(self):
        def mutates_name(next_creator, name, shape, dtype, init):
            next_creator(name + "_foo", shape, dtype, init)

        with base.new_context(), base.custom_creator(mutates_name):
            with self.assertRaisesRegex(ValueError,
                                        "Modifying .*name.* not supported"):
                base.get_parameter("w", [], init=jnp.ones)
示例#8
0
    def test_stateful(self):
        with base.new_context() as ctx:
            for _ in range(10):
                count = base.get_state("count", (), jnp.int32, jnp.zeros)
                base.set_state("count", count + 1)

        self.assertEqual(ctx.collect_initial_state(), {"~": {"count": 0}})
        self.assertEqual(ctx.collect_state(), {"~": {"count": 10}})
示例#9
0
 def test_context_copies_input(self):
     before = {"~": {"w": jnp.array(1.)}}
     with base.new_context(params=before, state=before) as ctx:
         base.get_parameter("w", [], init=jnp.ones)
         base.set_state("w", jnp.array(2.))
     self.assertEqual(ctx.collect_params(), {"~": {"w": jnp.array(1.)}})
     self.assertIsNot(ctx.collect_initial_state(), before)
     self.assertEqual(ctx.collect_initial_state(), before)
     self.assertEqual(ctx.collect_state(), {"~": {"w": jnp.array(2.)}})
     self.assertEqual(before, {"~": {"w": jnp.array(1.)}})
示例#10
0
 def init_fn(
     rng: Optional[Union[PRNGKey, PRNGSeed]],
     *args,
     **kwargs,
 ) -> Tuple[Params, State]:
     """Initializes your function collecting parameters and state."""
     rng = to_prng_sequence(rng, err_msg=INIT_RNG_ERROR)
     with base.new_context(rng=rng) as ctx:
         f(*args, **kwargs)
     return ctx.collect_params(), ctx.collect_initial_state()
示例#11
0
 def init(
     self,
     rng: tp.Optional[tp.Union[jnp.ndarray, int]],
     *args,
     **kwargs,
 ) -> tp.Tuple[tp.Any, haiku.Params, haiku.State]:
     """Initializes your function collecting parameters and state."""
     rng = to_prng_sequence(rng, err_msg=INIT_RNG_ERROR)
     with new_context(rng=rng) as ctx:
         output = self.f(*args, **kwargs)
     return output, ctx.collect_params(), ctx.collect_initial_state()
示例#12
0
    def test_assert_no_new_parameters(self):
        with base.new_context():
            base.get_parameter("w", [], init=jnp.zeros)
            with base.assert_no_new_parameters():
                # Should not raise, "w" already exists.
                base.get_parameter("w", [], init=jnp.zeros)

            with self.assertRaisesRegex(AssertionError,
                                        "New parameters were created: .*x"):
                with base.assert_no_new_parameters():
                    # Should raise, "x" does not exist.
                    base.get_parameter("x", [], init=jnp.zeros)
示例#13
0
 def apply_fn(
     params: Params,
     state: State,
     rng: Optional[Union[PRNGKey, PRNGSeed]],
     *args,
     **kwargs,
 ) -> Tuple[Any, State]:
     """Applies your function injecting parameters and state."""
     rng = to_prng_sequence(
         rng, err_msg=(APPLY_RNG_STATE_ERROR if state else APPLY_RNG_ERROR))
     with base.new_context(params=params, state=state, rng=rng) as ctx:
         out = f(*args, **kwargs)
     return out, ctx.collect_state()
示例#14
0
  def test_init_custom_creator(self, get_x, custom_x, collect_x):
    def zeros_creator(next_creator, shape, dtype, init, context):
      self.assertEqual(context.full_name, "~/w")
      self.assertEqual(shape, [])
      self.assertEqual(dtype, jnp.float32)
      self.assertEqual(init, jnp.ones)
      return next_creator(shape, dtype, jnp.zeros)

    with base.new_context() as ctx:
      with custom_x(zeros_creator):
        get_x("w", [], init=jnp.ones)

    self.assertEqual(getattr(ctx, collect_x)(), {"~": {"w": jnp.zeros([])}})
示例#15
0
    def test_set_then_get(self):
        with base.new_context() as ctx:
            base.set_state("i", 1)
            base.get_state("i")

        self.assertEqual(ctx.collect_initial_state(), {"~": {"i": 1}})

        for _ in range(10):
            with ctx:
                base.set_state("i", 1)
                y = base.get_state("i")
                self.assertEqual(y, 1)
            self.assertEqual(ctx.collect_initial_state(), {"~": {"i": 1}})
示例#16
0
    def test_init_custom_creator(self):
        def zeros_creator(next_creator, name, shape, dtype, init):
            self.assertEqual(name, "~/w")
            self.assertEqual(shape, [])
            self.assertEqual(dtype, jnp.float32)
            self.assertEqual(init, jnp.ones)
            return next_creator(name, shape, dtype, jnp.zeros)

        with base.new_context() as ctx:
            with base.custom_creator(zeros_creator):
                base.get_parameter("w", [], init=jnp.ones)

        self.assertEqual(ctx.collect_params(), {"~": {"w": jnp.zeros([])}})
示例#17
0
 def init_fn(
     rng: Optional[Union[PRNGKey, int]],
     *args,
     **kwargs,
 ) -> Tuple[hk.Params, hk.State]:
     """Initializes your function collecting parameters and state."""
     rng = to_prng_sequence(rng, err_msg=INIT_RNG_ERROR)
     with base.new_context(rng=rng) as ctx:
         try:
             f(*args, **kwargs)
         except jax.errors.UnexpectedTracerError as e:
             raise jax.errors.UnexpectedTracerError(
                 unexpected_tracer_hint) from e
     return ctx.collect_params(), ctx.collect_initial_state()
示例#18
0
 def apply_fn(
     params: Optional[hk.Params],
     state: Optional[hk.State],
     rng: Optional[Union[PRNGKey, int]],
     *args,
     **kwargs,
 ) -> Tuple[Any, hk.State]:
     """Applies your function injecting parameters and state."""
     params = check_mapping("params", params)
     state = check_mapping("state", state)
     rng = to_prng_sequence(
         rng, err_msg=(APPLY_RNG_STATE_ERROR if state else APPLY_RNG_ERROR))
     with base.new_context(params=params, state=state, rng=rng) as ctx:
         out = f(*args, **kwargs)
     return out, ctx.collect_state()
示例#19
0
    def test_original_shape(self, get_x, custom_x):
        def new_shape_creator(next_creator, shape, dtype, init, context):
            del shape
            del context
            new_shape = (1, 2, 3)
            return next_creator(new_shape, dtype, init)

        def original_shape_restorer(next_creator, shape, dtype, init, context):
            assert shape == (1, 2, 3)
            return next_creator(context.original_shape, dtype, init)

        with base.new_context():
            with custom_x(new_shape_creator):
                with custom_x(original_shape_restorer):
                    value = get_x("w", [5], jnp.bfloat16, jnp.ones)
                    assert value.shape == (5, )
示例#20
0
  def test_getter_types(self, params, state):
    log = []
    def logging_getter(next_getter, value, context):
      log.append(context.full_name)
      return next_getter(value)

    with base.new_context():
      with base.custom_getter(logging_getter, params=params, state=state):
        base.get_parameter("params", [], init=jnp.zeros)
        base.get_state("state", [], init=jnp.zeros)

    self.assertLen(log, int(params) + int(state))
    if params:
      self.assertIn("~/params", log)
    if state:
      self.assertIn("~/state", log)
示例#21
0
    def test_custom_getter_bf16(self, get_x, custom_x, collect_x):
        def bf16_getter(next_getter, value, context):
            del context
            if value.dtype == jnp.float32:
                value = value.astype(jnp.bfloat16)
            return next_getter(value)

        with base.new_context() as ctx:
            with custom_x(bf16_getter):
                f = get_x("f", [], jnp.float32, init=jnp.ones)
                i = get_x("i", [], jnp.int32, init=jnp.ones)

        collection = getattr(ctx, collect_x)()
        self.assertEqual(collection["~"]["f"].dtype, jnp.float32)
        self.assertEqual(f.dtype, jnp.bfloat16)
        self.assertEqual(collection["~"]["i"].dtype, jnp.int32)
        self.assertEqual(i.dtype, jnp.int32)
示例#22
0
    def test_nested_creators(self):
        log = []

        def logging_creator(log_msg):
            def _logging_creator(next_creator, name, shape, dtype, init):
                log.append(log_msg)
                return next_creator(name, shape, dtype, init)

            return _logging_creator

        with base.new_context():
            with base.custom_creator(logging_creator("a")), \
                 base.custom_creator(logging_creator("b")), \
                 base.custom_creator(logging_creator("c")):
                base.get_parameter("w", [], init=jnp.ones)

        self.assertEqual(log, ["a", "b", "c"])
示例#23
0
    def test_original_shape(self):
        def new_shape_creator(next_creator, shape, dtype, init, context):
            del shape
            del context
            new_shape = (1, 2, 3)
            return next_creator(new_shape, dtype, init)

        def original_shape_restorer(next_creator, shape, dtype, init, context):
            assert shape == (1, 2, 3)
            return next_creator(context.original_shape, dtype, init)

        with base.new_context():
            with base.custom_creator(new_shape_creator):
                with base.custom_creator(original_shape_restorer):
                    param = base.get_parameter("w", [5], jnp.bfloat16,
                                               jnp.ones)
                    assert param.shape == (5, )
示例#24
0
 def apply(
     self,
     params: tp.Optional[haiku.Params],
     state: tp.Optional[haiku.State],
     rng: tp.Optional[tp.Union[jnp.ndarray, int]],
     *args,
     **kwargs,
 ) -> tp.Tuple[tp.Any, haiku.State]:
     """Applies your function injecting parameters and state."""
     params = check_mapping("params", params)
     state = check_mapping("state", state)
     rng = to_prng_sequence(
         rng, err_msg=(APPLY_RNG_STATE_ERROR if state else APPLY_RNG_ERROR)
     )
     with new_context(params=params, state=state, rng=rng) as ctx:
         out = self.f(*args, **kwargs)
     return out, ctx.collect_state()
示例#25
0
  def test_nested_creators(self, get_x, custom_x):
    log = []

    def logging_creator(log_msg):
      def _logging_creator(next_creator, shape, dtype, init, context):
        del context
        log.append(log_msg)
        return next_creator(shape, dtype, init)
      return _logging_creator

    with base.new_context():
      with custom_x(logging_creator("a")), \
           custom_x(logging_creator("b")), \
           custom_x(logging_creator("c")):
        get_x("w", [], init=jnp.ones)

    self.assertEqual(log, ["a", "b", "c"])
示例#26
0
    def test_custom_getter_bf16(self):
        def bf16_getter(next_getter, value, context):
            del context
            if value.dtype == jnp.float32:
                value = value.astype(jnp.bfloat16)
            return next_getter(value)

        with base.new_context() as ctx:
            with base.custom_getter(bf16_getter):
                f = base.get_parameter("f", [], jnp.float32, init=jnp.ones)
                i = base.get_parameter("i", [], jnp.int32, init=jnp.ones)

        params = ctx.collect_params()
        self.assertEqual(params["~"]["f"].dtype, jnp.float32)
        self.assertEqual(f.dtype, jnp.bfloat16)
        self.assertEqual(params["~"]["i"].dtype, jnp.int32)
        self.assertEqual(i.dtype, jnp.int32)
示例#27
0
 def test_dataclass(self, name):
   with base.new_context() as ctx:
     output_sizes = [300, 100, 10]
     if name is None:
       mlp = DataMLP(output_sizes)
     else:
       mlp = DataMLP(output_sizes, name="mlp")
     mlp(jnp.ones([1, 28 * 28]))
     params = ctx.collect_params()
     if name is None:
       module_names = ["data_mlp/linear", "data_mlp/linear_1",
                       "data_mlp/linear_2"]
     else:
       module_names = ["mlp/linear", "mlp/linear_1", "mlp/linear_2"]
     self.assertEqual(list(params.keys()), module_names)
     for module_name, output_size in zip(module_names, output_sizes):
       self.assertEqual(params[module_name]["w"].shape[-1], output_size)
       self.assertEqual(params[module_name]["b"].shape[-1], output_size)
示例#28
0
  def test_nested_getters(self, get_x, custom_x):
    log = []

    def logging_getter(log_msg, dtype_in, dtype_out):
      def _logging_getter(next_getter, value, context):
        del context
        log.append(log_msg)
        self.assertEqual(value.dtype, dtype_in)
        value = value.astype(dtype_out)
        return next_getter(value)
      return _logging_getter

    with base.new_context():
      with custom_x(logging_getter("a", jnp.float32, jnp.bfloat16)), \
           custom_x(logging_getter("b", jnp.bfloat16, jnp.int32)), \
           custom_x(logging_getter("c", jnp.int32, jnp.int8)):
        w = get_x("w", [], init=jnp.ones)

    self.assertEqual(w.dtype, jnp.int8)
    self.assertEqual(log, ["a", "b", "c"])
示例#29
0
    def test_original_dtype(self):
        def dtype_cast_creator(next_creator, shape, dtype, init, context):
            if context.original_dtype == jnp.bfloat16:
                dtype = jnp.float32
            return next_creator(shape, dtype, init)

        def dtype_recast_getter(next_getter, value, context):
            if context.original_dtype == jnp.bfloat16:
                assert value.dtype == jnp.float32
                value = value.astype(jnp.bfloat16)
            return next_getter(value)

        with base.new_context() as ctx:
            with base.custom_creator(dtype_cast_creator), \
                 base.custom_getter(dtype_recast_getter):
                param = base.get_parameter("w", [], jnp.bfloat16, jnp.ones)
                orig_param = jax.tree_leaves(ctx.collect_params())[0]

                assert param.dtype == jnp.bfloat16
                assert orig_param.dtype == jnp.float32
示例#30
0
    def test_original_dtype(self, get_x, custom_create_x, custom_get_x,
                            collect_x):
        def dtype_cast_creator(next_creator, shape, dtype, init, context):
            if context.original_dtype == jnp.bfloat16:
                dtype = jnp.float32
            return next_creator(shape, dtype, init)

        def dtype_recast_getter(next_getter, value, context):
            if context.original_dtype == jnp.bfloat16:
                assert value.dtype == jnp.float32
                value = value.astype(jnp.bfloat16)
            return next_getter(value)

        with base.new_context() as ctx:
            with custom_create_x(dtype_cast_creator), \
                 custom_get_x(dtype_recast_getter):
                value = get_x("w", [], jnp.bfloat16, jnp.ones)
                orig_value = jax.tree_leaves(getattr(ctx, collect_x)())[0]

                assert value.dtype == jnp.bfloat16
                assert orig_value.dtype == jnp.float32