Exemple #1
0
    def __call__(self, *cshsx):
        """Returns new cell state and output of Child-Sum TreeLSTM.

        Args:
            cshsx (list of :class:`~chainer.Variable`): Variable arguments
                which include all cell vectors and all output vectors of
                variable children, and an input vector.

        Returns:
            tuple of ~chainer.Variable: Returns
            :math:`(c_{new}, h_{new})`, where :math:`c_{new}` represents
            new cell state vector, and :math:`h_{new}` is new output
            vector.

        """

        cs = cshsx[:len(cshsx) // 2]
        hs = cshsx[len(cshsx) // 2:-1]
        x = cshsx[-1]
        assert (len(cshsx) % 2 == 1)
        assert (len(cs) == len(hs))

        if x is None:
            if any(c is not None for c in cs):
                base = [c for c in cs if c is not None][0]
            elif any(h is not None for h in hs):
                base = [h for h in hs if h is not None][0]
            else:
                raise ValueError('All inputs (cs, hs, x) are None.')
            batchsize, dtype = base.shape[0], base.dtype
            x = self.xp.zeros((batchsize, self.in_size), dtype=dtype)

        W_x_in = self.W_x(x)
        W_x_aio_in, W_x_f_in = split_axis.split_axis(W_x_in,
                                                     [3 * self.state_size],
                                                     axis=1)

        if len(hs) == 0:
            aio_in = W_x_aio_in
            a, i, o = split_axis.split_axis(aio_in, 3, axis=1)
            c = sigmoid.sigmoid(i) * tanh.tanh(a)
            h = sigmoid.sigmoid(o) * tanh.tanh(c)
            return c, h

        hs = self._pad_zero_nodes(hs, (x.shape[0], self.state_size),
                                  dtype=x.dtype)
        cs = self._pad_zero_nodes(cs, (x.shape[0], self.state_size),
                                  dtype=x.dtype)

        aio_in = self.W_h_aio(sum(hs)) + W_x_aio_in
        W_h_fs_in = concat.concat(split_axis.split_axis(self.W_h_f(
            concat.concat(hs, axis=0)),
                                                        len(hs),
                                                        axis=0),
                                  axis=1)
        f_in = W_h_fs_in + \
            concat.concat([W_x_f_in] * len(hs), axis=1)
        tree_lstm_in = concat.concat([aio_in, f_in], axis=1)

        return tree_lstm.tree_lstm(*(cs + (tree_lstm_in, )))
Exemple #2
0
    def forward(self, *cshsx):
        """Returns new cell state and output of Child-Sum TreeLSTM.

        Args:
            cshsx (list of :class:`~chainer.Variable`): Variable arguments
                which include all cell vectors and all output vectors of
                variable children, and an input vector.

        Returns:
            tuple of ~chainer.Variable: Returns
            :math:`(c_{new}, h_{new})`, where :math:`c_{new}` represents
            new cell state vector, and :math:`h_{new}` is new output
            vector.

        """

        cs = cshsx[:len(cshsx) // 2]
        hs = cshsx[len(cshsx) // 2:-1]
        x = cshsx[-1]
        assert(len(cshsx) % 2 == 1)
        assert(len(cs) == len(hs))

        if x is None:
            if any(c is not None for c in cs):
                base = [c for c in cs if c is not None][0]
            elif any(h is not None for h in hs):
                base = [h for h in hs if h is not None][0]
            else:
                raise ValueError('All inputs (cs, hs, x) are None.')
            batchsize, dtype = base.shape[0], base.dtype
            x = self.xp.zeros(
                (batchsize, self.in_size), dtype=dtype)

        W_x_in = self.W_x(x)
        W_x_aio_in, W_x_f_in = split_axis.split_axis(
            W_x_in, [3 * self.state_size], axis=1)

        if len(hs) == 0:
            aio_in = W_x_aio_in
            a, i, o = split_axis.split_axis(aio_in, 3, axis=1)
            c = sigmoid.sigmoid(i) * tanh.tanh(a)
            h = sigmoid.sigmoid(o) * tanh.tanh(c)
            return c, h

        hs = self._pad_zero_nodes(
            hs, (x.shape[0], self.state_size), dtype=x.dtype)
        cs = self._pad_zero_nodes(
            cs, (x.shape[0], self.state_size), dtype=x.dtype)

        aio_in = self.W_h_aio(sum(hs)) + W_x_aio_in
        W_h_fs_in = concat.concat(split_axis.split_axis(
            self.W_h_f(concat.concat(hs, axis=0)), len(hs), axis=0),
            axis=1)
        f_in = W_h_fs_in + \
            concat.concat([W_x_f_in] * len(hs), axis=1)
        tree_lstm_in = concat.concat([aio_in, f_in], axis=1)

        return tree_lstm.tree_lstm(*(cs + (tree_lstm_in, )))
Exemple #3
0
    def forward(self, *cshsx):
        """Returns new cell state and output of N-ary TreeLSTM.

        Args:
            cshsx (list of :class:`~chainer.Variable`): Arguments which include
                all cell vectors and all output vectors of fixed-length
                children, and an input vector. The number of arguments must be
                same as ``n_ary * 2 + 1``.

        Returns:
            tuple of ~chainer.Variable: Returns :math:`(c_{new}, h_{new})`,
            where :math:`c_{new}` represents new cell state vector,
            and :math:`h_{new}` is new output vector.

        """

        assert(len(cshsx) == self.n_ary * 2 + 1)
        cs = cshsx[:self.n_ary]
        hs = cshsx[self.n_ary:-1]
        x = cshsx[-1]

        if x is None:
            if any(c is not None for c in cs):
                base = [c for c in cs if c is not None][0]
            elif any(h is not None for h in hs):
                base = [h for h in hs if h is not None][0]
            else:
                raise ValueError('All inputs (cs, hs, x) are None.')
            batchsize, dtype = base.shape[0], base.dtype
            x = self.xp.zeros(
                (batchsize, self.in_size), dtype=dtype)

        tree_lstm_in = self.W_x(x)

        for i, h in enumerate(hs, start=1):
            if h is not None:
                tree_lstm_in += getattr(self, 'W_h{}'.format(i))(h)

        cs = self._pad_zero_nodes(
            cs, (x.shape[0], self.state_size), dtype=x.dtype)

        return tree_lstm.tree_lstm(*(cs + (tree_lstm_in, )))
    def __call__(self, *cshsx):
        """Returns new cell state and output of N-ary TreeLSTM.

        Args:
            cshsx (list of :class:`~chainer.Variable`): Arguments which include
                all cell vectors and all output vectors of fixed-length
                children, and an input vector. The number of arguments must be
                same as ``n_ary * 2 + 1``.

        Returns:
            tuple of ~chainer.Variable: Returns :math:`(c_{new}, h_{new})`,
                where :math:`c_{new}` represents new cell state vector,
                and :math:`h_{new}` is new output vector.

        """

        assert (len(cshsx) == self.n_ary * 2 + 1)
        cs = cshsx[:self.n_ary]
        hs = cshsx[self.n_ary:-1]
        x = cshsx[-1]

        if x is None:
            if any(c is not None for c in cs):
                base = [c for c in cs if c is not None][0]
            elif any(h is not None for h in hs):
                base = [h for h in hs if h is not None][0]
            else:
                raise ValueError('All inputs are None.')
            batchsize, dtype = base.shape[0], base.dtype
            x = self.xp.zeros((batchsize, self.in_size), dtype=dtype)

        tree_lstm_in = self.W_x(x)

        for i, h in enumerate(hs, start=1):
            if h is not None:
                tree_lstm_in += getattr(self, 'W_h{}'.format(i))(h)

        cs = self._pad_zero_nodes(cs, (x.shape[0], self.state_size),
                                  dtype=x.dtype)

        return tree_lstm.tree_lstm(*(cs + (tree_lstm_in, )))