def _conv1d(inputs, output_dim, name=None): """ Performs 1d convolution :param inputs: The tensor inputs - [axis_0, ..., axis_n-1, input_dim] :param output_dim: The size of the 1d convoltuion :param name: Name of the scope - To share weights between calls :return: The tensors after the convolution - [axis_0, ..., axis_n-1, output_dim] """ with variable_scope.variable_scope(name): input_dims = [ array_ops.shape(inputs)[axis] if dim.value is None else dim.value for axis, dim in enumerate(inputs.shape) ] input_prefix_dims, input_last_dim = input_dims[:-1], input_dims[-1] weight = variable_scope.get_variable( 'w', shape=[input_last_dim, output_dim], initializer=init_ops.random_normal_initializer(0.02)) beta = variable_scope.get_variable( 'b', shape=[output_dim], initializer=init_ops.constant_initializer(0.)) inputs = gen_array_ops.reshape( inputs, [-1, input_last_dim]) # [B, input_last_dim] outputs = math_ops.matmul(inputs, weight) + beta # [B, output_dim] return gen_array_ops.reshape(outputs, input_prefix_dims + [output_dim]) # [..., output_dim]
def _linear(self, inputs, proj_dim, name): """ Computes a linear unit inside a full block - Projects to 'proj_dim' and back to 'input_dim' :param inputs: The inputs tensor - [axis_0, ..., axis_n-1, input_dim] :param proj_dim: The dimension of the projection :param name: Name of the scope - To share weights between calls :return: A tensor of shape - [axis_0, ..., axis_n-1, input_dim] """ with variable_scope.variable_scope(name): input_dim = inputs.shape[-1].value output_h1 = gelu(self._conv1d(inputs, proj_dim, 'mlp_fc1')) output_h2 = self._conv1d(output_h1, input_dim, 'mlp_fc2') return output_h2
def __init__(self, num_units, memory, memory_sequence_length=None, normalize=False, probability_fn=None, score_mask_value=None, dtype=None, name_or_scope='BahdanauAttention'): """ Construct the Attention mechanism. :param num_units: The depth of the query mechanism. :param memory: The memory to query; usually the output of an RNN encoder. This tensor should be shaped `[batch_size, max_time, ...]`. :param memory_sequence_length: (optional): Sequence lengths for the batch entries in memory. If provided, the memory tensor rows are masked with zeros for values past the respective sequence lengths. :param normalize: Python boolean. Whether to normalize the energy term. :param probability_fn: (optional) A `callable`. Converts the score to probabilities. The default is @{tf.nn.softmax}. Other options include @{tf.contrib.seq2seq.hardmax} and @{tf.contrib.sparsemax.sparsemax}. Its signature should be: `probabilities = probability_fn(score)`. :param score_mask_value: (optional): The mask value for score before passing into `probability_fn`. The default is -inf. Only used if `memory_sequence_length` is not None. :param dtype: The data type for the query and memory layers of the attention mechanism. :param name_or_scope: String or VariableScope to use when creating ops. """ # pylint: disable=too-many-arguments if probability_fn is None: probability_fn = nn_ops.softmax if dtype is None: dtype = dtypes.float32 wrapped_probability_fn = lambda score, _: probability_fn(score) self._num_units = num_units self._normalize = normalize self._name_or_scope = name_or_scope with variable_scope.variable_scope(name_or_scope, default_name='BahdanauAttention'): super(BahdanauAttention, self).__init__(query_layer=core.Dense(num_units, name='query_layer', use_bias=False, dtype=dtype), memory_layer=core.Dense(num_units, name='memory_layer', use_bias=False, dtype=dtype), memory=memory, probability_fn=wrapped_probability_fn, memory_sequence_length=memory_sequence_length, score_mask_value=score_mask_value)
def __call__(self, query, state): """ Score the query based on the keys and values. :param query: Tensor of dtype matching `self.values` and shape `[batch_size, query_depth]`. :param state: Tensor of dtype matching `self.values` and shape `[batch_size, alignments_size]` (`alignments_size` is memory's `max_time`). :return: Tensor of dtype matching `self.values` and shape `[batch_size, alignments_size]` (`alignments_size` is memory's `max_time`). """ with variable_scope.variable_scope(self._name_or_scope, 'bahdanau_attention', [query]): processed_query = self.query_layer( query) if self.query_layer else query score = _bahdanau_score(processed_query, self._keys, self._normalize) alignments = self._probability_fn(score, state) next_state = alignments return alignments, next_state
def _block(self, inputs, past_attn, name): """ Computes a transformer block :param inputs: The inputs tensor - [batch, seq_len, emb_dim] :param past_attn: The past attention - [batch, 2, nb_heads, emb_size // nb_heads] :param name: Name of the scope - To share weights between calls :return: A tuple consisting of: 1) The output of the transformer block - [batch, seq_len, emb_dim] 2) The present attention - [batch, 2, nb_heads, 1, emb_size // nb_heads] """ with variable_scope.variable_scope(name): input_dim = inputs.shape[-1].value h_out = inputs h_attn, present_attn = self._attn(self._norm(h_out, 'block_norm1'), input_dim, past_attn, 'block_attn') h_out = h_out + h_attn h_mlp = self._linear(self._norm(h_out, 'block_norm2'), input_dim * 4, 'block_mlp') h_out = h_out + h_mlp return h_out, present_attn
def _norm(inputs, name, axis=-1): """ Applies normalization to the input tensor by normalizing to mean=0, std_dev=1, then applying a gamma, beta :param inputs: The tensor inputs to normalize :param name: Name of the scope - To share weights between calls :param axis: Axis to normalize. Defaults to last one. :return: A tensor of the same shape as inputs, but normalized and transformed """ with variable_scope.variable_scope(name): axis_dim = inputs.shape[axis].value gamma = variable_scope.get_variable( 'gamma', [axis_dim], initializer=init_ops.constant_initializer(1.)) beta = variable_scope.get_variable( 'beta', [axis_dim], initializer=init_ops.constant_initializer(0.)) mean = math_ops.reduce_mean(inputs, axis=axis, keepdims=True) var = math_ops.reduce_mean(gen_math_ops.square(inputs - mean), axis=axis, keepdims=True) norm_inputs = (inputs - mean) * gen_math_ops.rsqrt(var + 1e-8) outputs = gamma * norm_inputs + beta return outputs
def call(self, inputs, state): """ Performs a step of attention-wrapped RNN. 1) Mix the `inputs` and previous step's `attention` output via `cell_input_fn`. 2) Call the wrapped `cell` with this input and its previous state. 3) Score the cell's output with `attention_mechanism`. 4) Calculate the alignments by passing the score through the `normalizer`. 5) Calculate the context vector as the inner product between the alignments and the attention_mechanism's values (memory). 6) Calculate the attention output by concatenating the cell output and context through the attention layer (a linear layer with `attention_layer_size` outputs). :param inputs: (Possibly nested tuple of) Tensor, the input at this time step. :param state: An instance of `AttentionWrapperState` containing tensors from the previous time step. :return: A tuple `(attention_or_cell_output, next_state)`, where: - `attention_or_cell_output` depending on `output_attention`. - `next_state` is an instance of `AttentionWrapperState` containing the state calculated at this time step. """ # pylint: disable=arguments-differ if not isinstance(state, AttentionWrapperState): raise TypeError( 'Expected state to be instance of AttentionWrapperState. Rcvd %s instead. ' % type(state)) # Step 1: Calculate the true inputs to the cell based on the previous attention value. cell_inputs = self._cell_input_fn(inputs, state.attention) cell_state = state.cell_state cell_output, next_cell_state = self._cell(cell_inputs, cell_state) cell_batch_size = tensor_shape.dimension_value( cell_output.shape[0]) or array_ops.shape(cell_output)[0] error_message = ( 'When applying AttentionWrapper %s: ' % self.name + 'Non-matching batch sizes between ' 'the memory (encoder output) and the query (decoder output). Are you using the ' 'BeamSearchDecoder? You may need to tile your memory input via the tf.contrib.seq2seq.' 'tile_batch function with argument multiple=beam_width.') with variable_scope.variable_scope(self._name_or_scope, 'AttentionWrapper', [inputs, state]): with ops.control_dependencies( self._batch_size_checks(cell_batch_size, error_message)): cell_output = array_ops.identity(cell_output, name='checked_cell_output') if self._is_multi: previous_attention_state = state.attention_state previous_alignment_history = state.alignment_history else: previous_attention_state = [state.attention_state] previous_alignment_history = [state.alignment_history] # Computing attention all_alignments = [] all_attentions = [] all_attention_states = [] maybe_all_histories = [] for i, attention_mechanism in enumerate( self._attention_mechanisms): attention, alignments, next_attention_state = _compute_attention( attention_mechanism, cell_output, previous_attention_state[i], self._attention_layers[i] if self._attention_layers else None) alignment_history = previous_alignment_history[i].write( state.time, alignments) if self._alignment_history else () all_attention_states.append(next_attention_state) all_alignments.append(alignments) all_attentions.append(attention) maybe_all_histories.append(alignment_history) # Building next state attention = array_ops.concat(all_attentions, 1) next_state = AttentionWrapperState( time=state.time + 1, cell_state=next_cell_state, attention=attention, attention_state=self._item_or_tuple(all_attention_states), alignments=self._item_or_tuple(all_alignments), alignment_history=self._item_or_tuple(maybe_all_histories)) # Returning if self._output_attention: return attention, next_state return cell_output, next_state
def __init__(self, cell, attention_mechanism, attention_layer_size=None, alignment_history=False, cell_input_fn=None, output_attention=True, initial_cell_state=None, name_or_scope='AttentionWrapper', attention_layer=None): """ Construct the `AttentionWrapper`. :param cell: An instance of `RNNCell`. :param attention_mechanism: A list of `AttentionMechanism` instances or a singleinstance. :param attention_layer_size: A list of Python integers or a single Python integer, the depth of the attention (output) layer(s). :param alignment_history: Python boolean, whether to store alignment history from all time steps in the final output state :param cell_input_fn: (optional) A `callable`. The default is: concat([inputs, attention], axis=-1) :param output_attention: Python bool. If `True` (default), the output at each time step is the attn value. :param initial_cell_state: The initial state value to use for the cell when the user calls `zero_state()`. :param name_or_scope: String or VariableScope to use when creating ops. :param attention_layer: A list of `tf.layers.Layer` instances or a single `tf.layers.Layer` instance taking the context and cell output as inputs to generate attention at each time step. If None (default), use the context as attention at each time step. **NOTE** If you are using the `BeamSearchDecoder` with a cell wrapped in `AttentionWrapper`, then you must ensure that: - The encoder output has been tiled to `beam_width` via `tf.contrib.seq2seq.tile_batch` (NOT `tf.tile`). - The `batch_size` argument passed to the `zero_state` method of this wrapper is equal to `true_batch_size * beam_width`. - The initial state created with `zero_state` above contains a `cell_state` value containing properly tiled final state from the encoder. """ # pylint: disable=too-many-arguments self._name_or_scope = name_or_scope with variable_scope.variable_scope(name_or_scope, 'AttentionWrapper'): super(AttentionWrapper, self).__init__() rnn_cell_impl.assert_like_rnncell("cell", cell) # Attention mechanism if isinstance(attention_mechanism, (list, tuple)): self._is_multi = True attention_mechanisms = attention_mechanism for attn_mechanism in attention_mechanisms: if not isinstance(attn_mechanism, AttentionMechanism): raise TypeError( 'attention_mechanism must contain only instances of AttentionMechanism, saw ' 'type: %s' % type(attn_mechanism).__name__) else: self._is_multi = False if not isinstance(attention_mechanism, AttentionMechanism): raise TypeError( 'attention_mechanism must be an AttentionMechanism or list of multiple ' 'AttentionMechanism instances, saw type: %s' % type(attention_mechanism).__name__) attention_mechanisms = (attention_mechanism, ) # Cell input function if cell_input_fn is None: cell_input_fn = lambda inputs, attention: array_ops.concat( [inputs, attention], -1) else: if not callable(cell_input_fn): raise TypeError( 'cell_input_fn must be callable, saw type: %s' % type(cell_input_fn).__name__) # Attention layer size if attention_layer_size is not None and attention_layer is not None: raise ValueError( 'Only one of attention_layer_size and attention_layer should be set' ) if attention_layer_size is not None: attention_layer_sizes = tuple( attention_layer_size if isinstance(attention_layer_size, ( list, tuple)) else (attention_layer_size, )) if len(attention_layer_sizes) != len(attention_mechanisms): raise ValueError( 'If provided, attention_layer_size must contain exactly one integer per ' 'attention_mechanism, saw: %d vs %d' % (len(attention_layer_sizes), len(attention_mechanisms))) self._attention_layers = tuple( core.Dense(attention_layer_size, name='attention_layer', use_bias=False, dtype=attention_mechanisms[i].dtype) for i, attention_layer_size in enumerate(attention_layer_sizes)) self._attention_layer_size = sum(attention_layer_sizes) elif attention_layer is not None: self._attention_layers = tuple(attention_layer if isinstance( attention_layer, (list, tuple)) else (attention_layer, )) if len(self._attention_layers) != len(attention_mechanisms): raise ValueError( 'If provided, attention_layer must contain exactly one layer per ' 'attention_mechanism, saw: %d vs %d' % (len(self._attention_layers), len(attention_mechanisms))) self._attention_layer_size = \ sum(tensor_shape.dimension_value( layer.compute_output_shape([None, cell.output_size + tensor_shape.dimension_value(mechanism.values.shape[-1])])[-1]) for layer, mechanism in zip(self._attention_layers, attention_mechanisms)) else: self._attention_layers = None self._attention_layer_size = sum( tensor_shape.dimension_value( attention_mechanism.values.shape[-1]) for attention_mechanism in attention_mechanisms) self._cell = cell self._attention_mechanisms = attention_mechanisms self._cell_input_fn = cell_input_fn self._output_attention = output_attention self._alignment_history = alignment_history if initial_cell_state is None: self._initial_cell_state = None else: final_state_tensor = nest.flatten(initial_cell_state)[-1] state_batch_size = (tensor_shape.dimension_value( final_state_tensor.shape[0]) or array_ops.shape(final_state_tensor)[0]) error_message = ( 'When constructing AttentionWrapper %s: ' % self._base_name + 'Non-matching batch sizes between the memory (encoder output) and initial_cell_state. ' 'Are you using the BeamSearchDecoder? You may need to tile your initial state via the ' 'tf.contrib.seq2seq.tile_batch function with argument multiple=beam_width.' ) with ops.control_dependencies( self._batch_size_checks(state_batch_size, error_message)): self._initial_cell_state = \ nest.map_structure(lambda state: array_ops.identity(state, name='check_initial_cell_state'), initial_cell_state)
def _attn(self, inputs, attn_dim, past_attn, name): """ Performs multi-head attention inside a transformer block :param inputs: The tensor inputs - [batch, seq_len, emb_size] :param attn_dim: The dimension of the attention (and output) :param past_attn: The past attention - [batch, 2, nb_heads. seq_len, emb_size // nb_heads] :param name: Name of the scope - To share weights between calls :return: A tuple consisting of: 1) The output of the attention - [batch, seq_len, attn_dim] 2) The present attention - [batch, 2, nb_heads, seq_len, emb_size // nb_heads] """ assert inputs.shape.ndims == 3, 'Expected [batch, seq_len, emb_size]' with variable_scope.variable_scope(name): # Computing the query, key, and value vectors query_keys_values = self._conv1d( inputs, 3 * attn_dim, 'attn_fc1') # [batch, seq_len, 3 * attn_dim] query, keys, values = array_ops.split( query_keys_values, 3, axis=-1) # 3x [batch, seq_len, attn_dim] # Splitting into nb_heads of size attn_dim // nb_heads # Output format is [batch, nb_heads, seq_len, attn_dim // nb_heads] query = self._split_in_heads( query) # [bz, nb_heads, seq_len, head_sz] keys = self._split_in_heads( keys) # [bz, nb_heads, seq_len, head_sz] values = self._split_in_heads( values) # [bz, nb_heads, seq_len, head_sz] head_size = query.shape[-1].value # Stacking keys and values to get the present_attn present_attn = array_ops.stack( [keys, values], axis=1) # [bz, 2, nb_heads, seq_len, head_sz] # Adding past_attn to keys and values past_keys, past_values = array_ops.unstack( past_attn, 2, axis=1) # 2x [bz, nb_heads, past_len, head_sz] keys = array_ops.concat( [past_keys, keys], axis=-2) # [bz, nb_heads, total_len, head_sz] values = array_ops.concat( [past_values, values], axis=-2) # [bz, nb_heads, total_len, head_sz] # Performing multi-head attention attn_w = math_ops.matmul( query, keys, transpose_b=True) # [bz. nb_heads, seq_len, total_len] attn_w = attn_w * gen_math_ops.rsqrt( math_ops.cast(head_size, attn_w.dtype) + 1e-8) attn_mask = self._mask_attn_weights( attn_w) # [bz, 1, seq_len, total_len] attn_w = attn_w * attn_mask + math_ops.cast( 1e-10, attn_w.dtype) * (1. - attn_mask) attn_w = nn_ops.softmax(attn_w) attn = math_ops.matmul(attn_w, values) # [bz, nb_heads, seq_len, head_sz] # Merging attention heads, then 1d conv before returning out_attn = self._merge_heads(attn) # [bz, seq_len, attn_dim] out_attn = self._conv1d(out_attn, attn_dim, 'attn_fc2') # [bz, seq_len, attn_dim] # Returning return out_attn, present_attn
def _step(self, inputs, past_attns, time, feeder_cell, feeder_state): """ Performs the block operation on n-layers :param inputs: The tensor inputs (embedding of each word) - [batch, seq_len, emb_size] :param past_attns: The past attentions - [batch, nb_layers, 2, nb_heads. past_length, emb_size // nb_heads] :param time: A tensor representing the current time step :param feeder_cell: None or A feeder cell that returns a RNN cell output to use for conditioning :param feeder_state: None or the initial state of the feeder cell :param name: Name of the scope - To share weights between calls :return: A tuple consisting of: 1) The cell outputs - [batch, seq_len, emb_size] 2) The present attention - [batch, nb_layers, 2, nb_heads. seq_len, emb_size // nb_heads] 3) The new state of the feeder cell """ with variable_scope.variable_scope(self._scope, default_name='step'): past_length = array_ops.shape(past_attns)[ -2] # How many past attention steps we have seq_len = array_ops.shape(inputs)[ -2] # How many steps are we computing for the current time emb_size = inputs.shape[-1].value # The size of the embedding assert emb_size == self._emb_size, 'Expected an embedding size of %d' % self._emb_size # 1) Computing the word embedding of each token assert inputs.shape.ndims == 3, 'Expected [batch, seq_len, emb_size]' # [bz, seq, emb] out_h = inputs # 2) Computing the position embedding of each token # If we know the context was padded, the effective past length is the context length + nb of time steps if self._past_seq_lengths is not None: past_length = gen_math_ops.minimum( past_length, self._past_seq_lengths + time)[:, None] # [bz, 1] else: past_length = gen_array_ops.fill([self._batch_size, 1], value=past_length) # [bz, 1] step_ix = math_ops.range(seq_len)[None, :] # [1, seq_len] token_positions = gen_math_ops.add(past_length, step_ix) # [batch, seq_len] token_positions = gen_math_ops.minimum( self._position_emb_size - 1, token_positions) # [batch, seq_len] h_pos = self._position_embedding_fn( token_positions) # [bz, seq, emb] out_h = out_h + h_pos # 3) If we have a feeder cell, we also need to condition 'h' on it. next_feeder_state = feeder_state if feeder_cell is not None: assert feeder_state is not None, 'A feeder state is required if a feeder cell is provided.' assert inputs.shape[ 1].value == 1, 'The seq dimension must be 1 to use a feeder_cell' feeder_outputs, next_feeder_state = feeder_cell( array_ops.squeeze(inputs, axis=1), feeder_state) h_feed = feeder_outputs # [bz, feeder_sz] if feeder_outputs.shape[-1].value != emb_size: h_feed = core.Dense(emb_size, activation=None, name='h_feed')(h_feed) # [bz, emb] h_feed = gen_array_ops.tile(h_feed[:, None, :], [1, seq_len, 1]) # [bz, seq, emb] out_h = out_h + h_feed # Transformer presents = [] pasts = array_ops.unstack( past_attns, axis=1) # list of [batch, 2, heads, past_len, head_sz] assert len( pasts ) == self._nb_layers, 'Expected the past attention to have %d layers.' % self._nb_layers for layer_ix, past_attn in enumerate(pasts): out_h, present = self._block(out_h, past_attn, 'layer.%d' % layer_ix) presents += [present] presents = array_ops.stack(presents, axis=1) # Normalizing and returning cell_outputs = self._norm(out_h, 'norm_h') # [batch, seq, emb] return cell_outputs, presents, next_feeder_state
def dynamic_decode(decoder, output_time_major=False, impute_finished=False, maximum_iterations=None, parallel_iterations=32, invariants_map=None, swap_memory=False, scope=None): """ Performs dynamic decoding with `decoder`. :param decoder: A `Decoder` instance. :param output_time_major: If True, outputs [time, batch, ...], otherwise outputs [batch, time, ...] :param impute_finished: If true, finished states are copied through the end of the game :param maximum_iterations: Int or None. The maximum number of steps (otherwise decode until it's done) :param parallel_iterations: Argument passed to tf.while_loop :param invariants_map: Optional. Dictionary of tensor path (in initial_state) to its shape invariant. :param swap_memory: Argument passed to `tf.while_loop`. :param scope: Optional variable scope to use. :return: A tuple of 1) final_outputs, 2) final_state, 3) final_sequence_length """ if not isinstance(decoder, seq2seq.Decoder): raise TypeError('Expected decoder to be type Decoder, but saw: %s' % type(decoder)) with variable_scope.variable_scope(scope, 'decoder') as varscope: # Determine context types. ctxt = ops.get_default_graph()._get_control_flow_context() # pylint: disable=protected-access is_xla = control_flow_util.GetContainingXLAContext(ctxt) is not None in_while_loop = control_flow_util.GetContainingWhileContext( ctxt) is not None # Properly cache variable values inside the while_loop. # Don't set a caching device when running in a loop, since it is possible that train steps could be wrapped # in a tf.while_loop. In that scenario caching prevents forward computations in loop iterations from re-reading # the updated weights. if not context.executing_eagerly() and not in_while_loop: if varscope.caching_device is None: varscope.set_caching_device(lambda op: op.device) # Setting maximum iterations if maximum_iterations is not None: maximum_iterations = ops.convert_to_tensor( maximum_iterations, dtype=dtypes.int32, name="maximum_iterations") if maximum_iterations.get_shape().ndims != 0: raise ValueError('maximum_iterations must be a scalar') def _inv_shape(maybe_ta): """ Returns the invariatns shape """ if isinstance(maybe_ta, tensor_array_ops.TensorArray): return maybe_ta.flow.shape return maybe_ta.shape def _invariants(structure): """ Returns the invariants of a structure """ return nest.map_structure(_inv_shape, structure) def _map_invariants(structure): """ Returns the invariants of a structure, but replaces the invariant using the value in invariants_map """ return nest.map_structure_with_paths( lambda path, tensor: (invariants_map or {}).get(path, _inv_shape(tensor)), structure) # Initializing decoder initial_finished, initial_inputs, initial_state = decoder.initialize() zero_outputs = _create_zero_outputs(decoder.output_size, decoder.output_dtype, decoder.batch_size) if is_xla and maximum_iterations is None: raise ValueError( 'maximum_iterations is required for XLA compilation.') if maximum_iterations is not None: initial_finished = gen_math_ops.logical_or(initial_finished, maximum_iterations <= 0) initial_sequence_lengths = array_ops.zeros_like(initial_finished, dtype=dtypes.int32) initial_time = constant_op.constant(0, dtype=dtypes.int32) # Creating initial output TA def _shape(batch_size, from_shape): """ Returns the batch_size concatenated with the from_shape """ if (not isinstance(from_shape, tensor_shape.TensorShape) or from_shape.ndims == 0): return tensor_shape.TensorShape(None) batch_size = tensor_util.constant_value( ops.convert_to_tensor(batch_size, name='batch_size')) return tensor_shape.TensorShape([batch_size ]).concatenate(from_shape) dynamic_size = maximum_iterations is None or not is_xla def _create_ta(shape, dtype): """ Creates a tensor array""" return tensor_array_ops.TensorArray( dtype=dtype, size=0 if dynamic_size else maximum_iterations, dynamic_size=dynamic_size, element_shape=_shape(decoder.batch_size, shape)) initial_outputs_ta = nest.map_structure(_create_ta, decoder.output_size, decoder.output_dtype) def condition(unused_time, unused_outputs_ta, unused_state, unused_inputs, finished, unused_sequence_lengths): """ While loop condition""" return gen_math_ops.logical_not(math_ops.reduce_all(finished)) def body(time, outputs_ta, state, inputs, finished, sequence_lengths): """ Internal while_loop body. """ (next_outputs, decoder_state, next_inputs, decoder_finished) = decoder.step(time, inputs, state) if decoder.tracks_own_finished: next_finished = decoder_finished else: next_finished = gen_math_ops.logical_or( decoder_finished, finished) next_sequence_lengths = array_ops.where( gen_math_ops.logical_not(finished), gen_array_ops.fill(array_ops.shape(sequence_lengths), time + 1), sequence_lengths) nest.assert_same_structure(state, decoder_state) nest.assert_same_structure(outputs_ta, next_outputs) nest.assert_same_structure(inputs, next_inputs) # Zero out output values past finish if impute_finished: emit = nest.map_structure( lambda out, zero: array_ops.where(finished, zero, out), next_outputs, zero_outputs) else: emit = next_outputs # Copy through states past finish def _maybe_copy_state(new, cur): # TensorArrays, multiple dynamic dims, and scalar states get passed through. if isinstance(cur, tensor_array_ops.TensorArray): pass_through = True elif None in new.shape.as_list()[1:]: pass_through = True else: new.set_shape(cur.shape) pass_through = (new.shape.ndims == 0) return new if pass_through else array_ops.where( finished, cur, new) if impute_finished: next_state = nest.map_structure(_maybe_copy_state, decoder_state, state) else: next_state = decoder_state outputs_ta = nest.map_structure( lambda ta, out: ta.write(time, out), outputs_ta, emit) return (time + 1, outputs_ta, next_state, next_inputs, next_finished, next_sequence_lengths) res = control_flow_ops.while_loop( condition, body, loop_vars=(initial_time, initial_outputs_ta, initial_state, initial_inputs, initial_finished, initial_sequence_lengths), shape_invariants=(_invariants(initial_time), _invariants(initial_outputs_ta), _map_invariants(initial_state), _invariants(initial_inputs), _invariants(initial_finished), _invariants(initial_sequence_lengths)), parallel_iterations=parallel_iterations, maximum_iterations=maximum_iterations, swap_memory=swap_memory) final_outputs_ta = res[1] final_state = res[2] final_sequence_lengths = res[5] final_outputs = nest.map_structure(lambda ta: ta.stack(), final_outputs_ta) try: final_outputs, final_state = decoder.finalize( final_outputs, final_state, final_sequence_lengths) except NotImplementedError: pass if not output_time_major: final_outputs = nest.map_structure(_transpose_batch_time, final_outputs) return final_outputs, final_state, final_sequence_lengths