def _rnn_template(incoming, cell, dropout=None, return_seq=False, return_state=False, initial_state=None, dynamic=False, scope=None, name="LSTM"): """ RNN Layer Template. """ sequence_length = None if dynamic: sequence_length = retrieve_seq_length_op( incoming if isinstance(incoming, tf.Tensor) else tf.pack(incoming)) input_shape = utils.get_incoming_shape(incoming) with tf.variable_op_scope([incoming], scope, name) as scope: name = scope.name _cell = cell # Apply dropout if dropout: if type(dropout) in [tuple, list]: in_keep_prob = dropout[0] out_keep_prob = dropout[1] elif isinstance(dropout, float): in_keep_prob, out_keep_prob = dropout, dropout else: raise Exception("Invalid dropout type (must be a 2-D tuple of " "float)") cell = DropoutWrapper(cell, in_keep_prob, out_keep_prob) inference = incoming # If a tensor given, convert it to a per timestep list if type(inference) not in [list, np.array]: ndim = len(input_shape) assert ndim >= 3, "Input dim should be at least 3." axes = [1, 0] + list(range(2, ndim)) inference = tf.transpose(inference, (axes)) inference = tf.unpack(inference) outputs, state = _rnn(cell, inference, dtype=tf.float32, initial_state=initial_state, scope=name, sequence_length=sequence_length) # Retrieve RNN Variables c = tf.GraphKeys.LAYER_VARIABLES + '/' + scope.name for v in [_cell.W, _cell.b]: if hasattr(v, "__len__"): for var in v: tf.add_to_collection(c, var) else: tf.add_to_collection(c, v) # Track activations. tf.add_to_collection(tf.GraphKeys.ACTIVATIONS, outputs[-1]) if dynamic: outputs = tf.transpose(tf.pack(outputs), [1, 0, 2]) o = advanced_indexing_op(outputs, sequence_length) else: o = outputs if return_seq else outputs[-1] # Track output tensor. tf.add_to_collection(tf.GraphKeys.LAYER_TENSOR + '/' + name, o) return (o, state) if return_state else o
def _rnn_template(incoming, cell, dropout=None, return_seq=False, return_state=False, initial_state=None, dynamic=False, scope=None, name="LSTM"): """ RNN Layer Template. """ sequence_length = None if dynamic: sequence_length = retrieve_seq_length_op( incoming if isinstance(incoming, tf.Tensor) else tf.pack(incoming)) input_shape = utils.get_incoming_shape(incoming) with tf.variable_scope(scope, name, values=[incoming]) as scope: name = scope.name _cell = cell # Apply dropout if dropout: if type(dropout) in [tuple, list]: in_keep_prob = dropout[0] out_keep_prob = dropout[1] elif isinstance(dropout, float): in_keep_prob, out_keep_prob = dropout, dropout else: raise Exception("Invalid dropout type (must be a 2-D tuple of " "float)") cell = DropoutWrapper(cell, in_keep_prob, out_keep_prob) inference = incoming # If a tensor given, convert it to a per timestep list if type(inference) not in [list, np.array]: ndim = len(input_shape) assert ndim >= 3, "Input dim should be at least 3." axes = [1, 0] + list(range(2, ndim)) inference = tf.transpose(inference, (axes)) inference = tf.unpack(inference) outputs, state = _rnn(cell, inference, dtype=tf.float32, initial_state=initial_state, scope=name, sequence_length=sequence_length) # Retrieve RNN Variables c = tf.GraphKeys.LAYER_VARIABLES + '/' + scope.name for v in [_cell.W, _cell.b]: if hasattr(v, "__len__"): for var in v: tf.add_to_collection(c, var) else: tf.add_to_collection(c, v) # Track activations. tf.add_to_collection(tf.GraphKeys.ACTIVATIONS, outputs[-1]) if dynamic: if return_seq: o = outputs else: outputs = tf.transpose(tf.pack(outputs), [1, 0, 2]) o = advanced_indexing_op(outputs, sequence_length) else: o = outputs if return_seq else outputs[-1] # Track output tensor. tf.add_to_collection(tf.GraphKeys.LAYER_TENSOR + '/' + name, o) return (o, state) if return_state else o