Exemplo n.º 1
0
    def test_do_not_store(self):
        def my_creator(next_creator, shape, dtype, init, context):
            del next_creator, shape, dtype, init, context
            return base.DO_NOT_STORE

        def my_getter(next_getter, value, context):
            assert value is base.DO_NOT_STORE
            return next_getter(
                context.original_init(context.original_shape,
                                      context.original_dtype))

        def my_setter(next_setter, value, context):
            del next_setter, value, context
            return base.DO_NOT_STORE

        with base.new_context() as ctx:
            with base.custom_creator(my_creator, state=True), \
                 base.custom_getter(my_getter, state=True), \
                 base.custom_setter(my_setter):
                self.assertEqual(base.get_parameter("w", [], init=jnp.ones), 1)
                self.assertEqual(base.get_state("s1", [], init=jnp.ones), 1)
                base.set_state("s2", jnp.ones([]))

        self.assertEmpty(ctx.collect_params())
        self.assertEmpty(ctx.collect_state())
Exemplo n.º 2
0
    def test_getter_requires_context(self):
        def my_getter(next_getter, value, context):
            del context
            return next_getter(value)

        with self.assertRaisesRegex(
                ValueError, "must be used as part of an `hk.transform`"):
            with base.custom_getter(my_getter):
                pass
Exemplo n.º 3
0
    def test_nested_getters(self):
        log = []

        def logging_getter(log_msg, dtype_in, dtype_out):
            def _logging_getter(next_getter, value, context):
                del context
                log.append(log_msg)
                self.assertEqual(value.dtype, dtype_in)
                value = value.astype(dtype_out)
                return next_getter(value)

            return _logging_getter

        with base.new_context():
            with base.custom_getter(logging_getter("a", jnp.float32, jnp.bfloat16)), \
                 base.custom_getter(logging_getter("b", jnp.bfloat16, jnp.int32)), \
                 base.custom_getter(logging_getter("c", jnp.int32, jnp.int8)):
                w = base.get_parameter("w", [], init=jnp.ones)

        self.assertEqual(w.dtype, jnp.int8)
        self.assertEqual(log, ["a", "b", "c"])
Exemplo n.º 4
0
  def test_getter_types(self, params, state):
    log = []
    def logging_getter(next_getter, value, context):
      log.append(context.full_name)
      return next_getter(value)

    with base.new_context():
      with base.custom_getter(logging_getter, params=params, state=state):
        base.get_parameter("params", [], init=jnp.zeros)
        base.get_state("state", [], init=jnp.zeros)

    self.assertLen(log, int(params) + int(state))
    if params:
      self.assertIn("~/params", log)
    if state:
      self.assertIn("~/state", log)
Exemplo n.º 5
0
    def test_custom_getter_bf16(self):
        def bf16_getter(next_getter, value, context):
            del context
            if value.dtype == jnp.float32:
                value = value.astype(jnp.bfloat16)
            return next_getter(value)

        with base.new_context() as ctx:
            with base.custom_getter(bf16_getter):
                f = base.get_parameter("f", [], jnp.float32, init=jnp.ones)
                i = base.get_parameter("i", [], jnp.int32, init=jnp.ones)

        params = ctx.collect_params()
        self.assertEqual(params["~"]["f"].dtype, jnp.float32)
        self.assertEqual(f.dtype, jnp.bfloat16)
        self.assertEqual(params["~"]["i"].dtype, jnp.int32)
        self.assertEqual(i.dtype, jnp.int32)
Exemplo n.º 6
0
    def test_original_dtype(self):
        def dtype_cast_creator(next_creator, shape, dtype, init, context):
            if context.original_dtype == jnp.bfloat16:
                dtype = jnp.float32
            return next_creator(shape, dtype, init)

        def dtype_recast_getter(next_getter, value, context):
            if context.original_dtype == jnp.bfloat16:
                assert value.dtype == jnp.float32
                value = value.astype(jnp.bfloat16)
            return next_getter(value)

        with base.new_context() as ctx:
            with base.custom_creator(dtype_cast_creator), \
                 base.custom_getter(dtype_recast_getter):
                param = base.get_parameter("w", [], jnp.bfloat16, jnp.ones)
                orig_param = jax.tree_leaves(ctx.collect_params())[0]

                assert param.dtype == jnp.bfloat16
                assert orig_param.dtype == jnp.float32
Exemplo n.º 7
0
 def forward_eval_bfloat16(x):
     with base.custom_getter(_bfloat16_getter, state=True):
         return get_batch_norm()(x, is_training=False)