예제 #1
0
def n_step_lstm_base(
        n_layers, dropout_ratio, hx, cx, ws, bs, xs, train, use_cudnn,
        use_bi_direction):
    """Base function for Stack LSTM/BiLSTM functions.

    This function is used at :func:`chainer.functions.n_step_lstm` and
    :func:`chainer.functions.n_step_bilstm`.
    This function's behavior depends on following arguments,
    ``activation`` and ``use_bi_direction``.

    Args:
        n_layers(int): Number of layers.
        dropout_ratio(float): Dropout ratio.
        hx (chainer.Variable): Variable holding stacked hidden states.
            Its shape is ``(S, B, N)`` where ``S`` is number of layers and is
            equal to ``n_layers``, ``B`` is mini-batch size, and ``N`` is
            dimention of hidden units.
        cx (chainer.Variable): Variable holding stacked cell states.
            It has the same shape as ``hx``.
        ws (list of list of chainer.Variable): Weight matrices. ``ws[i]``
            represents weights for i-th layer.
            Each ``ws[i]`` is a list containing eight matrices.
            ``ws[i][j]`` is corresponding with ``W_j`` in the equation.
            Only ``ws[0][j]`` where ``0 <= j < 4`` is ``(I, N)`` shape as they
            are multiplied with input variables. All other matrices has
            ``(N, N)`` shape.
        bs (list of list of chainer.Variable): Bias vectors. ``bs[i]``
            represnents biases for i-th layer.
            Each ``bs[i]`` is a list containing eight vectors.
            ``bs[i][j]`` is corresponding with ``b_j`` in the equation.
            Shape of each matrix is ``(N,)`` where ``N`` is dimention of
            hidden units.
        xs (list of chainer.Variable): A list of :class:`~chainer.Variable`
            holding input values. Each element ``xs[t]`` holds input value
            for time ``t``. Its shape is ``(B_t, I)``, where ``B_t`` is
            mini-batch size for time ``t``, and ``I`` is size of input units.
            Note that this functions supports variable length sequences.
            When sequneces has different lengths, sort sequences in descending
            order by length, and transpose the sorted sequence.
            :func:`~chainer.functions.transpose_sequence` transpose a list
            of :func:`~chainer.Variable` holding sequence.
            So ``xs`` needs to satisfy
            ``xs[t].shape[0] >= xs[t + 1].shape[0]``.
        train (bool): If ``True``, this function executes dropout.
        use_cudnn (bool): If ``True``, this function uses cuDNN if available.
        use_bi_direction (bool): If ``True``, this function uses Bi-directional
            LSTM.

    Returns:
        tuple: This functions returns a tuple concaining three elements,
            ``hy``, ``cy`` and ``ys``.
            - ``hy`` is an updated hidden states whose shape is same as ``hx``.
            - ``cy`` is an updated cell states whose shape is same as ``cx``.
            - ``ys`` is a list of :class:`~chainer.Variable` . Each element
              ``ys[t]`` holds hidden states of the last layer corresponding
              to an input ``xs[t]``. Its shape is ``(B_t, N)`` where ``B_t`` is
              mini-batch size for time ``t``, and ``N`` is size of hidden
              units. Note that ``B_t`` is the same value as ``xs[t]``.

    .. seealso::

       :func:`chainer.functions.n_step_lstm`
       :func:`chainer.functions.n_step_bilstm`

    """

    xp = cuda.get_array_module(hx, hx.data)

    if use_cudnn and xp is not numpy and cuda.cudnn_enabled and \
       _cudnn_version >= 5000:
        states = get_random_state().create_dropout_states(dropout_ratio)
        # flatten all input variables
        inputs = tuple(itertools.chain(
            (hx, cx),
            itertools.chain.from_iterable(ws),
            itertools.chain.from_iterable(bs),
            xs))
        if use_bi_direction:
            rnn = NStepBiLSTM(n_layers, states, train=train)
        else:
            rnn = NStepLSTM(n_layers, states, train=train)

        ret = rnn(*inputs)
        hy, cy = ret[:2]
        ys = ret[2:]
        return hy, cy, ys

    else:
        direction = 2 if use_bi_direction else 1
        split_size = n_layers * direction
        hx = split_axis.split_axis(hx, split_size, axis=0, force_tuple=True)
        hx = [reshape.reshape(h, h.shape[1:]) for h in hx]
        cx = split_axis.split_axis(cx, split_size, axis=0, force_tuple=True)
        cx = [reshape.reshape(c, c.shape[1:]) for c in cx]

        xws = [_stack_weight([w[2], w[0], w[1], w[3]]) for w in ws]
        hws = [_stack_weight([w[6], w[4], w[5], w[7]]) for w in ws]
        xbs = [_stack_weight([b[2], b[0], b[1], b[3]]) for b in bs]
        hbs = [_stack_weight([b[6], b[4], b[5], b[7]]) for b in bs]

        xs_next = xs
        hy = []
        cy = []
        for layer in six.moves.range(n_layers):

            def _one_directional_loop(di):
                # di=0, forward LSTM
                # di=1, backward LSTM
                h_list = []
                c_list = []
                layer_idx = direction * layer + di
                h = hx[layer_idx]
                c = cx[layer_idx]
                if di == 0:
                    xs_list = xs_next
                else:
                    xs_list = reversed(xs_next)
                for x in xs_list:
                    batch = x.shape[0]
                    if h.shape[0] > batch:
                        h, h_rest = split_axis.split_axis(h, [batch], axis=0)
                        c, c_rest = split_axis.split_axis(c, [batch], axis=0)
                    else:
                        h_rest = None
                        c_rest = None

                    if layer != 0:
                        x = dropout.dropout(x, ratio=dropout_ratio,
                                            train=train)
                    lstm_in = linear.linear(x, xws[layer_idx],
                                            xbs[layer_idx]) + \
                        linear.linear(h, hws[layer_idx], hbs[layer_idx])

                    c_bar, h_bar = lstm.lstm(c, lstm_in)
                    if h_rest is not None:
                        h = concat.concat([h_bar, h_rest], axis=0)
                        c = concat.concat([c_bar, c_rest], axis=0)
                    else:
                        h = h_bar
                        c = c_bar
                    h_list.append(h_bar)
                    c_list.append(c_bar)
                return h, c, h_list, c_list

            h, c, h_forward, c_forward = _one_directional_loop(di=0)
            hy.append(h)
            cy.append(c)

            if use_bi_direction:
                # BiLSTM
                h, c, h_backward, c_backward = _one_directional_loop(di=1)
                hy.append(h)
                cy.append(c)

                h_backward.reverse()
                # concat
                xs_next = [concat.concat([hfi, hbi], axis=1) for (hfi, hbi) in
                           zip(h_forward, h_backward)]
            else:
                # Uni-directional RNN
                xs_next = h_forward

        ys = xs_next
        hy = stack.stack(hy)
        cy = stack.stack(cy)
        return hy, cy, tuple(ys)
예제 #2
0
def n_step_lstm_base(n_layers, dropout_ratio, hx, cx, ws, bs, xs,
                     use_bi_direction, **kwargs):
    """Base function for Stack LSTM/BiLSTM functions.

    This function is used at :func:`chainer.functions.n_step_lstm` and
    :func:`chainer.functions.n_step_bilstm`.
    This function's behavior depends on following arguments,
    ``activation`` and ``use_bi_direction``.

    Args:
        n_layers(int): The number of layers.
        dropout_ratio(float): Dropout ratio.
        hx (~chainer.Variable): Variable holding stacked hidden states.
            Its shape is ``(S, B, N)`` where ``S`` is the number of layers and
            is equal to ``n_layers``, ``B`` is the mini-batch size, and ``N``
            is the dimension of the hidden units.
        cx (~chainer.Variable): Variable holding stacked cell states.
            It has the same shape as ``hx``.
        ws (list of list of :class:`~chainer.Variable`): Weight matrices.
            ``ws[i]`` represents the weights for the i-th layer.
            Each ``ws[i]`` is a list containing eight matrices.
            ``ws[i][j]`` corresponds to :math:`W_j` in the equation.
            Only ``ws[0][j]`` where ``0 <= j < 4`` are ``(I, N)``-shape as they
            are multiplied with input variables, where ``I`` is the size of
            the input and ``N`` is the dimension of the hidden units. All
            other matrices are ``(N, N)``-shaped.
        bs (list of list of :class:`~chainer.Variable`): Bias vectors.
            ``bs[i]`` represents the biases for the i-th layer.
            Each ``bs[i]`` is a list containing eight vectors.
            ``bs[i][j]`` corresponds to :math:`b_j` in the equation.
            The shape of each matrix is ``(N,)``.
        xs (list of :class:`~chainer.Variable`):
            A list of :class:`~chainer.Variable`
            holding input values. Each element ``xs[t]`` holds input value
            for time ``t``. Its shape is ``(B_t, I)``, where ``B_t`` is the
            mini-batch size for time ``t``. The sequences must be transposed.
            :func:`~chainer.functions.transpose_sequence` can be used to
            transpose a list of :class:`~chainer.Variable`\ s each representing
            a sequence. When sequences has different lengths, they must be
            sorted in descending order of their lengths before transposing.
            So ``xs`` needs to satisfy
            ``xs[t].shape[0] >= xs[t + 1].shape[0]``.
        use_bi_direction (bool): If ``True``, this function uses Bi-directional
            LSTM.

    Returns:
        tuple: This functions returns a tuple concaining three elements,
        ``hy``, ``cy`` and ``ys``.

            - ``hy`` is an updated hidden states whose shape is the same as
              ``hx``.
            - ``cy`` is an updated cell states whose shape is the same as
              ``cx``.
            - ``ys`` is a list of :class:`~chainer.Variable` . Each element
              ``ys[t]`` holds hidden states of the last layer corresponding
              to an input ``xs[t]``. Its shape is ``(B_t, N)`` where ``B_t`` is
              the mini-batch size for time ``t``. Note that ``B_t`` is the same
              value as ``xs[t]``.

    .. seealso::

       :func:`chainer.functions.n_step_lstm`
       :func:`chainer.functions.n_step_bilstm`

    """

    argument.check_unexpected_kwargs(
        kwargs,
        train='train argument is not supported anymore. '
        'Use chainer.using_config',
        use_cudnn='use_cudnn argument is not supported anymore. '
        'Use chainer.using_config')
    argument.assert_kwargs_empty(kwargs)

    xp = cuda.get_array_module(hx, hx.data)

    if xp is not numpy and chainer.should_use_cudnn('>=auto', 5000):
        states = get_random_state().create_dropout_states(dropout_ratio)
        # flatten all input variables
        inputs = tuple(
            itertools.chain((hx, cx), itertools.chain.from_iterable(ws),
                            itertools.chain.from_iterable(bs), xs))
        if use_bi_direction:
            rnn = NStepBiLSTM(n_layers, states)
        else:
            rnn = NStepLSTM(n_layers, states)

        ret = rnn(*inputs)
        hy, cy = ret[:2]
        ys = ret[2:]
        return hy, cy, ys

    else:
        direction = 2 if use_bi_direction else 1
        split_size = n_layers * direction
        hx = split_axis.split_axis(hx, split_size, axis=0, force_tuple=True)
        hx = [reshape.reshape(h, h.shape[1:]) for h in hx]
        cx = split_axis.split_axis(cx, split_size, axis=0, force_tuple=True)
        cx = [reshape.reshape(c, c.shape[1:]) for c in cx]

        xws = [_stack_weight([w[2], w[0], w[1], w[3]]) for w in ws]
        hws = [_stack_weight([w[6], w[4], w[5], w[7]]) for w in ws]
        xbs = [_stack_weight([b[2], b[0], b[1], b[3]]) for b in bs]
        hbs = [_stack_weight([b[6], b[4], b[5], b[7]]) for b in bs]

        xs_next = xs
        hy = []
        cy = []
        for layer in six.moves.range(n_layers):

            def _one_directional_loop(di):
                # di=0, forward LSTM
                # di=1, backward LSTM
                h_list = []
                c_list = []
                layer_idx = direction * layer + di
                h = hx[layer_idx]
                c = cx[layer_idx]
                if di == 0:
                    xs_list = xs_next
                else:
                    xs_list = reversed(xs_next)
                for x in xs_list:
                    batch = x.shape[0]
                    if h.shape[0] > batch:
                        h, h_rest = split_axis.split_axis(h, [batch], axis=0)
                        c, c_rest = split_axis.split_axis(c, [batch], axis=0)
                    else:
                        h_rest = None
                        c_rest = None

                    if layer != 0:
                        x = dropout.dropout(x, ratio=dropout_ratio)
                    lstm_in = linear.linear(x, xws[layer_idx],
                                            xbs[layer_idx]) + \
                        linear.linear(h, hws[layer_idx], hbs[layer_idx])

                    c_bar, h_bar = lstm.lstm(c, lstm_in)
                    if h_rest is not None:
                        h = concat.concat([h_bar, h_rest], axis=0)
                        c = concat.concat([c_bar, c_rest], axis=0)
                    else:
                        h = h_bar
                        c = c_bar
                    h_list.append(h_bar)
                    c_list.append(c_bar)
                return h, c, h_list, c_list

            h, c, h_forward, c_forward = _one_directional_loop(di=0)
            hy.append(h)
            cy.append(c)

            if use_bi_direction:
                # BiLSTM
                h, c, h_backward, c_backward = _one_directional_loop(di=1)
                hy.append(h)
                cy.append(c)

                h_backward.reverse()
                # concat
                xs_next = [
                    concat.concat([hfi, hbi], axis=1)
                    for (hfi, hbi) in zip(h_forward, h_backward)
                ]
            else:
                # Uni-directional RNN
                xs_next = h_forward

        ys = xs_next
        hy = stack.stack(hy)
        cy = stack.stack(cy)
        return hy, cy, tuple(ys)