Ejemplo n.º 1
0
    def forward(self, x):
        """Apply group normalization to given input.

        Args:
            x (~chainer.Variable): Batch tensors.
                First dimension of this value must be the size of minibatch and
                second dimension must be the number of channels.
                Moreover, this value must have one or more following
                dimensions, such as height and width.

        Returns:
            ~chainer.Variable: Output of the group normalization.

        """
        if self.gamma.array is None:
            if x.ndim < 2:
                raise ValueError('Input dimension must be at least 2, '
                                 'including batch size dimension '
                                 '(first dimension).')
            channels = x.shape[1]
            self._initialize_params(channels)

        return group_normalization.group_normalization(x, self.groups,
                                                       self.gamma, self.beta,
                                                       self.eps)
Ejemplo n.º 2
0
 def forward_preprocess(self, cb_args):
     # This method normalizes target link's weight by statistics
     link = cb_args.link
     input_variable = cb_args.args[0]
     if not self._initialized:
         if getattr(link, self.weight_name).array is None:
             if input_variable is None:
                 raise ValueError('Input variable does not exist!')
             link._initialize_params(input_variable.shape[1])
     weight = getattr(link, self.weight_name)
     with chainer.using_device(link.device):
         gamma = link.xp.ones((weight.shape[1], ), dtype=weight.dtype)
         beta = link.xp.zeros((weight.shape[1], ), dtype=weight.dtype)
     # For link.W or equivalents to be chainer.Parameter
     # consistently to users, this hook maintains a reference to
     # the unnormalized weight.
     self.original_weight = weight
     # note: `normalized_weight` is ~chainer.Variable
     normalized_weight = group_normalization.group_normalization(
         weight, groups=1, gamma=gamma, beta=beta, eps=self.eps)
     setattr(link, self.weight_name, normalized_weight)
Ejemplo n.º 3
0
    def __call__(self, x):
        """Apply group normalization to given input.

        Args:
            x (~chainer.Variable): Batch tensors.
                First dimension of this value must be the size of minibatch and
                second dimension must be the number of channels.
                Moreover, this value must have one or more following
                dimensions, such as height and width.

        Returns:
            ~chainer.Variable: Output of the group normalization.

        """
        if self.gamma.data is None:
            if x.ndim <= 2:
                raise ValueError('Input dimension must be grater than 2, '
                                 'including batch size dimension '
                                 '(first dimension).')
            channels = x.shape[1]
            self._initialize_params(channels)

        return group_normalization.group_normalization(
            x, self.groups, self.gamma, self.beta, self.eps)