def zero_state(self, batch_size, dtype): """ Return an initial (zero) state tuple for this `AttentionWrapper`. :param batch_size: `0D` integer tensor: the batch size. :param dtype: The internal state data type. :return: An `SelfAttentionWrapperState` tuple containing zeroed out tensors. """ with ops.name_scope(type(self).__name__ + 'ZeroState', values=[batch_size]): # Using batch_size * 0, rather than just 0 to have a dynamic dimension initial_cell_state = self._cell.zero_state(batch_size, dtype) initial_memory = array_ops.zeros( [batch_size, batch_size * 0, self._memory_size], dtype=self._dtype) return SelfAttentionWrapperState(cell_state=initial_cell_state, time=array_ops.zeros( [], dtype=dtypes.int32), memory=initial_memory)
def get_zero_memory_and_attn(): """ Time = 0, we don't concatenate to memory and attention is all 0. """ next_memory = state.memory next_attention = array_ops.zeros( [batch_size, self._attention_layer_size], dtype=inputs.dtype) with ops.control_dependencies([next_memory, next_attention]): return array_ops.identity(next_memory), array_ops.identity( next_attention)
def __init__(self, cell, order_embedding, candidate_embedding, candidates, sequence_length, initial_state, beam_width, input_layer=None, output_layer=None, time_major=False): """ Initialize the CustomBeamHelper :param cell: An `RNNCell` instance. :param order_embedding: The order embedding vector - Size: (batch, ord_emb_size) :param candidate_embedding: The candidate embedding vector - Size: (batch, cand_emb_size) :param candidates: The candidates at each time step -- Size: (batch, nb_cand, max_candidates) :param sequence_length: The length of each sequence (batch,) :param initial_state: A (possibly nested tuple of...) tensors and TensorArrays. :param beam_width: Python integer, the number of beams. :param input_layer: Optional. A layer to apply on the inputs :param output_layer: Optional. An instance of `tf.layers.Layer`, i.e., `tf.layers.Dense`. Optional layer to apply to the RNN output prior to storing the result or sampling. :param time_major: If true indicates that the first dimension is time, otherwise it is batch size. """ # pylint: disable=super-init-not-called,too-many-arguments rnn_cell_impl.assert_like_rnncell('cell', cell) # pylint: disable=protected-access assert isinstance(beam_width, int), 'beam_width should be a Python integer' self._sequence_length = ops.convert_to_tensor(sequence_length, name='sequence_length') if self._sequence_length.get_shape().ndims != 1: raise ValueError("Expected vector for sequence_length. Shape: %s" % self._sequence_length.get_shape()) candidates = ops.convert_to_tensor(candidates, name='candidates') candidates = nest.map_structure(_transpose_batch_time, candidates) if not time_major else candidates self._cell = cell self._order_embedding_fn = _get_embedding_fn(order_embedding) self._candidate_embedding_fn = _get_embedding_fn(candidate_embedding) self._candidate_tas = nest.map_structure(_unstack_ta, candidates) self._input_layer = input_layer if input_layer is not None else lambda x: x self._output_layer = output_layer self._input_size = order_embedding.shape[-1] if input_layer is not None: self._input_size = self._input_layer.compute_output_shape([None, self._input_size])[-1] self._batch_size = array_ops.size(sequence_length) self._start_tokens = gen_array_ops.fill([self._batch_size * beam_width], GO_ID) self._end_token = -1 self._beam_width = beam_width self._initial_cell_state = nest.map_structure(self._maybe_split_batch_beams, initial_state, self._cell.state_size) self._finished = array_ops.one_hot(array_ops.zeros([self._batch_size], dtype=dtypes.int32), depth=self._beam_width, on_value=False, off_value=True, dtype=dtypes.bool) # Compute input shape self._zero_inputs = \ CandidateInputs(inputs= array_ops.zeros_like(self._split_batch_beams( self._input_layer(self._order_embedding_fn(self._start_tokens)), self._input_size)), candidates=array_ops.zeros_like(candidates[0, :]), candidates_emb=array_ops.zeros_like(self._candidate_embedding_fn(candidates[0, :])))
def zero_state(self, batch_size, dtype): """ Return an initial (zero) state tuple for this `IdentityCell`. :param batch_size: `0D` integer tensor: the batch size. :param dtype: The internal state data type. :return: A zeroed out scalar representing the initial state of the cell. """ with ops.name_scope(type(self).__name__ + 'ZeroState', values=[batch_size]): return array_ops.zeros([], dtype=dtypes.int32)
def zero_state(self, batch_size, dtype): """ Return an initial (zero) state tuple for this `IdentityCell`. :param batch_size: `0D` integer tensor: the batch size. :param dtype: The internal state data type. :return: A zeroed out scalar representing the initial state of the cell. """ with ops.name_scope(type(self).__name__ + 'ZeroState', values=[batch_size]): if self._feeder_cell is None: feeder_init_state = array_ops.zeros([], dtype=dtype) elif self._feeder_init_state is not None: feeder_init_state = self._feeder_init_state else: feeder_init_state = self._feeder_cell.zero_state( batch_size, dtype) # Empty past attentions if self._past_attns is None: head_size = self._emb_size // self._nb_heads past_attns_shape = [ batch_size, self._nb_layers, 2, self._nb_heads, 0 * batch_size, head_size ] self._past_attns = array_ops.zeros(past_attns_shape, dtype=dtypes.float32) # No Context - Returning a zero past attention if self._context is None: return TransformerCellState(past_attentions=self._past_attns, feeder_state=feeder_init_state, time=array_ops.zeros( [], dtype=dtypes.int32)) # Context provided - Computing attention by running a single block step _, present_attns, _ = self._step( inputs=self._context_word_embedding_fn(self._context), past_attns=self._past_attns, time=0, feeder_cell=None, feeder_state=None) return TransformerCellState(past_attentions=present_attns, feeder_state=feeder_init_state, time=array_ops.zeros( [], dtype=dtypes.int32))
def zero_state(self, batch_size, dtype): """ Return an initial (zero) state tuple for this `ArrayConcatWrapper`. :param batch_size: `0D` integer tensor: the batch size. :param dtype: The internal state data type. :return: An `ArrayConcatWrapperState` tuple containing zeroed out tensors and, possibly, empty TA objects. """ with ops.name_scope(type(self).__name__ + 'ZeroState', values=[batch_size]): return ArrayConcatWrapperState( cell_state=self._cell.zero_state(batch_size, dtype), time=array_ops.zeros([], dtype=dtypes.int32))
def zero_state(self, batch_size, dtype): """ Return an initial (zero) state tuple for this `AttentionWrapper`. :param batch_size: `0D` integer tensor: the batch size. :param dtype: The internal state data type. :return: AttentionWrapperState` tuple containing zeroed out tensors and, possibly, empty `TensorArrays`. """ with ops.name_scope(type(self).__name__ + 'ZeroState', values=[batch_size]): if self._initial_cell_state is not None: cell_state = self._initial_cell_state else: cell_state = self._cell.zero_state(batch_size, dtype) error_message = ( 'When calling zero_state of AttentionWrapper %s: ' % self._base_name + 'Non-matching batch sizes between the memory encoder output) and the requested batch ' 'size. Are you using the BeamSearchDecoder? If so, make sure your encoder output has been ' 'tiled to beam_width via tf.contrib.seq2seq.tile_batch, and the batch_size= argument ' 'passed to zero_state is batch_size * beam_width.') with ops.control_dependencies( self._batch_size_checks(batch_size, error_message)): cell_state = nest.map_structure( lambda state: array_ops.identity( state, name='checked_cell_state'), cell_state) initial_alignments = [ attention_mechanism.initial_alignments(batch_size, dtype) for attention_mechanism in self._attention_mechanisms ] return AttentionWrapperState( cell_state=cell_state, time=array_ops.zeros([], dtype=dtypes.int32), attention=_zero_state_tensors(self._attention_layer_size, batch_size, dtype), alignments=self._item_or_tuple(initial_alignments), attention_state=self._item_or_tuple( attention_mechanism.initial_state(batch_size, dtype) for attention_mechanism in self._attention_mechanisms), alignment_history=self._item_or_tuple( tensor_array_ops.TensorArray(dtype, size=0, dynamic_size=True, element_shape=alignment.shape) if self._alignment_history else () for alignment in initial_alignments))
def __init__(self, cell, memory, alignments, sequence_length, probability_fn=None, score_mask_value=None, attention_layer_size=None, cell_input_fn=None, output_attention=False, name=None): """ Constructs an AttentionWrapper with static alignments (attention weights) :param cell: An instance of `RNNCell`. :param memory: The memory to query [batch_size, memory_time, memory_size] :param alignments: A tensor of probabilities of shape [batch_size, time_steps, memory_time] :param sequence_length: Sequence lengths for the batch entries in memory. Size (b,) :param probability_fn: A `callable`. Converts the score to probabilities. The default is @{tf.nn.softmax}. :param score_mask_value: The mask value for score before passing into `probability_fn`. Default is -inf. :param attention_layer_size: The size of the attention layer. Uses the context if None. :param cell_input_fn: (optional) A `callable` to aggregate attention. Default: `lambda inputs, attention: array_ops.concat([inputs, attention], -1)`. :param output_attention: If true, outputs the attention, if False outputs the cell output. :param name: name: Name to use when creating ops. """ # pylint: disable=too-many-arguments # Initializing RNN Cell super(StaticAttentionWrapper, self).__init__(name=name) rnn_cell_impl.assert_like_rnncell('cell', cell) # Setting values self._cell = cell self._memory = memory self._attention_layer_size = attention_layer_size self._output_attention = output_attention self._memory_time = alignments.get_shape()[-1].value self._memory_size = memory.get_shape()[-1].value self._sequence_length = sequence_length # Validating attention layer size if self._attention_layer_size is None: self._attention_layer_size = self._memory_size # Validating cell_input_fn if cell_input_fn is None: cell_input_fn = lambda inputs, attention: array_ops.concat( [inputs, attention], -1) else: if not callable(cell_input_fn): raise TypeError( 'cell_input_fn must be callable, saw type: %s' % type(cell_input_fn).__name__) self._cell_input_fn = cell_input_fn # Probability Function if probability_fn is None: probability_fn = nn_ops.softmax if score_mask_value is None: score_mask_value = dtypes.as_dtype( self._memory.dtype).as_numpy_dtype(-np.inf) self._probability_fn = lambda score, _: probability_fn( _maybe_mask_score(score, sequence_length, score_mask_value), _) # Storing alignments as TA # Padding with 1 additional zero, to prevent error on read(0) alignments = array_ops.pad(alignments, [(0, 0), (0, 1), (0, 0)]) alignments = nest.map_structure( _transpose_batch_time, alignments) # (max_time + 1, b, memory_time) self._alignments_ta = nest.map_structure( _unstack_ta, alignments) # [time_step + 1, batch, memory_time] self._initial_alignment = self._alignments_ta.read(0) self._initial_attention = self._compute_attention( self._initial_alignment, self._memory)[0] # Storing zero inputs batch_size = array_ops.shape(memory)[0] self._zero_cell_output = array_ops.zeros( [batch_size, cell.output_size]) self._zero_attention = array_ops.zeros( [batch_size, self._attention_layer_size]) self._zero_state = self.zero_state(batch_size, dtypes.float32) self._zero_alignment = array_ops.zeros_like(self._initial_alignment)
def __init__(self, cell, embedding, mask, sequence_length, initial_state, beam_width, input_layer=None, output_layer=None, time_major=False): """ Initialize the CustomBeamHelper :param cell: An `RNNCell` instance. :param embedding: The embedding vector :param mask: [SparseTensor] Mask to apply at each time step -- Size: (b, dec_len, vocab_size, vocab_size) :param sequence_length: The length of each input (b,) :param initial_state: A (possibly nested tuple of...) tensors and TensorArrays. :param beam_width: Python integer, the number of beams. :param input_layer: Optional. A layer to apply on the inputs :param output_layer: Optional. An instance of `tf.layers.Layer`, i.e., `tf.layers.Dense`. Optional layer to apply to the RNN output prior to storing the result or sampling. :param time_major: If true indicates that the first dimension is time, otherwise it is batch size. """ # pylint: disable=super-init-not-called,too-many-arguments rnn_cell_impl.assert_like_rnncell('cell', cell) # pylint: disable=protected-access assert isinstance(mask, SparseTensor), 'The mask must be a SparseTensor' assert isinstance(beam_width, int), 'beam_width should be a Python integer' self._sequence_length = ops.convert_to_tensor(sequence_length, name='sequence_length') if self._sequence_length.get_shape().ndims != 1: raise ValueError("Expected vector for sequence_length. Shape: %s" % self._sequence_length.get_shape()) self._cell = cell self._embedding_fn = _get_embedding_fn(embedding) self._mask = mask self._time_major = time_major self.vocab_size = VOCABULARY_SIZE self._input_layer = input_layer if input_layer is not None else lambda x: x self._output_layer = output_layer self._input_size = embedding.shape[-1] if input_layer is not None: self._input_size = self._input_layer.compute_output_shape( [None, self._input_size])[-1] self._batch_size = array_ops.size(sequence_length) self._start_tokens = gen_array_ops.fill( [self._batch_size * beam_width], GO_ID) self._end_token = -1 self._beam_width = beam_width self._initial_cell_state = nest.map_structure( self._maybe_split_batch_beams, initial_state, self._cell.state_size) self._finished = array_ops.one_hot(array_ops.zeros([self._batch_size], dtype=dtypes.int32), depth=self._beam_width, on_value=False, off_value=True, dtype=dtypes.bool) # zero_mask is (batch, beam, vocab_size) self._zero_mask = _slice_mask(self._mask, slicing=[-1, 0, GO_ID, -1], squeeze=True, time_major=self._time_major) self._zero_mask = gen_array_ops.tile( array_ops.expand_dims(self._zero_mask, axis=1), [1, self._beam_width, 1]) self._zero_inputs = \ MaskedInputs( inputs=array_ops.zeros_like( self._split_batch_beams( self._input_layer(self._embedding_fn(self._start_tokens)), self._input_size)), mask=self._zero_mask)