Example #1
0
    def step(self, time, inputs, cell_state):
        """ Performs a step using the beam search cell
            :param time: The current time step (scalar)
            :param inputs: A (structure of) input tensors.
            :param state: A (structure of) state tensors and TensorArrays.
            :return: `(cell_outputs, next_cell_state)`.
        """
        raw_inputs = inputs
        inputs, candidates, candidates_emb = raw_inputs.inputs, raw_inputs.candidates, raw_inputs.candidates_emb

        inputs = nest.map_structure(lambda inp: self._merge_batch_beams(inp, depth_shape=inp.shape[2:]), inputs)
        cell_state = nest.map_structure(self._maybe_merge_batch_beams, cell_state, self._cell.state_size)
        cell_outputs, next_cell_state = self._cell(inputs, cell_state)                  # [batch * beam, out_sz]
        next_cell_state = nest.map_structure(self._maybe_split_batch_beams, next_cell_state, self._cell.state_size)

        # Splitting outputs and adding a bias dimension
        # cell_outputs is [batch, beam, cand_emb_size + 1]
        cell_outputs = self._output_layer(cell_outputs) if self._output_layer is not None else cell_outputs
        cell_outputs = nest.map_structure(lambda out: self._split_batch_beams(out, out.shape[1:]), cell_outputs)
        cell_outputs = array_ops.pad(cell_outputs, [(0, 0), (0, 0), (0, 1)], constant_values=1.)

        # Computing candidates
        # cell_outputs is reshaped to   [batch, beam,        1, cand_emb_size + 1]
        # candidates_emb is reshaped to [batch,    1, max_cand, cand_emb_size + 1]
        # output_mask is                [batch,    1, max_cand]
        # cell_outputs is finally       [batch, beam, max_cand]
        cell_outputs = math_ops.reduce_sum(array_ops.expand_dims(cell_outputs, axis=2)
                                           * array_ops.expand_dims(candidates_emb, axis=1), axis=-1)
        output_mask = math_ops.cast(array_ops.expand_dims(gen_math_ops.greater(candidates, 0), axis=1), dtypes.float32)
        cell_outputs = gen_math_ops.add(cell_outputs, (1. - output_mask) * LARGE_NEGATIVE)

        # Returning
        return cell_outputs, next_cell_state
Example #2
0
 def get_next_memory_and_attn():
     """ Gets the next memory and attention """
     next_memory = array_ops.concat(
         [
             state.memory,  # [b, t, mem_size]
             array_ops.expand_dims(self._input_fn(inputs), axis=1)
         ],
         axis=1)
     next_attention = self._compute_attention(inputs, next_memory)
     with ops.control_dependencies([next_memory, next_attention]):
         return array_ops.identity(next_memory), array_ops.identity(
             next_attention)
Example #3
0
    def next_inputs(self, time, inputs, beam_search_output, beam_search_state):
        """ Computes the inputs at the next time step given the beam outputs
            :param time: The current time step (scalar)
            :param inputs: A (structure of) input tensors.
            :param beam_search_output: The output of the beam search step
            :param beam_search_state: The state after the beam search step
            :return: `(beam_search_output, next_inputs)`
            :type beam_search_output: beam_search_decoder.BeamSearchDecoderOutput
            :type beam_search_state: beam_search_decoder.BeamSearchDecoderState
        """
        next_time = time + 1
        all_finished = math_ops.reduce_all(next_time >= self._sequence_length)

        # Sampling
        next_word_ids = beam_search_output.predicted_ids
        candidates = inputs.candidates
        nb_candidates = array_ops.shape(candidates)[1]
        sample_ids = math_ops.reduce_sum(array_ops.one_hot(next_word_ids, nb_candidates, dtype=dtypes.int32)
                                         * array_ops.expand_dims(candidates, axis=1), axis=-1)

        def get_next_inputs():
            """ Retrieves the inputs for the next time step """
            inputs_next_step = sample_ids
            inputs_emb_next_step = self._input_layer(self._order_embedding_fn(inputs_next_step))
            candidate_next_step = self._candidate_tas.read(next_time)
            candidate_emb_next_step = self._candidate_embedding_fn(candidate_next_step)

            # Prevents this branch from executing eagerly
            with ops.control_dependencies([inputs_emb_next_step, candidate_next_step, candidate_emb_next_step]):
                return CandidateInputs(inputs=array_ops.identity(inputs_emb_next_step),
                                       candidates=array_ops.identity(candidate_next_step),
                                       candidates_emb=array_ops.identity(candidate_emb_next_step))

        # Getting next inputs
        next_inputs = control_flow_ops.cond(all_finished,
                                            true_fn=lambda: self._zero_inputs,
                                            false_fn=get_next_inputs)

        # Rewriting beam search output with the correct sample ids
        beam_search_output = beam_search_decoder.BeamSearchDecoderOutput(scores=beam_search_output.scores,
                                                                         predicted_ids=sample_ids,
                                                                         parent_ids=beam_search_output.parent_ids)

        # Returning
        return beam_search_output, next_inputs
Example #4
0
    def _compute_attention(self, alignments, memory):
        """Computes the attention and alignments for a given attention_mechanism."""
        # Reshape from [batch_size, memory_time] to [batch_size, 1, memory_time]
        expanded_alignments = array_ops.expand_dims(alignments, 1)

        # Context is the inner product of alignments and values along the
        # memory time dimension.
        # alignments shape is  [batch_size, 1, memory_time]
        # memory is [batch_size, memory_time, memory_size]
        # the batched matmul is over memory_time, so the output shape is [batch_size, 1, memory_size].
        # we then squeeze out the singleton dim.
        context = math_ops.matmul(expanded_alignments, memory)
        context = array_ops.squeeze(context, [1])
        attn_layer = lambda x: x
        if self._attention_layer_size != self._memory_size:
            attn_layer = core.Dense(self._attention_layer_size,
                                    name='attn_layer',
                                    use_bias=False,
                                    dtype=context.dtype)
        attention = attn_layer(context)
        return attention, alignments
Example #5
0
    def _compute_attention(self, query, memory):
        """ Computes the attention and alignments for the Bahdanau attention mechanism .
            :param query: The query (inputs) to use to compute attention. Size [b, input_size]
            :param memory: The memory (previous outputs) used to compute attention [b, time_step, memory_size]
            :return: The attention. Size [b, attn_size]
        """
        assert len(
            memory.shape) == 3, 'Memory needs to be [batch, time, memory_size]'
        memory_time = array_ops.shape(memory)[1]
        memory_size = memory.shape[2]
        num_units = self._num_units
        assert self._memory_size == memory_size, 'Expected mem size of %s - Got %s' % (
            self._memory_size, memory_size)

        # Query, memory, and attention layers
        query_layer = core.Dense(num_units,
                                 name='query_layer',
                                 use_bias=False,
                                 dtype=self._dtype)
        memory_layer = lambda x: x
        if memory_size != self._num_units:
            memory_layer = core.Dense(num_units,
                                      name='memory_layer',
                                      use_bias=False,
                                      dtype=self._dtype)
        attn_layer = lambda x: x
        if self._attention_layer_size is not None and memory_size != self._attention_layer_size:
            attn_layer = core.Dense(self._attention_layer_size,
                                    name='attn_layer',
                                    use_bias=False,
                                    dtype=self._dtype)

        # Masking memory
        sequence_length = gen_math_ops.minimum(memory_time,
                                               self._sequence_length)
        sequence_mask = array_ops.sequence_mask(sequence_length,
                                                maxlen=memory_time,
                                                dtype=dtypes.float32)[...,
                                                                      None]
        values = memory * sequence_mask
        keys = memory_layer(values)

        # Computing scores
        processed_query = query_layer(query)
        scores = _bahdanau_score(processed_query, keys, self._normalize)

        # Getting alignments
        masked_scores = _maybe_mask_score(scores, sequence_length,
                                          self._score_mask_value)
        alignments = self._wrapped_probability_fn(masked_scores,
                                                  None)  # [batch, time]

        # Getting attention
        expanded_alignments = array_ops.expand_dims(alignments,
                                                    1)  # [batch, 1, time]
        context = math_ops.matmul(expanded_alignments,
                                  memory)  # [batch, 1, memory_size]
        context = array_ops.squeeze(context, [1])  # [batch, memory_size]
        attention = attn_layer(context)  # [batch, attn_size]

        # Returning attention
        return attention
Example #6
0
    def __init__(self,
                 cell,
                 embedding,
                 mask,
                 sequence_length,
                 initial_state,
                 beam_width,
                 input_layer=None,
                 output_layer=None,
                 time_major=False):
        """ Initialize the CustomBeamHelper
            :param cell: An `RNNCell` instance.
            :param embedding: The embedding vector
            :param mask: [SparseTensor] Mask to apply at each time step -- Size: (b, dec_len, vocab_size, vocab_size)
            :param sequence_length: The length of each input (b,)
            :param initial_state: A (possibly nested tuple of...) tensors and TensorArrays.
            :param beam_width: Python integer, the number of beams.
            :param input_layer: Optional. A layer to apply on the inputs
            :param output_layer: Optional. An instance of `tf.layers.Layer`, i.e., `tf.layers.Dense`. Optional layer
                                 to apply to the RNN output prior to storing the result or sampling.
            :param time_major: If true indicates that the first dimension is time, otherwise it is batch size.
        """
        # pylint: disable=super-init-not-called,too-many-arguments
        rnn_cell_impl.assert_like_rnncell('cell', cell)  # pylint: disable=protected-access
        assert isinstance(mask,
                          SparseTensor), 'The mask must be a SparseTensor'
        assert isinstance(beam_width,
                          int), 'beam_width should be a Python integer'

        self._sequence_length = ops.convert_to_tensor(sequence_length,
                                                      name='sequence_length')
        if self._sequence_length.get_shape().ndims != 1:
            raise ValueError("Expected vector for sequence_length. Shape: %s" %
                             self._sequence_length.get_shape())

        self._cell = cell
        self._embedding_fn = _get_embedding_fn(embedding)
        self._mask = mask
        self._time_major = time_major
        self.vocab_size = VOCABULARY_SIZE
        self._input_layer = input_layer if input_layer is not None else lambda x: x
        self._output_layer = output_layer

        self._input_size = embedding.shape[-1]
        if input_layer is not None:
            self._input_size = self._input_layer.compute_output_shape(
                [None, self._input_size])[-1]

        self._batch_size = array_ops.size(sequence_length)
        self._start_tokens = gen_array_ops.fill(
            [self._batch_size * beam_width], GO_ID)
        self._end_token = -1
        self._beam_width = beam_width
        self._initial_cell_state = nest.map_structure(
            self._maybe_split_batch_beams, initial_state,
            self._cell.state_size)
        self._finished = array_ops.one_hot(array_ops.zeros([self._batch_size],
                                                           dtype=dtypes.int32),
                                           depth=self._beam_width,
                                           on_value=False,
                                           off_value=True,
                                           dtype=dtypes.bool)

        # zero_mask is (batch, beam, vocab_size)
        self._zero_mask = _slice_mask(self._mask,
                                      slicing=[-1, 0, GO_ID, -1],
                                      squeeze=True,
                                      time_major=self._time_major)
        self._zero_mask = gen_array_ops.tile(
            array_ops.expand_dims(self._zero_mask, axis=1),
            [1, self._beam_width, 1])
        self._zero_inputs = \
            MaskedInputs(
                inputs=array_ops.zeros_like(
                    self._split_batch_beams(
                        self._input_layer(self._embedding_fn(self._start_tokens)), self._input_size)),
                mask=self._zero_mask)