def step(self, time, inputs, cell_state): """ Performs a step using the beam search cell :param time: The current time step (scalar) :param inputs: A (structure of) input tensors. :param state: A (structure of) state tensors and TensorArrays. :return: `(cell_outputs, next_cell_state)`. """ raw_inputs = inputs inputs, candidates, candidates_emb = raw_inputs.inputs, raw_inputs.candidates, raw_inputs.candidates_emb inputs = nest.map_structure(lambda inp: self._merge_batch_beams(inp, depth_shape=inp.shape[2:]), inputs) cell_state = nest.map_structure(self._maybe_merge_batch_beams, cell_state, self._cell.state_size) cell_outputs, next_cell_state = self._cell(inputs, cell_state) # [batch * beam, out_sz] next_cell_state = nest.map_structure(self._maybe_split_batch_beams, next_cell_state, self._cell.state_size) # Splitting outputs and adding a bias dimension # cell_outputs is [batch, beam, cand_emb_size + 1] cell_outputs = self._output_layer(cell_outputs) if self._output_layer is not None else cell_outputs cell_outputs = nest.map_structure(lambda out: self._split_batch_beams(out, out.shape[1:]), cell_outputs) cell_outputs = array_ops.pad(cell_outputs, [(0, 0), (0, 0), (0, 1)], constant_values=1.) # Computing candidates # cell_outputs is reshaped to [batch, beam, 1, cand_emb_size + 1] # candidates_emb is reshaped to [batch, 1, max_cand, cand_emb_size + 1] # output_mask is [batch, 1, max_cand] # cell_outputs is finally [batch, beam, max_cand] cell_outputs = math_ops.reduce_sum(array_ops.expand_dims(cell_outputs, axis=2) * array_ops.expand_dims(candidates_emb, axis=1), axis=-1) output_mask = math_ops.cast(array_ops.expand_dims(gen_math_ops.greater(candidates, 0), axis=1), dtypes.float32) cell_outputs = gen_math_ops.add(cell_outputs, (1. - output_mask) * LARGE_NEGATIVE) # Returning return cell_outputs, next_cell_state
def get_next_memory_and_attn(): """ Gets the next memory and attention """ next_memory = array_ops.concat( [ state.memory, # [b, t, mem_size] array_ops.expand_dims(self._input_fn(inputs), axis=1) ], axis=1) next_attention = self._compute_attention(inputs, next_memory) with ops.control_dependencies([next_memory, next_attention]): return array_ops.identity(next_memory), array_ops.identity( next_attention)
def next_inputs(self, time, inputs, beam_search_output, beam_search_state): """ Computes the inputs at the next time step given the beam outputs :param time: The current time step (scalar) :param inputs: A (structure of) input tensors. :param beam_search_output: The output of the beam search step :param beam_search_state: The state after the beam search step :return: `(beam_search_output, next_inputs)` :type beam_search_output: beam_search_decoder.BeamSearchDecoderOutput :type beam_search_state: beam_search_decoder.BeamSearchDecoderState """ next_time = time + 1 all_finished = math_ops.reduce_all(next_time >= self._sequence_length) # Sampling next_word_ids = beam_search_output.predicted_ids candidates = inputs.candidates nb_candidates = array_ops.shape(candidates)[1] sample_ids = math_ops.reduce_sum(array_ops.one_hot(next_word_ids, nb_candidates, dtype=dtypes.int32) * array_ops.expand_dims(candidates, axis=1), axis=-1) def get_next_inputs(): """ Retrieves the inputs for the next time step """ inputs_next_step = sample_ids inputs_emb_next_step = self._input_layer(self._order_embedding_fn(inputs_next_step)) candidate_next_step = self._candidate_tas.read(next_time) candidate_emb_next_step = self._candidate_embedding_fn(candidate_next_step) # Prevents this branch from executing eagerly with ops.control_dependencies([inputs_emb_next_step, candidate_next_step, candidate_emb_next_step]): return CandidateInputs(inputs=array_ops.identity(inputs_emb_next_step), candidates=array_ops.identity(candidate_next_step), candidates_emb=array_ops.identity(candidate_emb_next_step)) # Getting next inputs next_inputs = control_flow_ops.cond(all_finished, true_fn=lambda: self._zero_inputs, false_fn=get_next_inputs) # Rewriting beam search output with the correct sample ids beam_search_output = beam_search_decoder.BeamSearchDecoderOutput(scores=beam_search_output.scores, predicted_ids=sample_ids, parent_ids=beam_search_output.parent_ids) # Returning return beam_search_output, next_inputs
def _compute_attention(self, alignments, memory): """Computes the attention and alignments for a given attention_mechanism.""" # Reshape from [batch_size, memory_time] to [batch_size, 1, memory_time] expanded_alignments = array_ops.expand_dims(alignments, 1) # Context is the inner product of alignments and values along the # memory time dimension. # alignments shape is [batch_size, 1, memory_time] # memory is [batch_size, memory_time, memory_size] # the batched matmul is over memory_time, so the output shape is [batch_size, 1, memory_size]. # we then squeeze out the singleton dim. context = math_ops.matmul(expanded_alignments, memory) context = array_ops.squeeze(context, [1]) attn_layer = lambda x: x if self._attention_layer_size != self._memory_size: attn_layer = core.Dense(self._attention_layer_size, name='attn_layer', use_bias=False, dtype=context.dtype) attention = attn_layer(context) return attention, alignments
def _compute_attention(self, query, memory): """ Computes the attention and alignments for the Bahdanau attention mechanism . :param query: The query (inputs) to use to compute attention. Size [b, input_size] :param memory: The memory (previous outputs) used to compute attention [b, time_step, memory_size] :return: The attention. Size [b, attn_size] """ assert len( memory.shape) == 3, 'Memory needs to be [batch, time, memory_size]' memory_time = array_ops.shape(memory)[1] memory_size = memory.shape[2] num_units = self._num_units assert self._memory_size == memory_size, 'Expected mem size of %s - Got %s' % ( self._memory_size, memory_size) # Query, memory, and attention layers query_layer = core.Dense(num_units, name='query_layer', use_bias=False, dtype=self._dtype) memory_layer = lambda x: x if memory_size != self._num_units: memory_layer = core.Dense(num_units, name='memory_layer', use_bias=False, dtype=self._dtype) attn_layer = lambda x: x if self._attention_layer_size is not None and memory_size != self._attention_layer_size: attn_layer = core.Dense(self._attention_layer_size, name='attn_layer', use_bias=False, dtype=self._dtype) # Masking memory sequence_length = gen_math_ops.minimum(memory_time, self._sequence_length) sequence_mask = array_ops.sequence_mask(sequence_length, maxlen=memory_time, dtype=dtypes.float32)[..., None] values = memory * sequence_mask keys = memory_layer(values) # Computing scores processed_query = query_layer(query) scores = _bahdanau_score(processed_query, keys, self._normalize) # Getting alignments masked_scores = _maybe_mask_score(scores, sequence_length, self._score_mask_value) alignments = self._wrapped_probability_fn(masked_scores, None) # [batch, time] # Getting attention expanded_alignments = array_ops.expand_dims(alignments, 1) # [batch, 1, time] context = math_ops.matmul(expanded_alignments, memory) # [batch, 1, memory_size] context = array_ops.squeeze(context, [1]) # [batch, memory_size] attention = attn_layer(context) # [batch, attn_size] # Returning attention return attention
def __init__(self, cell, embedding, mask, sequence_length, initial_state, beam_width, input_layer=None, output_layer=None, time_major=False): """ Initialize the CustomBeamHelper :param cell: An `RNNCell` instance. :param embedding: The embedding vector :param mask: [SparseTensor] Mask to apply at each time step -- Size: (b, dec_len, vocab_size, vocab_size) :param sequence_length: The length of each input (b,) :param initial_state: A (possibly nested tuple of...) tensors and TensorArrays. :param beam_width: Python integer, the number of beams. :param input_layer: Optional. A layer to apply on the inputs :param output_layer: Optional. An instance of `tf.layers.Layer`, i.e., `tf.layers.Dense`. Optional layer to apply to the RNN output prior to storing the result or sampling. :param time_major: If true indicates that the first dimension is time, otherwise it is batch size. """ # pylint: disable=super-init-not-called,too-many-arguments rnn_cell_impl.assert_like_rnncell('cell', cell) # pylint: disable=protected-access assert isinstance(mask, SparseTensor), 'The mask must be a SparseTensor' assert isinstance(beam_width, int), 'beam_width should be a Python integer' self._sequence_length = ops.convert_to_tensor(sequence_length, name='sequence_length') if self._sequence_length.get_shape().ndims != 1: raise ValueError("Expected vector for sequence_length. Shape: %s" % self._sequence_length.get_shape()) self._cell = cell self._embedding_fn = _get_embedding_fn(embedding) self._mask = mask self._time_major = time_major self.vocab_size = VOCABULARY_SIZE self._input_layer = input_layer if input_layer is not None else lambda x: x self._output_layer = output_layer self._input_size = embedding.shape[-1] if input_layer is not None: self._input_size = self._input_layer.compute_output_shape( [None, self._input_size])[-1] self._batch_size = array_ops.size(sequence_length) self._start_tokens = gen_array_ops.fill( [self._batch_size * beam_width], GO_ID) self._end_token = -1 self._beam_width = beam_width self._initial_cell_state = nest.map_structure( self._maybe_split_batch_beams, initial_state, self._cell.state_size) self._finished = array_ops.one_hot(array_ops.zeros([self._batch_size], dtype=dtypes.int32), depth=self._beam_width, on_value=False, off_value=True, dtype=dtypes.bool) # zero_mask is (batch, beam, vocab_size) self._zero_mask = _slice_mask(self._mask, slicing=[-1, 0, GO_ID, -1], squeeze=True, time_major=self._time_major) self._zero_mask = gen_array_ops.tile( array_ops.expand_dims(self._zero_mask, axis=1), [1, self._beam_width, 1]) self._zero_inputs = \ MaskedInputs( inputs=array_ops.zeros_like( self._split_batch_beams( self._input_layer(self._embedding_fn(self._start_tokens)), self._input_size)), mask=self._zero_mask)