コード例 #1
0
ファイル: base_test.py プロジェクト: ibab/haiku
    def test_parameter_in_apply(self, params):
        _, apply_fn = base.transform(
            lambda: base.get_parameter("w", [], init=jnp.zeros))

        with self.assertRaisesRegex(
                ValueError, "parameters must be created as part of `init`"):
            apply_fn(params)
コード例 #2
0
    def test_ema_on_changing_data(self):
        def f():
            return basic.Linear(output_size=2, b_init=jnp.ones)(jnp.zeros([6]))

        init_fn, _ = base.transform(f)
        params = init_fn(random.PRNGKey(428))

        def g(x):
            return moving_averages.EMAParamsTree(0.2)(x)

        init_fn, apply_fn = base.without_apply_rng(
            base.transform_with_state(g))
        _, params_state = init_fn(None, params)
        params, params_state = apply_fn(None, params_state, params)
        # Let's modify our params.
        changed_params = tree.map_structure(lambda t: 2. * t, params)
        ema_params, params_state = apply_fn(None, params_state, changed_params)

        # ema_params should be different from changed params!
        tree.assert_same_structure(changed_params, ema_params)
        for p1, p2 in zip(tree.flatten(params), tree.flatten(ema_params)):
            self.assertEqual(p1.shape, p2.shape)
            with self.assertRaisesRegex(AssertionError,
                                        "Not equal to tolerance"):
                np.testing.assert_allclose(p1, p2, atol=1e-6)
コード例 #3
0
    def test_connect_conv_same(self, n):
        batch_size = 2
        input_shape = [16] * n
        input_shape_ = [batch_size] + input_shape + [4]
        state_shape_ = [batch_size] + input_shape + [3]

        data_ = jnp.zeros(input_shape_)
        state_init_ = (jnp.zeros(state_shape_), jnp.zeros(state_shape_))

        def f(data, state_init):
            net = recurrent.ConvNDLSTM(n,
                                       input_shape=input_shape,
                                       output_channels=3,
                                       kernel_shape=3)
            return net(data, state_init)

        init_fn, apply_fn = base.transform(f)
        params = init_fn(random.PRNGKey(42), data_, state_init_)

        out, state = apply_fn(params, data_, state_init_)

        expected_output_shape = (batch_size, ) + tuple(input_shape) + (3, )
        self.assertEqual(out.shape, expected_output_shape)
        self.assertEqual(state[0].shape, expected_output_shape)
        self.assertEqual(state[1].shape, expected_output_shape)
コード例 #4
0
    def test_functionalize(self):
        def inner_fn(x):
            assert x.ndim == 1
            return Bias()(x)

        def outer_fn(x):
            assert x.ndim == 2
            x = Bias()(x)
            inner = base.without_apply_rng(base.transform_with_state(inner_fn))
            inner_p, inner_s = lift.lift(inner.init)(base.next_rng_key(), x[0])
            vmap_inner = jax.vmap(inner.apply, in_axes=(None, None, 0))
            return vmap_inner(inner_p, inner_s, x)[0]

        key = jax.random.PRNGKey(428)
        init_key, apply_key = jax.random.split(key)
        data = np.zeros((3, 2))

        outer = base.transform(outer_fn, apply_rng=True)
        outer_params = outer.init(init_key, data)
        self.assertEqual(
            outer_params, {
                "bias": {
                    "b": np.ones(())
                },
                "lifted_init_fn": {
                    "sentinel": np.zeros(())
                },
                "lifted_init_fn/bias": {
                    "b": np.ones(())
                },
            })

        out = outer.apply(outer_params, apply_key, data)
        np.testing.assert_equal(out, 2 * np.ones((3, 2)))
コード例 #5
0
    def test_used_inside_transform(self):
        log = []

        def counting_creator(next_creator, name, shape, dtype, init):
            log.append(name)
            return next_creator(name, shape, dtype, init)

        def net():
            with base.custom_creator(counting_creator):
                return MultipleForwardMethods()()

        init_fn, apply_fn = base.transform(net)

        params = init_fn(None)
        self.assertEqual(
            log,
            [
                "multiple_forward_methods/~/scalar_module/w",  # __init__
                "multiple_forward_methods/scalar_module/w",  # __call__
                "multiple_forward_methods/~encode/scalar_module/w",  # encode
                "multiple_forward_methods/~decode/scalar_module/w",  # decode
            ])

        del log[:]
        apply_fn(params)
        self.assertEmpty(log)
コード例 #6
0
    def test_convolution(self, with_bias):
        def f():
            data = np.ones([1, 10, 10, 3])
            data[0, :, :, 1] += 1
            data[0, :, :, 2] += 2
            data = jnp.array(data)
            net = depthwise_conv.DepthwiseConv2D(
                channel_multiplier=1,
                kernel_shape=3,
                stride=1,
                padding="VALID",
                with_bias=with_bias,
                data_format="channels_last",
                **create_constant_initializers(1.0, 1.0, with_bias))
            return net(data)

        init_fn, apply_fn = base.transform(f)
        out = apply_fn(init_fn(random.PRNGKey(428)))
        self.assertEqual(out.shape, (1, 8, 8, 3))
        self.assertLen(np.unique(out[0, :, :, 0]), 1)
        self.assertLen(np.unique(out[0, :, :, 1]), 1)
        self.assertLen(np.unique(out[0, :, :, 2]), 1)
        if with_bias:
            self.assertEqual(np.unique(out[0, :, :, 0])[0], 1 * 3.0 * 3.0 + 1)
            self.assertEqual(np.unique(out[0, :, :, 1])[0], 2 * 3.0 * 3.0 + 1)
            self.assertEqual(np.unique(out[0, :, :, 2])[0], 3 * 3.0 * 3.0 + 1)
        else:
            self.assertEqual(np.unique(out[0, :, :, 0])[0], 1 * 3.0 * 3.0)
            self.assertEqual(np.unique(out[0, :, :, 1])[0], 2 * 3.0 * 3.0)
            self.assertEqual(np.unique(out[0, :, :, 2])[0], 3 * 3.0 * 3.0)
コード例 #7
0
ファイル: conv_test.py プロジェクト: shyamalschandra/haiku
  def test_computation_padding_valid(self, with_bias):
    expected_out = [[9, 9, 9], [9, 9, 9], [9, 9, 9]]

    def f():
      data = jnp.ones([1, 5, 5, 1])
      net = conv.Conv2D(
          output_channels=1,
          kernel_shape=3,
          stride=1,
          padding="VALID",
          with_bias=with_bias,
          **create_constant_initializers(1.0, 1.0, with_bias))
      return net(data)

    init_fn, apply_fn = base.transform(f)
    out = apply_fn(init_fn(random.PRNGKey(428)))

    self.assertEqual(out.shape, (1, 3, 3, 1))
    out = np.squeeze(out, axis=(0, 3))

    expected_out = np.asarray(expected_out, dtype=float)
    if with_bias:
      expected_out += 1

    np.testing.assert_equal(out, expected_out)
コード例 #8
0
 def test_bias_dims_negative_out_of_order(self):
   def f():
     mod = bias.Bias(bias_dims=[-1, -2])
     mod(jnp.ones([1, 2, 3]))
     self.assertEqual(mod.bias_shape, (2, 3))
   params = base.transform(f).init(None)
   self.assertEqual(params["bias"]["b"].shape, (2, 3))
コード例 #9
0
ファイル: conv_test.py プロジェクト: shyamalschandra/haiku
  def test_computation_padding_same(self, with_bias):
    expected_out = np.asarray([
        9, 13, 13, 13, 9, 13, 19, 19, 19, 13, 13, 19, 19, 19, 13, 13, 19, 19,
        19, 13, 9, 13, 13, 13, 9, 13, 19, 19, 19, 13, 19, 28, 28, 28, 19, 19,
        28, 28, 28, 19, 19, 28, 28, 28, 19, 13, 19, 19, 19, 13, 13, 19, 19, 19,
        13, 19, 28, 28, 28, 19, 19, 28, 28, 28, 19, 19, 28, 28, 28, 19, 13, 19,
        19, 19, 13, 13, 19, 19, 19, 13, 19, 28, 28, 28, 19, 19, 28, 28, 28, 19,
        19, 28, 28, 28, 19, 13, 19, 19, 19, 13, 9, 13, 13, 13, 9, 13, 19, 19,
        19, 13, 13, 19, 19, 19, 13, 13, 19, 19, 19, 13, 9, 13, 13, 13, 9
    ],
                              dtype=float).reshape((5, 5, 5))
    if not with_bias:
      expected_out -= 1

    def f():
      data = jnp.ones([1, 5, 5, 5, 1])
      net = conv.Conv3D(
          output_channels=1,
          kernel_shape=3,
          stride=1,
          padding="SAME",
          with_bias=with_bias,
          **create_constant_initializers(1.0, 1.0, with_bias))
      return net(data)

    init_fn, apply_fn = base.transform(f)
    out = apply_fn(init_fn(random.PRNGKey(428)))

    self.assertEqual(out.shape, (1, 5, 5, 5, 1))
    out = np.squeeze(out, axis=(0, 4))
    np.testing.assert_equal(out, expected_out)
コード例 #10
0
ファイル: conv_test.py プロジェクト: shyamalschandra/haiku
 def testIncorrectN(self, n):
   init_fn, _ = base.transform(
       lambda: conv.ConvND(n, output_channels=1, kernel_shape=3))
   with self.assertRaisesRegex(
       ValueError,
       "only support convolution operations for num_spatial_dims=1, 2 or 3"):
     init_fn(None)
コード例 #11
0
ファイル: conv_test.py プロジェクト: shyamalschandra/haiku
  def test_computation_padding_valid(self, with_bias):
    expected_out = np.asarray([
        1, 2, 3, 2, 1, 2, 4, 6, 4, 2, 3, 6, 9, 6, 3, 2, 4, 6, 4, 2, 1, 2, 3, 2,
        1, 2, 4, 6, 4, 2, 4, 8, 12, 8, 4, 6, 12, 18, 12, 6, 4, 8, 12, 8, 4, 2,
        4, 6, 4, 2, 3, 6, 9, 6, 3, 6, 12, 18, 12, 6, 9, 18, 27, 18, 9, 6, 12,
        18, 12, 6, 3, 6, 9, 6, 3, 2, 4, 6, 4, 2, 4, 8, 12, 8, 4, 6, 12, 18, 12,
        6, 4, 8, 12, 8, 4, 2, 4, 6, 4, 2, 1, 2, 3, 2, 1, 2, 4, 6, 4, 2, 3, 6, 9,
        6, 3, 2, 4, 6, 4, 2, 1, 2, 3, 2, 1.
    ]).reshape((5, 5, 5))
    if with_bias:
      expected_out += 1

    def f():
      data = jnp.ones([1, 3, 3, 3, 1])
      net = conv.Conv3DTranspose(
          output_channels=1,
          kernel_shape=3,
          stride=1,
          padding="VALID",
          with_bias=with_bias,
          **create_constant_initializers(1.0, 1.0, with_bias))
      return net(data)

    init_fn, apply_fn = base.transform(f)
    out = apply_fn(init_fn(random.PRNGKey(428)))

    self.assertEqual(out.shape, (1, 5, 5, 5, 1))
    out = np.squeeze(out, axis=(0, 4))
    np.testing.assert_equal(out, expected_out)
コード例 #12
0
ファイル: reshape_test.py プロジェクト: shyamalschandra/haiku
    def test_flatten(self):
        def f():
            return reshape.Flatten(preserve_dims=2)(jnp.zeros([2, 3, 4, 5]))

        init_fn, apply_fn = base.transform(f)
        params = init_fn(None)
        self.assertEqual(apply_fn(params).shape, (2, 3, 20))
コード例 #13
0
 def outer_fn(x):
     assert x.ndim == 2
     x = Bias()(x)
     inner = base.transform(inner_fn, state=True)
     inner_p, inner_s = lift.lift(inner.init)(base.next_rng_key(), x[0])
     vmap_inner = jax.vmap(inner.apply, in_axes=(None, None, 0))
     return vmap_inner(inner_p, inner_s, x)[0]
コード例 #14
0
    def test_tree_update_stats(self):
        def f():
            return basic.Linear(output_size=2, b_init=jnp.ones)(jnp.zeros([6]))

        init_fn, _ = base.transform(f)
        params = init_fn(random.PRNGKey(428))

        def g(x):
            """This should never update internal stats."""
            return moving_averages.EMAParamsTree(0.2)(x, update_stats=False)

        init_fn, apply_fn_g = base.transform(g, state=True)
        _, params_state = init_fn(None, params)

        # Let's modify our params.
        changed_params = tree.map_structure(lambda t: 2. * t, params)
        ema_params, params_state = apply_fn_g(None, params_state,
                                              changed_params)
        ema_params2, params_state = apply_fn_g(None, params_state,
                                               changed_params)

        # ema_params should be the same as ema_params2 with update_stats=False!
        for p1, p2 in zip(tree.flatten(ema_params2), tree.flatten(ema_params)):
            self.assertEqual(p1.shape, p2.shape)
            np.testing.assert_allclose(p1, p2)

        def h(x):
            """This will behave like normal."""
            return moving_averages.EMAParamsTree(0.2)(x, update_stats=True)

        init_fn, apply_fn_h = base.transform(h, state=True)
        _, params_state = init_fn(None, params)
        params, params_state = apply_fn_h(None, params_state, params)

        # Let's modify our params.
        changed_params = tree.map_structure(lambda t: 2. * t, params)
        ema_params, params_state = apply_fn_h(None, params_state,
                                              changed_params)
        ema_params2, params_state = apply_fn_h(None, params_state,
                                               changed_params)

        # ema_params should be different as ema_params2 with update_stats=False!
        for p1, p2 in zip(tree.flatten(ema_params2), tree.flatten(ema_params)):
            self.assertEqual(p1.shape, p2.shape)
            with self.assertRaisesRegex(AssertionError,
                                        "Not equal to tolerance"):
                np.testing.assert_allclose(p1, p2, atol=1e-6)
コード例 #15
0
ファイル: module_test.py プロジェクト: shyamalschandra/haiku
 def test_params_nested(self):
   init_fn, _ = base.transform(lambda: MultipleForwardMethods(name="outer")())  # pylint: disable=unnecessary-lambda
   params = init_fn(None)
   self.assertEqual(params,
                    {"outer/~/scalar_module": {"w": jnp.zeros([])},
                     "outer/scalar_module": {"w": jnp.zeros([])},
                     "outer/~encode/scalar_module": {"w": jnp.zeros([])},
                     "outer/~decode/scalar_module": {"w": jnp.zeros([])}})
コード例 #16
0
 def test_invalid_rng(self):
   f = base.transform(lambda: None, apply_rng=True)
   with self.assertRaisesRegex(
       ValueError, "Init must be called with an RNG as the first argument"):
     f.init("nonsense")
   with self.assertRaisesRegex(
       ValueError, "Apply must be called with an RNG as the second argument"):
     f.apply({}, "nonsense")
コード例 #17
0
ファイル: reshape_test.py プロジェクト: shyamalschandra/haiku
    def test_invalid_multiple_wildcard(self):
        def f():
            mod = reshape.Reshape(output_shape=[-1, -1])
            return mod(np.ones([1, 2, 3]))

        init_fn, _ = base.transform(f)
        with self.assertRaises(ValueError):
            init_fn(None)
コード例 #18
0
ファイル: basic_test.py プロジェクト: shyamalschandra/haiku
  def test_sequential(self):
    def f():
      seq = basic.Sequential([basic.Linear(2), jax.nn.relu])
      return seq(jnp.zeros([3, 2]))

    init_fn, apply_fn = base.transform(f)
    params = init_fn(random.PRNGKey(428))
    self.assertEqual(apply_fn(params).shape, (3, 2))
コード例 #19
0
ファイル: base_test.py プロジェクト: ibab/haiku
    def test_params(self):
        def f():
            w = base.get_parameter("w", [], init=jnp.zeros)
            return w

        init_fn, _ = base.transform(f)
        params = init_fn(None)
        self.assertEqual(params, {"~": {"w": jnp.zeros([])}})
コード例 #20
0
ファイル: reshape_test.py プロジェクト: shyamalschandra/haiku
    def test_invalid_type(self):
        def f():
            mod = reshape.Reshape(output_shape=[7, "string"])
            return mod(np.ones([1, 2, 3]))

        init_fn, _ = base.transform(f)
        with self.assertRaises(TypeError):
            init_fn(None)
コード例 #21
0
ファイル: basic_test.py プロジェクト: shyamalschandra/haiku
  def test_linear_rank3(self):
    def f():
      return basic.Linear(output_size=2)(jnp.zeros((2, 5, 6)))

    init_fn, apply_fn = base.transform(f)
    params = init_fn(random.PRNGKey(428))
    self.assertEqual(params.linear.w.shape, (6, 2))
    self.assertEqual(params.linear.b.shape, (2,))
    self.assertEqual(apply_fn(params).shape, (2, 5, 2))
コード例 #22
0
ファイル: basic_test.py プロジェクト: shyamalschandra/haiku
  def test_linear_without_bias_has_zero_in_null_space(self):
    def f():
      return basic.Linear(output_size=6, with_bias=False)(jnp.zeros((5, 6)))

    init_fn, apply_fn = base.transform(f)
    params = init_fn(random.PRNGKey(428))
    self.assertEqual(params.linear.w.shape, (6, 6))
    self.assertFalse(hasattr(params.linear, "b"))
    np.testing.assert_array_almost_equal(apply_fn(params), jnp.zeros((5, 6)))
コード例 #23
0
ファイル: reshape_test.py プロジェクト: shyamalschandra/haiku
    def test_reshape(self, preserve_dims, expected_output_shape):
        def f(inputs):
            return reshape.Reshape(output_shape=(-1, D),
                                   preserve_dims=preserve_dims)(inputs)

        init_fn, apply_fn = base.transform(f)
        params = init_fn(None, jnp.ones([B, H, W, C, D]))
        outputs = apply_fn(params, np.ones([B, H, W, C, D]))
        self.assertEqual(outputs.shape, expected_output_shape)
コード例 #24
0
ファイル: filtering_test.py プロジェクト: ibab/haiku
    def test_transforms_with_filer(self):
        # Note to make sense of test:
        #
        # out = (w0 + b0) * w1 + b1
        #     = w0 * w1 + b0 * w1 + b1
        # doutdw0 = w1
        # doutdw1 = w0 + b0
        # with w0 = 1.0, b0 = 1.5, w1 = 3.0, b1 = 4.5
        init_fn, apply_fn = base.transform(get_net)
        inputs = jnp.ones((1, 1))
        params = init_fn(jax.random.PRNGKey(428), inputs)

        df_fn = jax_fn_with_filter(
            jax_fn=jax.grad,
            f=apply_fn,
            predicate=lambda module_name, name, _: name == "w")
        df = df_fn(params, inputs)
        self.assertEqual(
            to_set(df),
            set([("first_layer/w", (3.0, )), ("second_layer/w", (2.5, ))]))

        fn = jax_fn_with_filter(
            jax_fn=jax.value_and_grad,
            f=apply_fn,
            predicate=lambda module_name, name, _: name == "w")
        v = fn(params, inputs)
        self.assertEqual(v[0], jnp.array([12.0]))
        self.assertEqual(to_set(df), to_set(v[1]))

        def get_stacked_net(x):
            y = get_net(x)
            return jnp.stack([y, 2.0 * y])

        _, apply_fn = base.transform(get_stacked_net)
        jf_fn = jax_fn_with_filter(
            jax_fn=jax.jacobian,
            f=apply_fn,
            predicate=lambda module_name, name, _: name == "w")
        jf = jf_fn(params, inputs)

        self.assertEqual(
            to_set(jf),
            set([("first_layer/w", (3.0, 6.0)),
                 ("second_layer/w", (2.5, 5.0))]))
コード例 #25
0
  def test_naked_parameter_in_tilde_collection(self):
    def net():
      w1 = base.get_parameter("w1", [], init=jnp.zeros)
      w2 = base.get_parameter("w2", [], init=jnp.ones)
      self.assertIsNot(w1, w2)

    init_fn, _ = base.transform(net)
    params = init_fn(None)
    self.assertEqual(params,
                     {"~": {"w1": jnp.zeros([]), "w2": jnp.ones([])}})
コード例 #26
0
    def test_inline_use(self):
        def f():
            return ScalarModule()()

        f = base.transform(f)

        rng = jax.random.PRNGKey(42)
        params = f.init(rng)
        w = f.apply(params)
        self.assertEqual(w, 0)
コード例 #27
0
    def test_grad_and_jit(self):
        def f(x):
            g = stateful.grad(SquareModule())(x)
            return g

        x = jnp.array(3.)
        f = base.transform(f, state=True)
        params, state = jax.jit(f.init)(None, x)
        g, state = jax.jit(f.apply)(params, state, x)
        np.testing.assert_allclose(g, 2 * x, rtol=1e-3)
コード例 #28
0
  def test_unable_to_mutate_name(self):
    def mutates_name(next_creator, name, shape, dtype, init):
      next_creator(name + "_foo", shape, dtype, init)

    init_fn, _ = base.transform(
        lambda: base.get_parameter("w", [], init=jnp.ones))

    with self.assertRaisesRegex(ValueError, "Modifying .*name.* not supported"):
      with base.custom_creator(mutates_name):
        init_fn(None)
コード例 #29
0
ファイル: base_test.py プロジェクト: ibab/haiku
    def test_inline_use(self):
        def f():
            w = base.get_parameter("w", [], init=jnp.zeros)
            return w

        f = base.transform(f)

        rng = jax.random.PRNGKey(42)
        params = f.init(rng)
        w = f.apply(params)
        self.assertEqual(w, 0)
コード例 #30
0
  def test_with_rng(self, seed):
    key = jax.random.PRNGKey(seed)
    unrelated_key = jax.random.PRNGKey(seed * 2 + 1)
    _, next_key = jax.random.split(key)
    expected_output = jax.random.uniform(next_key, ())

    def without_decorator():
      return jax.random.uniform(base.next_rng_key(), ())
    without_decorator = base.transform(without_decorator, apply_rng=True)
    without_decorator_out = without_decorator.apply(None, unrelated_key).item()

    def with_decorator():
      with base.with_rng(key):
        return jax.random.uniform(base.next_rng_key(), ())

    with_decorator = base.transform(with_decorator, apply_rng=True)
    with_decorator_out = with_decorator.apply(None, unrelated_key).item()

    self.assertNotEqual(without_decorator_out, expected_output)
    self.assertEqual(with_decorator_out, expected_output)