コード例 #1
0
ファイル: stax_test.py プロジェクト: ROCmSoftwarePlatform/jax
  def testBatchNormNoScaleOrCenter(self):
    key = random.PRNGKey(0)
    axes = (0, 1, 2)
    init_fun, apply_fun = stax.BatchNorm(axis=axes, center=False, scale=False)
    input_shape = (4, 5, 6, 7)
    inputs = random_inputs(self.rng(), input_shape)

    out_shape, params = init_fun(key, input_shape)
    out = apply_fun(params, inputs)
    means = np.mean(out, axis=(0, 1, 2))
    std_devs = np.std(out, axis=(0, 1, 2))
    assert np.allclose(means, np.zeros_like(means), atol=1e-4)
    assert np.allclose(std_devs, np.ones_like(std_devs), atol=1e-4)
コード例 #2
0
ファイル: stax_test.py プロジェクト: xueeinstein/jax
    def testBatchNormShapeNHWC(self):
        key = random.PRNGKey(0)
        init_fun, apply_fun = stax.BatchNorm(axis=(0, 1, 2))
        input_shape = (4, 5, 6, 7)
        inputs = random_inputs(self.rng(), input_shape)

        out_shape, params = init_fun(key, input_shape)
        out = apply_fun(params, inputs)

        self.assertEqual(out_shape, input_shape)
        beta, gamma = params
        self.assertEqual(beta.shape, (7, ))
        self.assertEqual(gamma.shape, (7, ))
        self.assertEqual(out_shape, out.shape)
コード例 #3
0
ファイル: stax_test.py プロジェクト: xueeinstein/jax
    def testBatchNormShapeNCHW(self):
        key = random.PRNGKey(0)
        # Regression test for https://github.com/google/jax/issues/461
        init_fun, apply_fun = stax.BatchNorm(axis=(0, 2, 3))
        input_shape = (4, 5, 6, 7)
        inputs = random_inputs(self.rng(), input_shape)

        out_shape, params = init_fun(key, input_shape)
        out = apply_fun(params, inputs)

        self.assertEqual(out_shape, input_shape)
        beta, gamma = params
        self.assertEqual(beta.shape, (5, ))
        self.assertEqual(gamma.shape, (5, ))
        self.assertEqual(out_shape, out.shape)
コード例 #4
0
 def build(self, input_shape: tuple):
     init_fun, self.apply_fn = stax.BatchNorm(axis=self.axis)
     self.shape, self._params = init_fun(rng=self.key,
                                         input_shape=input_shape)
     self.input_shape = self.shape
     self.built = True