Esempio n. 1
0
def test_batch_norm_invalid_dimensions(
        device, x_shape, gamma_shape, beta_shape, running_mean_shape,
        running_var_shape, axis, float_dtype):
    x, gamma, beta, running_mean, running_var = (
        _create_batch_norm_ndarray_args(
            chainerx, device, x_shape, gamma_shape, beta_shape,
            running_mean_shape, running_var_shape, float_dtype))

    with pytest.raises(chainerx.DimensionError):
        chainerx.batch_norm(
            x, gamma, beta, running_mean=running_mean, running_var=running_var,
            eps=1e-2, decay=0.9, axis=axis)
Esempio n. 2
0
def test_batch_norm_invalid_dimensions(
        device, x_shape, gamma_shape, beta_shape, running_mean_shape,
        running_var_shape, axis, float_dtype):
    x, gamma, beta, running_mean, running_var = (
        _create_batch_norm_ndarray_args(
            chainerx, device, x_shape, gamma_shape, beta_shape,
            running_mean_shape, running_var_shape, float_dtype))

    with pytest.raises(chainerx.DimensionError):
        chainerx.batch_norm(
            x, gamma, beta, running_mean=running_mean, running_var=running_var,
            eps=1e-2, decay=0.9, axis=axis)
Esempio n. 3
0
    def forward_chainerx(self, inputs):
        x, gamma, beta = inputs

        if self.is_cuda and self.contiguous == 'C':
            # Testing pre-condition.
            assert x.is_contiguous
            assert gamma.is_contiguous
            assert beta.is_contiguous

        running_mean = chainerx.array(self.running_mean,
                                      copy=True).astype(x.dtype)
        running_var = chainerx.array(self.running_var,
                                     copy=True).astype(x.dtype)

        y = chainerx.batch_norm(x,
                                gamma,
                                beta,
                                running_mean=running_mean,
                                running_var=running_var,
                                **self.optional_args)

        self.running_mean_chx = running_mean
        self.running_var_chx = running_var

        return y,
Esempio n. 4
0
 def __call__(self, x):
     return chx.batch_norm(x,
                           self.gamma,
                           self.beta,
                           running_mean=self.avg_mean,
                           running_var=self.avg_var,
                           axis=(0, 2, 3))
def test_batch_norm(device, x_shape, reduced_shape, eps, decay, axis,
                    float_dtype):
    def create_args(xp):
        return _create_batch_norm_ndarray_args(xp, device, x_shape,
                                               reduced_shape, reduced_shape,
                                               reduced_shape, reduced_shape,
                                               float_dtype)

    x_chx, gamma_chx, beta_chx, running_mean_chx, running_var_chx = (
        create_args(chainerx))
    x_np, gamma_np, beta_np, running_mean_np, running_var_np = (
        create_args(numpy))

    # Save copies of running values before updating to later check that they
    # are updated.
    initial_running_mean = running_mean_chx.copy()
    initial_running_var = running_var_chx.copy()

    optional_args = {}
    if eps is not None:
        optional_args['eps'] = eps
    if decay is not None:
        optional_args['decay'] = decay
    if axis is not None:
        optional_args['axis'] = axis

    y_chx = chainerx.batch_norm(x_chx,
                                gamma_chx,
                                beta_chx,
                                running_mean=running_mean_chx,
                                running_var=running_var_chx,
                                **optional_args)
    y_np = chainer.functions.batch_normalization(x_np,
                                                 gamma_np,
                                                 beta_np,
                                                 running_mean=running_mean_np,
                                                 running_var=running_var_np,
                                                 **optional_args).data

    # Check that the running values are updated.
    assert not numpy.allclose(chainerx.to_numpy(initial_running_mean),
                              chainerx.to_numpy(running_mean_chx))
    assert not numpy.allclose(chainerx.to_numpy(initial_running_var),
                              chainerx.to_numpy(running_var_chx))

    chainerx.testing.assert_allclose_ex(y_chx, y_np, rtol=1e-6, atol=1e-5)
    chainerx.testing.assert_allclose_ex(running_mean_chx,
                                        running_mean_np,
                                        rtol=1e-6,
                                        atol=1e-6)
    chainerx.testing.assert_allclose_ex(running_var_chx,
                                        running_var_np,
                                        rtol=1e-6,
                                        atol=1e-6)
Esempio n. 6
0
    def forward_chainerx(self, inputs):
        x, gamma, beta = inputs

        running_mean = chainerx.array(self.running_mean, copy=True)
        running_var = chainerx.array(self.running_var, copy=True)

        y = chainerx.batch_norm(
            x, gamma, beta, running_mean=running_mean, running_var=running_var,
            **self.optional_args)

        # Record running values for later checks.
        self.running_mean_chx = running_mean
        self.running_var_chx = running_var

        return y,
Esempio n. 7
0
    def forward_chainerx(self, inputs):
        x, gamma, beta = inputs

        running_mean = chainerx.array(self.running_mean, copy=True)
        running_var = chainerx.array(self.running_var, copy=True)

        y = chainerx.batch_norm(
            x, gamma, beta, running_mean=running_mean, running_var=running_var,
            **self.optional_args)

        # Record running values for later checks.
        self.running_mean_chx = running_mean
        self.running_var_chx = running_var

        return y,
Esempio n. 8
0
def test_batch_norm(
        device, x_shape, reduced_shape, eps, decay, axis, float_dtype):
    def create_args(xp):
        return _create_batch_norm_ndarray_args(
            xp, device, x_shape, reduced_shape, reduced_shape, reduced_shape,
            reduced_shape, float_dtype)

    x_chx, gamma_chx, beta_chx, running_mean_chx, running_var_chx = (
        create_args(chainerx))
    x_np, gamma_np, beta_np, running_mean_np, running_var_np = (
        create_args(numpy))

    # Save copies of running values before updating to later check that they
    # are updated.
    initial_running_mean = running_mean_chx.copy()
    initial_running_var = running_var_chx.copy()

    optional_args = {}
    if eps is not None:
        optional_args['eps'] = eps
    if decay is not None:
        optional_args['decay'] = decay
    if axis is not None:
        optional_args['axis'] = axis

    y_chx = chainerx.batch_norm(
        x_chx, gamma_chx, beta_chx, running_mean=running_mean_chx,
        running_var=running_var_chx, **optional_args)
    y_np = chainer.functions.batch_normalization(
        x_np, gamma_np, beta_np, running_mean=running_mean_np,
        running_var=running_var_np, **optional_args).data

    # Check that the running values are updated.
    assert not numpy.allclose(chainerx.to_numpy(
        initial_running_mean), chainerx.to_numpy(running_mean_chx))
    assert not numpy.allclose(chainerx.to_numpy(
        initial_running_var), chainerx.to_numpy(running_var_chx))

    chainerx.testing.assert_allclose_ex(
        y_chx, y_np, rtol=1e-6, atol=1e-5,
        float16_rtol=1e-2, float16_atol=1e-2)
    chainerx.testing.assert_allclose_ex(
        running_mean_chx, running_mean_np,
        rtol=1e-6, atol=1e-6, float16_rtol=1e-2, float16_atol=1e-2)
    chainerx.testing.assert_allclose_ex(
        running_var_chx, running_var_np,
        rtol=1e-6, atol=1e-6, float16_rtol=1e-2, float16_atol=1e-2)
Esempio n. 9
0
    def forward_chainerx(self, inputs):
        # TODO(niboshi): Support conditions implemented as fallback

        # Running statistics are required.
        if self.running_mean is None or self.running_var is None:
            return chainer.Fallback

        # Fall back if the running statistics are non-contiguous CUDA arrays
        # since they are not supported by cuDNN.
        # Assert that both running statistics belong to the same backend.
        if self.running_mean.device.backend.name == 'cuda' and not (
                self.running_mean.is_contiguous
                and self.running_var.is_contiguous):
            return chainer.Fallback

        x, gamma, beta = inputs
        axis_chx = _chainerx_compute_axis(x.ndim, gamma.ndim, self.axis)
        if not _chainerx_is_supported(x.device, axis_chx):
            return chainer.Fallback

        y = chainerx.batch_norm(
            x, gamma, beta, self.running_mean, self.running_var,
            self.eps, self.decay, axis_chx)
        return y,
Esempio n. 10
0
    def forward_chainerx(self, inputs):
        # TODO(niboshi): Support conditions implemented as fallback

        # Running statistics are required.
        if self.running_mean is None or self.running_var is None:
            return chainer.Fallback

        # Fall back if the running statistics are non-contiguous CUDA arrays
        # since they are not supported by cuDNN.
        # Assert that both running statistics belong to the same backend.
        if self.running_mean.device.backend.name == 'cuda' and not (
                self.running_mean.is_contiguous
                and self.running_var.is_contiguous):
            return chainer.Fallback

        x, gamma, beta = inputs
        axis_chx = _chainerx_compute_axis(x.ndim, gamma.ndim, self.axis)
        if not _chainerx_is_supported(x.device, axis_chx):
            return chainer.Fallback

        y = chainerx.batch_norm(
            x, gamma, beta, self.running_mean, self.running_var,
            self.eps, self.decay, axis_chx)
        return y,
Esempio n. 11
0
 def __call__(self, x):
     return chx.batch_norm(x, self.gamma, self.beta,
                           running_mean=self.avg_mean,
                           running_var=self.avg_var,
                           axis=(0, 2, 3))