Esempio n. 1
0
    def forward(self, c, h, x):
        """Returns new cell state and updated output of LSTM.

        Args:
            c (~chainer.Variable): Cell states of LSTM units.
            h (~chainer.Variable): Output at the previous time step.
            x (~chainer.Variable): A new batch from the input sequence.

        Returns:
            tuple of ~chainer.Variable: Returns ``(c_new, h_new)``, where
            ``c_new`` represents new cell state, and ``h_new`` is updated
            output of LSTM units.

        """
        if self.upward.W.array is None:
            in_size = x.size // x.shape[0]
            with chainer.using_device(self.device):
                self.upward._initialize_params(in_size)
                self._initialize_params()

        lstm_in = self.upward(x)
        if h is not None:
            lstm_in += self.lateral(h)
        if c is None:
            xp = self.xp
            with chainer.using_device(self.device):
                c = variable.Variable(
                    xp.zeros((x.shape[0], self.state_size), dtype=x.dtype))
        return lstm.lstm(c, lstm_in)
Esempio n. 2
0
def _lstm(x, h, c, w, b):
    xw = _stack_weight([w[2], w[0], w[1], w[3]])
    hw = _stack_weight([w[6], w[4], w[5], w[7]])
    xb = _stack_weight([b[2], b[0], b[1], b[3]])
    hb = _stack_weight([b[6], b[4], b[5], b[7]])
    lstm_in = linear.linear(x, xw, xb) + linear.linear(h, hw, hb)
    c_bar, h_bar = lstm.lstm(c, lstm_in)
    return h_bar, c_bar
Esempio n. 3
0
    def forward(self, x):
        """Updates the internal state and returns the LSTM outputs.

        Args:
            x (~chainer.Variable): A new batch from the input sequence.

        Returns:
            ~chainer.Variable: Outputs of updated LSTM units.

        """
        if self.upward.W.array is None:
            with chainer.using_device(self.device):
                in_size = utils.size_of_shape(x.shape[1:])
                self.upward._initialize_params(in_size)
                self._initialize_params()

        batch = x.shape[0]
        lstm_in = self.upward(x)
        h_rest = None
        if self.h is not None:
            h_size = self.h.shape[0]
            if batch == 0:
                h_rest = self.h
            elif h_size < batch:
                msg = ('The batch size of x must be equal to or less than '
                       'the size of the previous state h.')
                raise TypeError(msg)
            elif h_size > batch:
                h_update, h_rest = split_axis.split_axis(self.h, [batch],
                                                         axis=0)
                lstm_in += self.lateral(h_update)
            else:
                lstm_in += self.lateral(self.h)
        if self.c is None:
            with chainer.using_device(self.device):
                self.c = variable.Variable(
                    self.xp.zeros((batch, self.state_size), dtype=x.dtype))
        self.c, y = lstm.lstm(self.c, lstm_in)

        if h_rest is None:
            self.h = y
        elif len(y.array) == 0:
            self.h = h_rest
        else:
            self.h = concat.concat([y, h_rest], axis=0)

        return y