Exemple #1
0
        def get_next_inputs():
            """ Retrieves the inputs for the next time step """
            inputs_next_step = sample_ids
            inputs_emb_next_step = self._input_layer(
                self._embedding_fn(inputs_next_step))  # [bat, beam, in_sz]

            # Applying mask
            # inputs_one_hot:   (batch, beam,   1, VOC,   1)
            # mask_t:           (batch,    1,   1, VOC, VOC)
            # next_mask:        (batch, beam, VOC)
            inputs_one_hot = array_ops.one_hot(inputs_next_step,
                                               self.vocab_size)[:, :, None, :,
                                                                None]
            mask_t = sparse_ops.sparse_tensor_to_dense(
                _slice_mask(self._mask, [-1, next_time, -1, -1],
                            time_major=self._time_major))[:, None, :, :, :]
            mask_t.set_shape([None, 1, 1, self.vocab_size, self.vocab_size])
            next_mask = math_ops.reduce_sum(inputs_one_hot * mask_t,
                                            axis=[2, 3])
            next_mask = gen_math_ops.minimum(next_mask, 1.)

            # Prevents this branch from executing eagerly
            with ops.control_dependencies([inputs_emb_next_step, next_mask]):
                return MaskedInputs(
                    inputs=array_ops.identity(inputs_emb_next_step),
                    mask=array_ops.identity(next_mask))
Exemple #2
0
    def __init__(self, cell, order_embedding, candidate_embedding, candidates, sequence_length, initial_state,
                 beam_width, input_layer=None, output_layer=None, time_major=False):
        """ Initialize the CustomBeamHelper
            :param cell: An `RNNCell` instance.
            :param order_embedding: The order embedding vector  - Size: (batch, ord_emb_size)
            :param candidate_embedding: The candidate embedding vector - Size: (batch, cand_emb_size)
            :param candidates: The candidates at each time step -- Size: (batch, nb_cand, max_candidates)
            :param sequence_length: The length of each sequence (batch,)
            :param initial_state: A (possibly nested tuple of...) tensors and TensorArrays.
            :param beam_width: Python integer, the number of beams.
            :param input_layer: Optional. A layer to apply on the inputs
            :param output_layer: Optional. An instance of `tf.layers.Layer`, i.e., `tf.layers.Dense`. Optional layer
                                 to apply to the RNN output prior to storing the result or sampling.
            :param time_major: If true indicates that the first dimension is time, otherwise it is batch size.
        """
        # pylint: disable=super-init-not-called,too-many-arguments
        rnn_cell_impl.assert_like_rnncell('cell', cell)                                                                 # pylint: disable=protected-access
        assert isinstance(beam_width, int), 'beam_width should be a Python integer'

        self._sequence_length = ops.convert_to_tensor(sequence_length, name='sequence_length')
        if self._sequence_length.get_shape().ndims != 1:
            raise ValueError("Expected vector for sequence_length. Shape: %s" % self._sequence_length.get_shape())

        candidates = ops.convert_to_tensor(candidates, name='candidates')
        candidates = nest.map_structure(_transpose_batch_time, candidates) if not time_major else candidates

        self._cell = cell
        self._order_embedding_fn = _get_embedding_fn(order_embedding)
        self._candidate_embedding_fn = _get_embedding_fn(candidate_embedding)
        self._candidate_tas = nest.map_structure(_unstack_ta, candidates)
        self._input_layer = input_layer if input_layer is not None else lambda x: x
        self._output_layer = output_layer

        self._input_size = order_embedding.shape[-1]
        if input_layer is not None:
            self._input_size = self._input_layer.compute_output_shape([None, self._input_size])[-1]

        self._batch_size = array_ops.size(sequence_length)
        self._start_tokens = gen_array_ops.fill([self._batch_size * beam_width], GO_ID)
        self._end_token = -1
        self._beam_width = beam_width
        self._initial_cell_state = nest.map_structure(self._maybe_split_batch_beams,
                                                      initial_state,
                                                      self._cell.state_size)
        self._finished = array_ops.one_hot(array_ops.zeros([self._batch_size], dtype=dtypes.int32),
                                           depth=self._beam_width,
                                           on_value=False,
                                           off_value=True,
                                           dtype=dtypes.bool)

        # Compute input shape
        self._zero_inputs = \
            CandidateInputs(inputs=
                            array_ops.zeros_like(self._split_batch_beams(
                                self._input_layer(self._order_embedding_fn(self._start_tokens)),
                                self._input_size)),
                            candidates=array_ops.zeros_like(candidates[0, :]),
                            candidates_emb=array_ops.zeros_like(self._candidate_embedding_fn(candidates[0, :])))
Exemple #3
0
 def greedy():
     """ Selecting greedy """
     argmax_id = math_ops.cast(math_ops.argmax(cell_outputs, axis=-1), dtypes.int32)
     nb_candidate = array_ops.shape(candidate)[1]
     candidate_ids = \
         math_ops.reduce_sum(array_ops.one_hot(argmax_id, nb_candidate, dtype=dtypes.int32) * candidate,
                             axis=-1)
     with ops.control_dependencies([candidate_ids]):
         return array_ops.identity(candidate_ids)
Exemple #4
0
 def sample():
     """ Sampling """
     logits = cell_outputs if self._softmax_temperature is None else cell_outputs / self._softmax_temperature
     sample_id_sampler = categorical.Categorical(logits=logits)
     sample_ids = sample_id_sampler.sample(seed=self._seed)
     nb_candidate = array_ops.shape(candidate)[1]
     reduce_op = math_ops.reduce_sum(array_ops.one_hot(sample_ids,
                                                       nb_candidate,
                                                       dtype=dtypes.int32) * candidate, axis=-1)
     with ops.control_dependencies([reduce_op]):
         return array_ops.identity(reduce_op)
Exemple #5
0
    def next_inputs(self, time, inputs, beam_search_output, beam_search_state):
        """ Computes the inputs at the next time step given the beam outputs
            :param time: The current time step (scalar)
            :param inputs: A (structure of) input tensors.
            :param beam_search_output: The output of the beam search step
            :param beam_search_state: The state after the beam search step
            :return: `(beam_search_output, next_inputs)`
            :type beam_search_output: beam_search_decoder.BeamSearchDecoderOutput
            :type beam_search_state: beam_search_decoder.BeamSearchDecoderState
        """
        next_time = time + 1
        all_finished = math_ops.reduce_all(next_time >= self._sequence_length)

        # Sampling
        next_word_ids = beam_search_output.predicted_ids
        candidates = inputs.candidates
        nb_candidates = array_ops.shape(candidates)[1]
        sample_ids = math_ops.reduce_sum(array_ops.one_hot(next_word_ids, nb_candidates, dtype=dtypes.int32)
                                         * array_ops.expand_dims(candidates, axis=1), axis=-1)

        def get_next_inputs():
            """ Retrieves the inputs for the next time step """
            inputs_next_step = sample_ids
            inputs_emb_next_step = self._input_layer(self._order_embedding_fn(inputs_next_step))
            candidate_next_step = self._candidate_tas.read(next_time)
            candidate_emb_next_step = self._candidate_embedding_fn(candidate_next_step)

            # Prevents this branch from executing eagerly
            with ops.control_dependencies([inputs_emb_next_step, candidate_next_step, candidate_emb_next_step]):
                return CandidateInputs(inputs=array_ops.identity(inputs_emb_next_step),
                                       candidates=array_ops.identity(candidate_next_step),
                                       candidates_emb=array_ops.identity(candidate_emb_next_step))

        # Getting next inputs
        next_inputs = control_flow_ops.cond(all_finished,
                                            true_fn=lambda: self._zero_inputs,
                                            false_fn=get_next_inputs)

        # Rewriting beam search output with the correct sample ids
        beam_search_output = beam_search_decoder.BeamSearchDecoderOutput(scores=beam_search_output.scores,
                                                                         predicted_ids=sample_ids,
                                                                         parent_ids=beam_search_output.parent_ids)

        # Returning
        return beam_search_output, next_inputs
Exemple #6
0
            def get_next_inputs():
                """ Retrieves the inputs for the next time step """
                def get_training_inputs():
                    """ Selecting training inputs """
                    read_op = self._input_tas.read(next_time)
                    with ops.control_dependencies([read_op]):
                        return array_ops.identity(read_op)

                def get_sample_inputs():
                    """ Selecting greedy/sample inputs """
                    return sample_ids

                inputs_next_step = control_flow_ops.case(
                    [(gen_math_ops.equal(self._decoder_type, TRAINING_DECODER),
                      get_training_inputs),
                     (gen_math_ops.equal(self._decoder_type,
                                         GREEDY_DECODER), get_sample_inputs),
                     (gen_math_ops.equal(self._decoder_type,
                                         SAMPLE_DECODER), get_sample_inputs)],
                    default=get_training_inputs)
                inputs_emb_next_step = self._input_layer(
                    self._embedding_fn(inputs_next_step))

                # Applying mask
                # inputs_one_hot:   (b, 1, VOC, 1)
                # mask_t:           (b, 1, VOC, VOC)
                # next_mask:        (b, VOC)        -- DenseTensor
                inputs_one_hot = array_ops.one_hot(inputs_next_step,
                                                   self.vocab_size)[:, None, :,
                                                                    None]
                mask_t = _slice_mask(self._mask, [-1, next_time, -1, -1],
                                     time_major=self._time_major)
                next_mask = sparse_ops.sparse_reduce_sum(inputs_one_hot *
                                                         mask_t,
                                                         axis=[1, 2])
                next_mask = gen_math_ops.minimum(next_mask, 1.)
                next_mask.set_shape([None, self.vocab_size])

                # Prevents this branch from executing eagerly
                with ops.control_dependencies(
                    [inputs_emb_next_step, next_mask]):
                    return MaskedInputs(
                        inputs=array_ops.identity(inputs_emb_next_step),
                        mask=array_ops.identity(next_mask))
Exemple #7
0
    def __init__(self,
                 cell,
                 embedding,
                 mask,
                 sequence_length,
                 initial_state,
                 beam_width,
                 input_layer=None,
                 output_layer=None,
                 time_major=False):
        """ Initialize the CustomBeamHelper
            :param cell: An `RNNCell` instance.
            :param embedding: The embedding vector
            :param mask: [SparseTensor] Mask to apply at each time step -- Size: (b, dec_len, vocab_size, vocab_size)
            :param sequence_length: The length of each input (b,)
            :param initial_state: A (possibly nested tuple of...) tensors and TensorArrays.
            :param beam_width: Python integer, the number of beams.
            :param input_layer: Optional. A layer to apply on the inputs
            :param output_layer: Optional. An instance of `tf.layers.Layer`, i.e., `tf.layers.Dense`. Optional layer
                                 to apply to the RNN output prior to storing the result or sampling.
            :param time_major: If true indicates that the first dimension is time, otherwise it is batch size.
        """
        # pylint: disable=super-init-not-called,too-many-arguments
        rnn_cell_impl.assert_like_rnncell('cell', cell)  # pylint: disable=protected-access
        assert isinstance(mask,
                          SparseTensor), 'The mask must be a SparseTensor'
        assert isinstance(beam_width,
                          int), 'beam_width should be a Python integer'

        self._sequence_length = ops.convert_to_tensor(sequence_length,
                                                      name='sequence_length')
        if self._sequence_length.get_shape().ndims != 1:
            raise ValueError("Expected vector for sequence_length. Shape: %s" %
                             self._sequence_length.get_shape())

        self._cell = cell
        self._embedding_fn = _get_embedding_fn(embedding)
        self._mask = mask
        self._time_major = time_major
        self.vocab_size = VOCABULARY_SIZE
        self._input_layer = input_layer if input_layer is not None else lambda x: x
        self._output_layer = output_layer

        self._input_size = embedding.shape[-1]
        if input_layer is not None:
            self._input_size = self._input_layer.compute_output_shape(
                [None, self._input_size])[-1]

        self._batch_size = array_ops.size(sequence_length)
        self._start_tokens = gen_array_ops.fill(
            [self._batch_size * beam_width], GO_ID)
        self._end_token = -1
        self._beam_width = beam_width
        self._initial_cell_state = nest.map_structure(
            self._maybe_split_batch_beams, initial_state,
            self._cell.state_size)
        self._finished = array_ops.one_hot(array_ops.zeros([self._batch_size],
                                                           dtype=dtypes.int32),
                                           depth=self._beam_width,
                                           on_value=False,
                                           off_value=True,
                                           dtype=dtypes.bool)

        # zero_mask is (batch, beam, vocab_size)
        self._zero_mask = _slice_mask(self._mask,
                                      slicing=[-1, 0, GO_ID, -1],
                                      squeeze=True,
                                      time_major=self._time_major)
        self._zero_mask = gen_array_ops.tile(
            array_ops.expand_dims(self._zero_mask, axis=1),
            [1, self._beam_width, 1])
        self._zero_inputs = \
            MaskedInputs(
                inputs=array_ops.zeros_like(
                    self._split_batch_beams(
                        self._input_layer(self._embedding_fn(self._start_tokens)), self._input_size)),
                mask=self._zero_mask)