示例#1
0
    def __call__(self, *xs):
        """Applies broadcasted elementwise summation.

        Args:
            xs (list of ~chainer.Variable): Input variables whose length should
                be one if the link has a learnable bias parameter, otherwise
                should be two.
        """
        axis = self.axis

        # Case of only one argument where b is a learnt parameter.
        if hasattr(self, 'b'):
            if chainer.is_debug():
                assert len(xs) == 1
            x, = xs
            b = self.b
            return bias.bias(x, b, axis)
        # Case of two arguments where b is given as an argument.
        else:
            if chainer.is_debug():
                assert len(xs) == 2
            x, y = xs
            return bias.bias(x, y, axis)
示例#2
0
    def __call__(self, x):
        """Apply layer normalization to given input.

        Args:
            x (~chainer.Variable): Batch vectors.
                Shape of this value must be `(batch_size, unit_size)`,
                e.g., the output of :func:`~chainer.functions.linear`.

        Returns:
            ~chainer.Variable: Output of the layer normalization.

        """
        if self.gamma.data is None:
            self._initialize_params(x.size // x.shape[0])

        normalized = self._normalize(x)
        return bias.bias(scale.scale(normalized, self.gamma), self.beta)
    def __call__(self, x):
        """Apply layer normalization to given input.

        Args:
            x (~chainer.Variable): Batch vectors.
                Shape of this value must be `(batch_size, unit_size)`,
                e.g., the output of :func:`~chainer.functions.linear`.

        Returns:
            ~chainer.Variable: Output of the layer normalization.

        """
        if self.has_uninitialized_params:
            with cuda.get_device_from_id(self._device_id):
                self._initialize_params(x.size // x.shape[0])

        normalized = self._normalize(x)
        return bias.bias(scale.scale(normalized, self.gamma), self.beta)