Ejemplo n.º 1
0
    def test_do_not_store(self):
        def my_creator(next_creator, shape, dtype, init, context):
            del next_creator, shape, dtype, init, context
            return base.DO_NOT_STORE

        def my_getter(next_getter, value, context):
            assert value is base.DO_NOT_STORE
            return next_getter(
                context.original_init(context.original_shape,
                                      context.original_dtype))

        def my_setter(next_setter, value, context):
            del next_setter, value, context
            return base.DO_NOT_STORE

        with base.new_context() as ctx:
            with base.custom_creator(my_creator, state=True), \
                 base.custom_getter(my_getter, state=True), \
                 base.custom_setter(my_setter):
                self.assertEqual(base.get_parameter("w", [], init=jnp.ones), 1)
                self.assertEqual(base.get_state("s1", [], init=jnp.ones), 1)
                base.set_state("s2", jnp.ones([]))

        self.assertEmpty(ctx.collect_params())
        self.assertEmpty(ctx.collect_state())
Ejemplo n.º 2
0
    def test_nested_creators(self):
        log = []

        def logging_creator(log_msg):
            def _logging_creator(next_creator, name, shape, dtype, init):
                log.append(log_msg)
                return next_creator(name, shape, dtype, init)

            return _logging_creator

        with base.new_context():
            with base.custom_creator(logging_creator("a")), \
                 base.custom_creator(logging_creator("b")), \
                 base.custom_creator(logging_creator("c")):
                base.get_parameter("w", [], init=jnp.ones)

        self.assertEqual(log, ["a", "b", "c"])
Ejemplo n.º 3
0
    def test_unable_to_mutate_name(self):
        def mutates_name(next_creator, name, shape, dtype, init):
            next_creator(name + "_foo", shape, dtype, init)

        with base.new_context(), base.custom_creator(mutates_name):
            with self.assertRaisesRegex(ValueError,
                                        "Modifying .*name.* not supported"):
                base.get_parameter("w", [], init=jnp.ones)
Ejemplo n.º 4
0
    def test_original_shape(self):
        def new_shape_creator(next_creator, shape, dtype, init, context):
            del shape
            del context
            new_shape = (1, 2, 3)
            return next_creator(new_shape, dtype, init)

        def original_shape_restorer(next_creator, shape, dtype, init, context):
            assert shape == (1, 2, 3)
            return next_creator(context.original_shape, dtype, init)

        with base.new_context():
            with base.custom_creator(new_shape_creator):
                with base.custom_creator(original_shape_restorer):
                    param = base.get_parameter("w", [5], jnp.bfloat16,
                                               jnp.ones)
                    assert param.shape == (5, )
Ejemplo n.º 5
0
  def test_nested_creators(self):
    log = []

    def logging_creator(log_msg):
      def _logging_creator(next_creator, name, shape, dtype, init):
        log.append(log_msg)
        return next_creator(name, shape, dtype, init)
      return _logging_creator

    init_fn, _ = base.transform(
        lambda: base.get_parameter("w", [], init=jnp.ones))

    a, b, c = map(logging_creator, ["a", "b", "c"])
    with base.custom_creator(a), base.custom_creator(b), base.custom_creator(c):
      init_fn(None)

    self.assertEqual(log, ["a", "b", "c"])
Ejemplo n.º 6
0
    def test_creator_requires_context(self):
        def my_creator(next_creator, shape, dtype, init, context):
            del context
            return next_creator(shape, dtype, init)

        with self.assertRaisesRegex(
                ValueError, "must be used as part of an `hk.transform`"):
            with base.custom_creator(my_creator):
                pass
Ejemplo n.º 7
0
  def test_unable_to_mutate_name(self):
    def mutates_name(next_creator, name, shape, dtype, init):
      next_creator(name + "_foo", shape, dtype, init)

    init_fn, _ = base.transform(
        lambda: base.get_parameter("w", [], init=jnp.ones))

    with self.assertRaisesRegex(ValueError, "Modifying .*name.* not supported"):
      with base.custom_creator(mutates_name):
        init_fn(None)
Ejemplo n.º 8
0
    def test_init_custom_creator(self):
        def zeros_creator(next_creator, name, shape, dtype, init):
            self.assertEqual(name, "~/w")
            self.assertEqual(shape, [])
            self.assertEqual(dtype, jnp.float32)
            self.assertEqual(init, jnp.ones)
            return next_creator(name, shape, dtype, jnp.zeros)

        with base.new_context() as ctx:
            with base.custom_creator(zeros_creator):
                base.get_parameter("w", [], init=jnp.ones)

        self.assertEqual(ctx.collect_params(), {"~": {"w": jnp.zeros([])}})
Ejemplo n.º 9
0
    def test_init_custom_creator(self):
        def zeros_creator(next_creator, name, shape, dtype, init):
            self.assertEqual(name, "~/w")
            self.assertEqual(shape, [])
            self.assertEqual(dtype, jnp.float32)
            self.assertEqual(init, jnp.ones)
            return next_creator(name, shape, dtype, jnp.zeros)

        init_fn, _ = base.transform(
            lambda: base.get_parameter("w", [], init=jnp.ones))

        with base.custom_creator(zeros_creator):
            params = init_fn(None)

        self.assertEqual(params, {"~": {"w": jnp.zeros([])}})
Ejemplo n.º 10
0
  def test_creator_types(self, params, state):
    log = []
    def logging_creator(next_creator, shape, dtype, init, context):
      log.append(context.full_name)
      return next_creator(shape, dtype, init)

    with base.new_context():
      with base.custom_creator(logging_creator, params=params, state=state):
        base.get_parameter("params", [], init=jnp.zeros)
        base.get_state("state", [], init=jnp.zeros)

    self.assertLen(log, int(params) + int(state))
    if params:
      self.assertIn("~/params", log)
    if state:
      self.assertIn("~/state", log)
Ejemplo n.º 11
0
    def test_original_dtype(self):
        def dtype_cast_creator(next_creator, shape, dtype, init, context):
            if context.original_dtype == jnp.bfloat16:
                dtype = jnp.float32
            return next_creator(shape, dtype, init)

        def dtype_recast_getter(next_getter, value, context):
            if context.original_dtype == jnp.bfloat16:
                assert value.dtype == jnp.float32
                value = value.astype(jnp.bfloat16)
            return next_getter(value)

        with base.new_context() as ctx:
            with base.custom_creator(dtype_cast_creator), \
                 base.custom_getter(dtype_recast_getter):
                param = base.get_parameter("w", [], jnp.bfloat16, jnp.ones)
                orig_param = jax.tree_leaves(ctx.collect_params())[0]

                assert param.dtype == jnp.bfloat16
                assert orig_param.dtype == jnp.float32
Ejemplo n.º 12
0
 def net():
     with base.custom_creator(counting_creator):
         return MultipleForwardMethods()()
Ejemplo n.º 13
0
 def f():
     a, b, c = map(logging_creator, ["a", "b", "c"])
     with base.custom_creator(a), \
          base.custom_creator(b), \
          base.custom_creator(c):
         return base.get_parameter("w", [], init=jnp.ones)
Ejemplo n.º 14
0
 def net():
     with base.custom_creator(counting_creator):
         for i in range(4):
             base.get_parameter("w{}".format(i), [], init=jnp.zeros)
Ejemplo n.º 15
0
 def f():
     with base.custom_creator(zeros_creator):
         return base.get_parameter("w", [], init=jnp.ones)