예제 #1
0
    def forward(self, x, finetune=False):
        if self.gamma is not None:
            gamma = self.gamma
        else:
            with cuda.get_device_from_id(self._device_id):
                gamma = self.xp.ones(
                    self.avg_mean.shape, dtype=x.dtype)

        if self.beta is not None:
            beta = self.beta
        else:
            with cuda.get_device_from_id(self._device_id):
                beta = self.xp.zeros(
                    self.avg_mean.shape, dtype=x.dtype)

        if configuration.config.train:
            if finetune:
                self.N += 1
                decay = 1. - 1. / self.N
            else:
                decay = self.decay

            ret = batch_renormalization.batch_renormalization(
                x, gamma, beta, self.rmax, self.dmax,
                self.eps, self.avg_mean, self.avg_var, decay,
                update_statistics=True)
        else:
            # Use running average statistics or fine-tuned statistics.
            mean = self.avg_mean
            var = self.avg_var
            ret = batch_normalization.fixed_batch_normalization(
                x, gamma, beta, mean, var, self.eps)
        return ret
예제 #2
0
    def forward(self, inputs, device):
        x, gamma, beta, gy = inputs
        running_mean = device.send(self.running_mean.copy())
        running_var = device.send(self.running_var.copy())
        y = batch_renormalization.batch_renormalization(
            x,
            gamma,
            beta,
            self.rmax,
            self.dmax,
            eps=self.eps,
            running_mean=running_mean,
            running_var=running_var,
            update_statistics=self.update_statistics)

        # backward gradients
        gx, ggamma, gbeta = self._compute_backward(x, gamma, beta, y, gy)

        return (
            y,
            chainer.Variable(running_mean),
            chainer.Variable(running_var),
            chainer.Variable(gx),
            chainer.Variable(ggamma),
            chainer.Variable(gbeta),
        )
예제 #3
0
    def forward(self, x, finetune=False):
        if self.gamma is not None:
            gamma = self.gamma
        else:
            with chainer.using_device(self.device):
                gamma = self.xp.ones(
                    self.avg_mean.shape, dtype=x.dtype)

        if self.beta is not None:
            beta = self.beta
        else:
            with chainer.using_device(self.device):
                beta = self.xp.zeros(
                    self.avg_mean.shape, dtype=x.dtype)

        if configuration.config.train:
            if finetune:
                self.N += 1
                decay = 1. - 1. / self.N
            else:
                decay = self.decay

            ret = batch_renormalization.batch_renormalization(
                x, gamma, beta, self.rmax, self.dmax,
                self.eps, self.avg_mean, self.avg_var, decay,
                update_statistics=True)
        else:
            # Use running average statistics or fine-tuned statistics.
            mean = self.avg_mean
            var = self.avg_var
            ret = batch_normalization.fixed_batch_normalization(
                x, gamma, beta, mean, var, self.eps)
        return ret
 def f_tested(x, gamma, beta, running_mean, running_var):
     return batch_renormalization.batch_renormalization(
         x,
         gamma,
         beta,
         self.rmax,
         self.dmax,
         eps=self.eps,
         running_mean=running_mean,
         running_var=running_var)
    def forward(self, x, finetune=False):
        if self.gamma is not None:
            gamma = self.gamma
        else:
            with chainer.using_device(self.device):
                gamma = self.xp.ones(self.avg_mean.shape, dtype=x.dtype)

        if self.beta is not None:
            beta = self.beta
        else:
            with chainer.using_device(self.device):
                beta = self.xp.zeros(self.avg_mean.shape, dtype=x.dtype)

        if configuration.config.train:
            if finetune:
                self.N += 1
                decay = 1. - 1. / self.N
            else:
                decay = self.decay

            avg_mean = self.avg_mean
            avg_var = self.avg_var
            update_statistics = True

            if chainer.config.in_recomputing:
                # Do not update statistics when extra forward computation is
                # called.
                if finetune:
                    self.N -= 1  # Revert the count
                avg_mean = self._prev_avg_mean
                avg_var = self._prev_avg_var
                update_statistics = False
            elif chainer.config._will_recompute:
                self._prev_avg_mean = avg_mean.copy()
                self._prev_avg_var = avg_var.copy()

            ret = batch_renormalization.batch_renormalization(
                x,
                gamma,
                beta,
                self.rmax,
                self.dmax,
                self.eps,
                avg_mean,
                avg_var,
                decay,
                update_statistics=update_statistics)
        else:
            # Use running average statistics or fine-tuned statistics.
            mean = self.avg_mean
            var = self.avg_var
            ret = batch_normalization.fixed_batch_normalization(
                x, gamma, beta, mean, var, self.eps)
        return ret
예제 #6
0
    def forward(self, x, finetune=False):
        if self.gamma is not None:
            gamma = self.gamma
        else:
            with chainer.using_device(self.device):
                gamma = self.xp.ones(
                    self.avg_mean.shape, dtype=x.dtype)

        if self.beta is not None:
            beta = self.beta
        else:
            with chainer.using_device(self.device):
                beta = self.xp.zeros(
                    self.avg_mean.shape, dtype=x.dtype)

        if configuration.config.train:
            if finetune:
                self.N += 1
                decay = 1. - 1. / self.N
            else:
                decay = self.decay

            avg_mean = self.avg_mean
            avg_var = self.avg_var

            if chainer.config.in_recomputing:
                # Do not update statistics when extra forward computation is
                # called.
                if finetune:
                    self.N -= 1  # Revert the count
                avg_mean = self.xp.zeros_like(self.avg_mean)
                avg_var = self.xp.zeros_like(self.avg_var)

            ret = batch_renormalization.batch_renormalization(
                x, gamma, beta, self.rmax, self.dmax,
                self.eps, avg_mean, avg_var, decay,
                update_statistics=True)
        else:
            # Use running average statistics or fine-tuned statistics.
            mean = self.avg_mean
            var = self.avg_var
            ret = batch_normalization.fixed_batch_normalization(
                x, gamma, beta, mean, var, self.eps)
        return ret
예제 #7
0
    def check_forward(self, args):
        with chainer.using_config('train',  self.train):
            y = batch_renormalization.batch_renormalization(
                *[chainer.Variable(i) for i in args],
                rmax=self.rmax, dmax=self.dmax, running_mean=self.running_mean,
                running_var=self.running_var, decay=self.decay, eps=self.eps)
        self.assertEqual(y.data.dtype, self.dtype)

        sigma_batch = numpy.sqrt(self.var)
        running_sigma = numpy.sqrt(self.running_var + self.eps)
        r = numpy.clip(sigma_batch / running_sigma, 1.0 / self.rmax, self.rmax)
        d = numpy.clip((self.mean - self.running_mean) / running_sigma,
                       -self.dmax, self.dmax)
        y_expect = _batch_renormalization(
            self.expander, self.gamma, self.beta, self.x, self.mean, self.var,
            r[self.expander], d[self.expander])

        testing.assert_allclose(
            y_expect, y.data, **self.check_forward_options)
예제 #8
0
    def check_forward(self, args):
        with chainer.using_config('train',  self.train):
            y = batch_renormalization.batch_renormalization(
                *[chainer.Variable(i) for i in args],
                rmax=self.rmax, dmax=self.dmax, running_mean=self.running_mean,
                running_var=self.running_var, decay=self.decay, eps=self.eps)
        self.assertEqual(y.data.dtype, self.dtype)

        sigma_batch = numpy.sqrt(self.var)
        running_sigma = numpy.sqrt(self.running_var + self.eps)
        r = numpy.clip(sigma_batch / running_sigma, 1.0 / self.rmax, self.rmax)
        d = numpy.clip((self.mean - self.running_mean) / running_sigma,
                       -self.dmax, self.dmax)
        y_expect = _batch_renormalization(
            self.expander, self.gamma, self.beta, self.x, self.mean, self.var,
            r[self.expander], d[self.expander])

        testing.assert_allclose(
            y_expect, y.data, **self.check_forward_options)
예제 #9
0
    def forward(self, inputs, device):
        x, gamma, beta, gy = inputs
        running_mean = device.send(self.running_mean.copy())
        running_var = device.send(self.running_var.copy())
        y = batch_renormalization.batch_renormalization(
            x, gamma, beta,
            self.rmax, self.dmax,
            eps=self.eps,
            running_mean=running_mean,
            running_var=running_var,
            update_statistics=self.update_statistics)

        # backward gradients
        gx, ggamma, gbeta = self._compute_backward(x, gamma, beta, y, gy)

        return (
            y,
            chainer.Variable(running_mean),
            chainer.Variable(running_var),
            chainer.Variable(gx),
            chainer.Variable(ggamma),
            chainer.Variable(gbeta),
        )
예제 #10
0
 def f_tested(x, gamma, beta, running_mean, running_var):
     return batch_renormalization.batch_renormalization(
         x, gamma, beta, self.rmax, self.dmax,
         eps=self.eps, running_mean=running_mean,
         running_var=running_var)