コード例 #1
0
ファイル: recurrent.py プロジェクト: CharlesShang/tflearn
def _rnn_template(incoming, cell, dropout=None, return_seq=False,
                  return_state=False, initial_state=None, dynamic=False,
                  scope=None, name="LSTM"):
    """ RNN Layer Template. """
    sequence_length = None
    if dynamic:
        sequence_length = retrieve_seq_length_op(
            incoming if isinstance(incoming, tf.Tensor) else tf.pack(incoming))

    input_shape = utils.get_incoming_shape(incoming)

    with tf.variable_op_scope([incoming], scope, name) as scope:
        name = scope.name

        _cell = cell
        # Apply dropout
        if dropout:
            if type(dropout) in [tuple, list]:
                in_keep_prob = dropout[0]
                out_keep_prob = dropout[1]
            elif isinstance(dropout, float):
                in_keep_prob, out_keep_prob = dropout, dropout
            else:
                raise Exception("Invalid dropout type (must be a 2-D tuple of "
                                "float)")
            cell = DropoutWrapper(cell, in_keep_prob, out_keep_prob)

        inference = incoming
        # If a tensor given, convert it to a per timestep list
        if type(inference) not in [list, np.array]:
            ndim = len(input_shape)
            assert ndim >= 3, "Input dim should be at least 3."
            axes = [1, 0] + list(range(2, ndim))
            inference = tf.transpose(inference, (axes))
            inference = tf.unpack(inference)

        outputs, state = _rnn(cell, inference, dtype=tf.float32,
                              initial_state=initial_state, scope=name,
                              sequence_length=sequence_length)

        # Retrieve RNN Variables
        c = tf.GraphKeys.LAYER_VARIABLES + '/' + scope.name
        for v in [_cell.W, _cell.b]:
            if hasattr(v, "__len__"):
                for var in v: tf.add_to_collection(c, var)
            else:
                tf.add_to_collection(c, v)
        # Track activations.
        tf.add_to_collection(tf.GraphKeys.ACTIVATIONS, outputs[-1])

    if dynamic:
        outputs = tf.transpose(tf.pack(outputs), [1, 0, 2])
        o = advanced_indexing_op(outputs, sequence_length)
    else:
        o = outputs if return_seq else outputs[-1]

    # Track output tensor.
    tf.add_to_collection(tf.GraphKeys.LAYER_TENSOR + '/' + name, o)

    return (o, state) if return_state else o
コード例 #2
0
def _rnn_template(incoming,
                  cell,
                  dropout=None,
                  return_seq=False,
                  return_state=False,
                  initial_state=None,
                  dynamic=False,
                  scope=None,
                  name="LSTM"):
    """ RNN Layer Template. """
    sequence_length = None
    if dynamic:
        sequence_length = retrieve_seq_length_op(
            incoming if isinstance(incoming, tf.Tensor) else tf.pack(incoming))

    input_shape = utils.get_incoming_shape(incoming)

    with tf.variable_scope(scope, name, values=[incoming]) as scope:
        name = scope.name

        _cell = cell
        # Apply dropout
        if dropout:
            if type(dropout) in [tuple, list]:
                in_keep_prob = dropout[0]
                out_keep_prob = dropout[1]
            elif isinstance(dropout, float):
                in_keep_prob, out_keep_prob = dropout, dropout
            else:
                raise Exception("Invalid dropout type (must be a 2-D tuple of "
                                "float)")
            cell = DropoutWrapper(cell, in_keep_prob, out_keep_prob)

        inference = incoming
        # If a tensor given, convert it to a per timestep list
        if type(inference) not in [list, np.array]:
            ndim = len(input_shape)
            assert ndim >= 3, "Input dim should be at least 3."
            axes = [1, 0] + list(range(2, ndim))
            inference = tf.transpose(inference, (axes))
            inference = tf.unpack(inference)

        outputs, state = _rnn(cell,
                              inference,
                              dtype=tf.float32,
                              initial_state=initial_state,
                              scope=name,
                              sequence_length=sequence_length)

        # Retrieve RNN Variables
        c = tf.GraphKeys.LAYER_VARIABLES + '/' + scope.name
        for v in [_cell.W, _cell.b]:
            if hasattr(v, "__len__"):
                for var in v:
                    tf.add_to_collection(c, var)
            else:
                tf.add_to_collection(c, v)
        # Track activations.
        tf.add_to_collection(tf.GraphKeys.ACTIVATIONS, outputs[-1])

    if dynamic:
        if return_seq:
            o = outputs
        else:
            outputs = tf.transpose(tf.pack(outputs), [1, 0, 2])
            o = advanced_indexing_op(outputs, sequence_length)
    else:
        o = outputs if return_seq else outputs[-1]

    # Track output tensor.
    tf.add_to_collection(tf.GraphKeys.LAYER_TENSOR + '/' + name, o)

    return (o, state) if return_state else o