Пример #1
0
    def __call__(self, x, finetune=False):
        if self.gamma is not None:
            gamma = self.gamma
        else:
            with cuda.get_device_from_id(self._device_id):
                gamma = self.xp.ones(self.avg_mean.shape, dtype=x.dtype)

        if self.beta is not None:
            beta = self.beta
        else:
            with cuda.get_device_from_id(self._device_id):
                beta = self.xp.zeros(self.avg_mean.shape, dtype=x.dtype)

        if configuration.config.train:
            if finetune:
                self.N += 1
                decay = 1. - 1. / self.N
            else:
                decay = self.decay

            func = batch_renormalization.BatchRenormalizationFunction(
                self.eps, self.avg_mean, self.avg_var, decay, self.rmax,
                self.dmax)
            ret = func(x, gamma, beta)

            self.avg_mean[:] = func.running_mean
            self.avg_var[:] = func.running_var
        else:
            # Use running average statistics or fine-tuned statistics.
            mean = self.avg_mean
            var = self.avg_var
            ret = batch_normalization.fixed_batch_normalization(
                x, gamma, beta, mean, var, self.eps)
        return ret
Пример #2
0
 def check_backward(self, args, y_grad):
     with chainer.using_config('train', self.train):
         gradient_check.check_backward(
             batch_renormalization.BatchRenormalizationFunction(
                 mean=None,
                 var=None,
                 decay=self.decay,
                 eps=self.eps,
                 rmax=self.rmax,
                 dmax=self.dmax,
                 freeze_running_statistics=True), args, y_grad,
             **self.check_backward_options)
Пример #3
0
 def check_backward(self, args, y_grad):
     with chainer.using_config('train',  self.train):
         # Freezing the update of running statistics is needed in order to
         # make gradient check work, since the parameters r and d in batch
         # renormalization are calculated from the input, but should be
         # treated as constants during gradient computation, as stated in
         # the paper.
         gradient_check.check_backward(
             batch_renormalization.BatchRenormalizationFunction(
                 mean=self.running_mean, var=self.running_var,
                 decay=self.decay, eps=self.eps, rmax=self.rmax,
                 dmax=self.dmax,
                 freeze_running_statistics=True), args, y_grad,
             **self.check_backward_options)
Пример #4
0
    def __call__(self, x, finetune=False):
        if hasattr(self, 'gamma'):
            gamma = self.gamma
        else:
            with cuda.get_device(self._device_id):
                gamma = variable.Variable(self.xp.ones(
                    self.avg_mean.shape, dtype=x.dtype))
        if hasattr(self, 'beta'):
            beta = self.beta
        else:
            with cuda.get_device(self._device_id):
                beta = variable.Variable(self.xp.zeros(
                    self.avg_mean.shape, dtype=x.dtype))

        if configuration.config.train:
            if finetune:
                self.N += 1
                decay = 1. - 1. / self.N
            else:
                decay = self.decay

            func = batch_renormalization.BatchRenormalizationFunction(
                self.eps, self.avg_mean, self.avg_var, decay,
                self.rmax, self.dmax, self.freeze_running_statistics)
            if self.freeze_running_statistics:
                func.r = self.r
                func.d = self.d
            ret = func(x, gamma, beta)
            if self.freeze_running_statistics and self.r is None:
                self.r = func.r
                self.d = func.d

            self.avg_mean[:] = func.running_mean
            self.avg_var[:] = func.running_var
        else:
            # Use running average statistics or fine-tuned statistics.
            mean = variable.Variable(self.avg_mean)
            var = variable.Variable(self.avg_var)
            ret = batch_renormalization.fixed_batch_renormalization(
                x, gamma, beta, mean, var, self.eps)
        return ret