示例#1
0
 def forward(self, inputs, device):
     x, gamma, beta, mean, var = inputs
     with testing.assert_warns(DeprecationWarning):
         y = batch_renormalization.fixed_batch_renormalization(x,
                                                               gamma,
                                                               beta,
                                                               mean,
                                                               var,
                                                               eps=self.eps)
     return y,
示例#2
0
    def check_forward(self, args):
        with chainer.using_config('train', self.train):
            y = batch_renormalization.fixed_batch_renormalization(
                *[chainer.Variable(i) for i in args], eps=self.eps)
        self.assertEqual(y.data.dtype, self.dtype)

        y_expect = _batch_renormalization(self.expander, self.gamma, self.beta,
                                          self.x, self.mean, self.var, 1, 0)

        testing.assert_allclose(y_expect, y.data, **self.check_forward_options)
    def check_forward(self, args):
        with chainer.using_config('train',  self.train):
            y = batch_renormalization.fixed_batch_renormalization(
                *[chainer.Variable(i) for i in args],
                eps=self.eps)
        self.assertEqual(y.data.dtype, self.dtype)

        y_expect = _batch_renormalization(
            self.expander, self.gamma, self.beta, self.x, self.mean, self.var,
            1, 0)

        testing.assert_allclose(
            y_expect, y.data, **self.check_forward_options)
示例#4
0
    def __call__(self, x, finetune=False):
        if self.gamma is not None:
            gamma = self.gamma
        else:
            with cuda.get_device(self._device_id):
                gamma = variable.Variable(
                    self.xp.ones(self.avg_mean.shape, dtype=x.dtype))

        if self.beta is not None:
            beta = self.beta
        else:
            with cuda.get_device(self._device_id):
                beta = variable.Variable(
                    self.xp.zeros(self.avg_mean.shape, dtype=x.dtype))

        if configuration.config.train:
            if finetune:
                self.N += 1
                decay = 1. - 1. / self.N
            else:
                decay = self.decay

            func = batch_renormalization.BatchRenormalizationFunction(
                self.eps, self.avg_mean, self.avg_var, decay, self.rmax,
                self.dmax, self.freeze_running_statistics)
            if self.freeze_running_statistics:
                func.r = self.r
                func.d = self.d
            ret = func(x, gamma, beta)
            if self.freeze_running_statistics and self.r is None:
                self.r = func.r
                self.d = func.d

            self.avg_mean[:] = func.running_mean
            self.avg_var[:] = func.running_var
        else:
            # Use running average statistics or fine-tuned statistics.
            mean = variable.Variable(self.avg_mean)
            var = variable.Variable(self.avg_var)
            ret = batch_renormalization.fixed_batch_renormalization(
                x, gamma, beta, mean, var, self.eps)
        return ret
示例#5
0
    def __call__(self, x, finetune=False):
        if self.gamma is not None:
            gamma = self.gamma
        else:
            with cuda.get_device_from_id(self._device_id):
                gamma = variable.Variable(self.xp.ones(
                    self.avg_mean.shape, dtype=x.dtype))

        if self.beta is not None:
            beta = self.beta
        else:
            with cuda.get_device_from_id(self._device_id):
                beta = variable.Variable(self.xp.zeros(
                    self.avg_mean.shape, dtype=x.dtype))

        if configuration.config.train:
            if finetune:
                self.N += 1
                decay = 1. - 1. / self.N
            else:
                decay = self.decay

            func = batch_renormalization.BatchRenormalizationFunction(
                self.eps, self.avg_mean, self.avg_var, decay,
                self.rmax, self.dmax, self.freeze_running_statistics)
            if self.freeze_running_statistics:
                func.r = self.r
                func.d = self.d
            ret = func(x, gamma, beta)
            if self.freeze_running_statistics and self.r is None:
                self.r = func.r
                self.d = func.d

            self.avg_mean[:] = func.running_mean
            self.avg_var[:] = func.running_var
        else:
            # Use running average statistics or fine-tuned statistics.
            mean = variable.Variable(self.avg_mean)
            var = variable.Variable(self.avg_var)
            ret = batch_renormalization.fixed_batch_renormalization(
                x, gamma, beta, mean, var, self.eps)
        return ret
 def _forward(self, *args):
     with testing.assert_warns(DeprecationWarning):
         return batch_renormalization.fixed_batch_renormalization(
             *args, eps=self.eps)
 def _forward(self, *args):
     with testing.assert_warns(DeprecationWarning):
         return batch_renormalization.fixed_batch_renormalization(
             *args, eps=self.eps)
 def forward(self, inputs, device):
     x, gamma, beta, mean, var = inputs
     with testing.assert_warns(DeprecationWarning):
         y = batch_renormalization.fixed_batch_renormalization(
             x, gamma, beta, mean, var, eps=self.eps)
     return y,
示例#9
0
 def _forward(self, *args):
     return batch_renormalization.fixed_batch_renormalization(*args,
                                                              eps=self.eps)