def get_state_shape(s): c = _concat(batch_size, s) size = array_ops.zeros(c, dtype=dtype) if not context.executing_eagerly(): c_static = _concat(batch_size, s, static=True) size.set_shape(c_static) return size
def build(self, _): from tensorflow.python.ops.rnn_cell_impl import _concat c = _concat(1, self._state_size, static=True) # size = self.add_variable("init_state", shape=c, initializer=tf.initializers.zeros) size = self.add_weight(name="init_state", shape=c, dtype=self.dtype, initializer=tf.initializers.zeros, trainable=True) self._size = tf.tile(size, [self._batch_size] + [1] * (len(c) - 1)) self.built = True
def raw_rnn(cell, loop_fn, parallel_iterations=None, swap_memory=False, scope=None): """ raw_rnn adapted from the original tensorflow implementation (https://github.com/tensorflow/tensorflow/blob/r1.4/tensorflow/python/ops/rnn.py) to emit arbitrarily nested states for each time step (concatenated along the time axis) in addition to the outputs at each timestep and the final state returns ( states for all timesteps, outputs for all timesteps, final cell state, ) """ if not _like_rnncell(cell): raise TypeError("cell must be an instance of RNNCell") if not callable(loop_fn): raise TypeError("loop_fn must be a callable") parallel_iterations = parallel_iterations or 32 # Create a new scope in which the caching device is either # determined by the parent scope, or is set to place the cached # Variable using the same placement as for the rest of the RNN. with vs.variable_scope(scope or "rnn") as varscope: if not context.executing_eagerly(): if varscope.caching_device is None: varscope.set_caching_device(lambda op: op.device) time = constant_op.constant(0, dtype=dtypes.int32) (elements_finished, next_input, initial_state, emit_structure, init_loop_state) = loop_fn(time, None, None, None) flat_input = nest.flatten(next_input) # Need a surrogate loop state for the while_loop if none is available. loop_state = (init_loop_state if init_loop_state is not None else constant_op.constant(0, dtype=dtypes.int32)) input_shape = [input_.get_shape() for input_ in flat_input] static_batch_size = input_shape[0][0] for input_shape_i in input_shape: # Static verification that batch sizes all match static_batch_size.merge_with(input_shape_i[0]) batch_size = static_batch_size.value const_batch_size = batch_size if batch_size is None: batch_size = array_ops.shape(flat_input[0])[0] nest.assert_same_structure(initial_state, cell.state_size) state = initial_state flat_state = nest.flatten(state) flat_state = [ops.convert_to_tensor(s) for s in flat_state] state = nest.pack_sequence_as(structure=state, flat_sequence=flat_state) if emit_structure is not None: flat_emit_structure = nest.flatten(emit_structure) flat_emit_size = [ emit.shape if emit.shape.is_fully_defined() else array_ops.shape(emit) for emit in flat_emit_structure ] flat_emit_dtypes = [emit.dtype for emit in flat_emit_structure] else: emit_structure = cell.output_size flat_emit_size = nest.flatten(emit_structure) flat_emit_dtypes = [flat_state[0].dtype] * len(flat_emit_size) flat_state_size = [ s.shape if s.shape.is_fully_defined() else array_ops.shape(s) for s in flat_state ] flat_state_dtypes = [s.dtype for s in flat_state] flat_emit_ta = [ tensor_array_ops.TensorArray( dtype=dtype_i, dynamic_size=True, element_shape=(tensor_shape.TensorShape([ const_batch_size ]).concatenate(_maybe_tensor_shape_from_tensor(size_i))), size=0, name="rnn_output_%d" % i) for i, (dtype_i, size_i) in enumerate(zip(flat_emit_dtypes, flat_emit_size)) ] emit_ta = nest.pack_sequence_as(structure=emit_structure, flat_sequence=flat_emit_ta) flat_zero_emit = [ array_ops.zeros(_concat(batch_size, size_i), dtype_i) for size_i, dtype_i in zip(flat_emit_size, flat_emit_dtypes) ] zero_emit = nest.pack_sequence_as(structure=emit_structure, flat_sequence=flat_zero_emit) flat_state_ta = [ tensor_array_ops.TensorArray( dtype=dtype_i, dynamic_size=True, element_shape=(tensor_shape.TensorShape([ const_batch_size ]).concatenate(_maybe_tensor_shape_from_tensor(size_i))), size=0, name="rnn_state_%d" % i) for i, ( dtype_i, size_i) in enumerate(zip(flat_state_dtypes, flat_state_size)) ] state_ta = nest.pack_sequence_as(structure=state, flat_sequence=flat_state_ta) def condition(unused_time, elements_finished, *_): return math_ops.logical_not(math_ops.reduce_all(elements_finished)) def body(time, elements_finished, current_input, state_ta, emit_ta, state, loop_state): (next_output, cell_state) = cell(current_input, state) nest.assert_same_structure(state, cell_state) nest.assert_same_structure(cell.output_size, next_output) next_time = time + 1 (next_finished, next_input, next_state, emit_output, next_loop_state) = loop_fn(next_time, next_output, cell_state, loop_state) nest.assert_same_structure(state, next_state) nest.assert_same_structure(current_input, next_input) nest.assert_same_structure(emit_ta, emit_output) # If loop_fn returns None for next_loop_state, just reuse the previous one. loop_state = loop_state if next_loop_state is None else next_loop_state def _copy_some_through(current, candidate): """Copy some tensors through via array_ops.where.""" def copy_fn(cur_i, cand_i): # TensorArray and scalar get passed through. if isinstance(cur_i, tensor_array_ops.TensorArray): return cand_i if cur_i.shape.ndims == 0: return cand_i # Otherwise propagate the old or the new value. with ops.colocate_with(cand_i): return array_ops.where(elements_finished, cur_i, cand_i) return nest.map_structure(copy_fn, current, candidate) emit_output = _copy_some_through(zero_emit, emit_output) next_state = _copy_some_through(state, next_state) emit_ta = nest.map_structure(lambda ta, emit: ta.write(time, emit), emit_ta, emit_output) state_ta = nest.map_structure( lambda ta, state: ta.write(time, state), state_ta, next_state) elements_finished = math_ops.logical_or(elements_finished, next_finished) return (next_time, elements_finished, next_input, state_ta, emit_ta, next_state, loop_state) returned = control_flow_ops.while_loop( condition, body, loop_vars=[ time, elements_finished, next_input, state_ta, emit_ta, state, loop_state ], parallel_iterations=parallel_iterations, swap_memory=swap_memory) (state_ta, emit_ta, final_state, final_loop_state) = returned[-4:] flat_states = nest.flatten(state_ta) flat_states = [ array_ops.transpose(ta.stack(), (1, 0, 2)) for ta in flat_states ] states = nest.pack_sequence_as(structure=state_ta, flat_sequence=flat_states) flat_outputs = nest.flatten(emit_ta) flat_outputs = [ array_ops.transpose(ta.stack(), (1, 0, 2)) for ta in flat_outputs ] outputs = nest.pack_sequence_as(structure=emit_ta, flat_sequence=flat_outputs) return (states, outputs, final_state)
def raw_rnn(cell, loop_fn, parallel_iterations=None, swap_memory=False, scope=None): """ raw_rnn adapted from the original tensorflow implementation (https://github.com/tensorflow/tensorflow/blob/r1.4/tensorflow/python/ops/rnn.py) to emit arbitrarily nested states for each time step (concatenated along the time axis) in addition to the outputs at each timestep and the final state returns ( states for all timesteps, outputs for all timesteps, final cell state, ) """ if not _like_rnncell(cell): raise TypeError("cell must be an instance of RNNCell") if not callable(loop_fn): raise TypeError("loop_fn must be a callable") parallel_iterations = parallel_iterations or 32 # Create a new scope in which the caching device is either # determined by the parent scope, or is set to place the cached # Variable using the same placement as for the rest of the RNN. with vs.variable_scope(scope or "rnn") as varscope: if context.in_graph_mode(): if varscope.caching_device is None: varscope.set_caching_device(lambda op: op.device) time = constant_op.constant(0, dtype=dtypes.int32) (elements_finished, next_input, initial_state, emit_structure, init_loop_state) = loop_fn(time, None, None, None) flat_input = nest.flatten(next_input) # Need a surrogate loop state for the while_loop if none is available. loop_state = (init_loop_state if init_loop_state is not None else constant_op.constant(0, dtype=dtypes.int32)) input_shape = [input_.get_shape() for input_ in flat_input] static_batch_size = input_shape[0][0] for input_shape_i in input_shape: # Static verification that batch sizes all match static_batch_size.merge_with(input_shape_i[0]) batch_size = static_batch_size.value const_batch_size = batch_size if batch_size is None: batch_size = array_ops.shape(flat_input[0])[0] nest.assert_same_structure(initial_state, cell.state_size) state = initial_state flat_state = nest.flatten(state) flat_state = [ops.convert_to_tensor(s) for s in flat_state] state = nest.pack_sequence_as(structure=state, flat_sequence=flat_state) if emit_structure is not None: flat_emit_structure = nest.flatten(emit_structure) flat_emit_size = [emit.shape if emit.shape.is_fully_defined() else array_ops.shape(emit) for emit in flat_emit_structure] flat_emit_dtypes = [emit.dtype for emit in flat_emit_structure] else: emit_structure = cell.output_size flat_emit_size = nest.flatten(emit_structure) flat_emit_dtypes = [flat_state[0].dtype] * len(flat_emit_size) flat_state_size = [s.shape if s.shape.is_fully_defined() else array_ops.shape(s) for s in flat_state] flat_state_dtypes = [s.dtype for s in flat_state] flat_emit_ta = [ tensor_array_ops.TensorArray( dtype=dtype_i, dynamic_size=True, element_shape=(tensor_shape.TensorShape([const_batch_size]) .concatenate(_maybe_tensor_shape_from_tensor(size_i))), size=0, name="rnn_output_%d" % i ) for i, (dtype_i, size_i) in enumerate(zip(flat_emit_dtypes, flat_emit_size)) ] emit_ta = nest.pack_sequence_as(structure=emit_structure, flat_sequence=flat_emit_ta) flat_zero_emit = [ array_ops.zeros(_concat(batch_size, size_i), dtype_i) for size_i, dtype_i in zip(flat_emit_size, flat_emit_dtypes)] zero_emit = nest.pack_sequence_as(structure=emit_structure, flat_sequence=flat_zero_emit) flat_state_ta = [ tensor_array_ops.TensorArray( dtype=dtype_i, dynamic_size=True, element_shape=(tensor_shape.TensorShape([const_batch_size]) .concatenate(_maybe_tensor_shape_from_tensor(size_i))), size=0, name="rnn_state_%d" % i ) for i, (dtype_i, size_i) in enumerate(zip(flat_state_dtypes, flat_state_size)) ] state_ta = nest.pack_sequence_as(structure=state, flat_sequence=flat_state_ta) def condition(unused_time, elements_finished, *_): return math_ops.logical_not(math_ops.reduce_all(elements_finished)) def body(time, elements_finished, current_input, state_ta, emit_ta, state, loop_state): (next_output, cell_state) = cell(current_input, state) nest.assert_same_structure(state, cell_state) nest.assert_same_structure(cell.output_size, next_output) next_time = time + 1 (next_finished, next_input, next_state, emit_output, next_loop_state) = loop_fn(next_time, next_output, cell_state, loop_state) nest.assert_same_structure(state, next_state) nest.assert_same_structure(current_input, next_input) nest.assert_same_structure(emit_ta, emit_output) # If loop_fn returns None for next_loop_state, just reuse the previous one. loop_state = loop_state if next_loop_state is None else next_loop_state def _copy_some_through(current, candidate): """Copy some tensors through via array_ops.where.""" def copy_fn(cur_i, cand_i): # TensorArray and scalar get passed through. if isinstance(cur_i, tensor_array_ops.TensorArray): return cand_i if cur_i.shape.ndims == 0: return cand_i # Otherwise propagate the old or the new value. with ops.colocate_with(cand_i): return array_ops.where(elements_finished, cur_i, cand_i) return nest.map_structure(copy_fn, current, candidate) emit_output = _copy_some_through(zero_emit, emit_output) next_state = _copy_some_through(state, next_state) emit_ta = nest.map_structure(lambda ta, emit: ta.write(time, emit), emit_ta, emit_output) state_ta = nest.map_structure(lambda ta, state: ta.write(time, state), state_ta, next_state) elements_finished = math_ops.logical_or(elements_finished, next_finished) return (next_time, elements_finished, next_input, state_ta, emit_ta, next_state, loop_state) returned = control_flow_ops.while_loop( condition, body, loop_vars=[ time, elements_finished, next_input, state_ta, emit_ta, state, loop_state], parallel_iterations=parallel_iterations, swap_memory=swap_memory ) (state_ta, emit_ta, final_state, final_loop_state) = returned[-4:] flat_states = nest.flatten(state_ta) flat_states = [array_ops.transpose(ta.stack(), (1, 0, 2)) for ta in flat_states] states = nest.pack_sequence_as(structure=state_ta, flat_sequence=flat_states) flat_outputs = nest.flatten(emit_ta) flat_outputs = [array_ops.transpose(ta.stack(), (1, 0, 2)) for ta in flat_outputs] outputs = nest.pack_sequence_as(structure=emit_ta, flat_sequence=flat_outputs) return (states, outputs, final_state)
def get_state_shape(s): """Combine s with batch_size to get a proper tensor shape.""" c = _concat(batch_size, s) size = tf.random_uniform(c, dtype=dtype) return size
def get_state_shape(s): c = _concat(batch_size, s) c_static = _concat(batch_size, s, static=True) size = array_ops.zeros(c, dtype=dtype) size.set_shape(c_static) return size