Example #1
0
        def get_next_inputs():
            """ Retrieves the inputs for the next time step """
            inputs_next_step = sample_ids
            inputs_emb_next_step = self._input_layer(
                self._embedding_fn(inputs_next_step))  # [bat, beam, in_sz]

            # Applying mask
            # inputs_one_hot:   (batch, beam,   1, VOC,   1)
            # mask_t:           (batch,    1,   1, VOC, VOC)
            # next_mask:        (batch, beam, VOC)
            inputs_one_hot = array_ops.one_hot(inputs_next_step,
                                               self.vocab_size)[:, :, None, :,
                                                                None]
            mask_t = sparse_ops.sparse_tensor_to_dense(
                _slice_mask(self._mask, [-1, next_time, -1, -1],
                            time_major=self._time_major))[:, None, :, :, :]
            mask_t.set_shape([None, 1, 1, self.vocab_size, self.vocab_size])
            next_mask = math_ops.reduce_sum(inputs_one_hot * mask_t,
                                            axis=[2, 3])
            next_mask = gen_math_ops.minimum(next_mask, 1.)

            # Prevents this branch from executing eagerly
            with ops.control_dependencies([inputs_emb_next_step, next_mask]):
                return MaskedInputs(
                    inputs=array_ops.identity(inputs_emb_next_step),
                    mask=array_ops.identity(next_mask))
Example #2
0
    def step(self, time, inputs, cell_state):
        """ Performs a step using the beam search cell
            :param time: The current time step (scalar)
            :param inputs: A (structure of) input tensors.
            :param state: A (structure of) state tensors and TensorArrays.
            :return: `(cell_outputs, next_cell_state)`.
        """
        raw_inputs = inputs
        inputs, candidates, candidates_emb = raw_inputs.inputs, raw_inputs.candidates, raw_inputs.candidates_emb

        inputs = nest.map_structure(lambda inp: self._merge_batch_beams(inp, depth_shape=inp.shape[2:]), inputs)
        cell_state = nest.map_structure(self._maybe_merge_batch_beams, cell_state, self._cell.state_size)
        cell_outputs, next_cell_state = self._cell(inputs, cell_state)                  # [batch * beam, out_sz]
        next_cell_state = nest.map_structure(self._maybe_split_batch_beams, next_cell_state, self._cell.state_size)

        # Splitting outputs and adding a bias dimension
        # cell_outputs is [batch, beam, cand_emb_size + 1]
        cell_outputs = self._output_layer(cell_outputs) if self._output_layer is not None else cell_outputs
        cell_outputs = nest.map_structure(lambda out: self._split_batch_beams(out, out.shape[1:]), cell_outputs)
        cell_outputs = array_ops.pad(cell_outputs, [(0, 0), (0, 0), (0, 1)], constant_values=1.)

        # Computing candidates
        # cell_outputs is reshaped to   [batch, beam,        1, cand_emb_size + 1]
        # candidates_emb is reshaped to [batch,    1, max_cand, cand_emb_size + 1]
        # output_mask is                [batch,    1, max_cand]
        # cell_outputs is finally       [batch, beam, max_cand]
        cell_outputs = math_ops.reduce_sum(array_ops.expand_dims(cell_outputs, axis=2)
                                           * array_ops.expand_dims(candidates_emb, axis=1), axis=-1)
        output_mask = math_ops.cast(array_ops.expand_dims(gen_math_ops.greater(candidates, 0), axis=1), dtypes.float32)
        cell_outputs = gen_math_ops.add(cell_outputs, (1. - output_mask) * LARGE_NEGATIVE)

        # Returning
        return cell_outputs, next_cell_state
Example #3
0
 def greedy():
     """ Selecting greedy """
     argmax_id = math_ops.cast(math_ops.argmax(cell_outputs, axis=-1), dtypes.int32)
     nb_candidate = array_ops.shape(candidate)[1]
     candidate_ids = \
         math_ops.reduce_sum(array_ops.one_hot(argmax_id, nb_candidate, dtype=dtypes.int32) * candidate,
                             axis=-1)
     with ops.control_dependencies([candidate_ids]):
         return array_ops.identity(candidate_ids)
Example #4
0
 def sample():
     """ Sampling """
     logits = cell_outputs if self._softmax_temperature is None else cell_outputs / self._softmax_temperature
     sample_id_sampler = categorical.Categorical(logits=logits)
     sample_ids = sample_id_sampler.sample(seed=self._seed)
     nb_candidate = array_ops.shape(candidate)[1]
     reduce_op = math_ops.reduce_sum(array_ops.one_hot(sample_ids,
                                                       nb_candidate,
                                                       dtype=dtypes.int32) * candidate, axis=-1)
     with ops.control_dependencies([reduce_op]):
         return array_ops.identity(reduce_op)
Example #5
0
    def step(self, time, inputs, state, name=None):
        """ Performs a decoding step
            :param time: scalar `int32` tensor.
            :param inputs: A (structure of) input tensors.  (** This is a MaskedInputs tuple **)
            :param state: A (structure of) state tensors and TensorArrays.
            :param name: Name scope for any created operations.
            :return: (outputs, next_state, next_inputs, finished)
        """
        assert isinstance(
            inputs,
            CandidateInputs), 'The inputs must be of type "CandidateInputs"'
        with ops.name_scope(name, "BasicDecoderStep", (time, inputs, state)):
            inputs, candidates, candidates_emb = inputs.inputs, inputs.candidates, inputs.candidates_emb
            cell_outputs, cell_state = self._cell(inputs, state)
            cell_state_output = cell_outputs  # Corresponds to cell_state.h (before output layer)
            if self._output_layer is not None:
                cell_outputs = self._output_layer(cell_outputs)

            # Adding a bias dimension, then computing candidate logits and masking PAD_IDs
            cell_outputs = array_ops.pad(cell_outputs, [(0, 0), (0, 1)],
                                         constant_values=1.)
            cell_outputs = math_ops.reduce_sum(cell_outputs[:, None, :] *
                                               candidates_emb,
                                               axis=-1)
            output_mask = math_ops.cast(gen_math_ops.greater(candidates, 0),
                                        dtypes.float32)
            cell_outputs = gen_math_ops.add(cell_outputs, (1. - output_mask) *
                                            LARGE_NEGATIVE)

            # Sampling and computing next inputs
            sample_ids = self._helper.sample(time=time,
                                             outputs=(cell_outputs,
                                                      candidates),
                                             state=cell_state)
            (finished, next_inputs,
             next_state) = self._helper.next_inputs(time=time,
                                                    outputs=cell_outputs,
                                                    state=cell_state,
                                                    sample_ids=sample_ids)
        if self.extract_state:
            outputs = BasicDecoderWithStateOutput(cell_outputs,
                                                  cell_state_output,
                                                  sample_ids)
        else:
            outputs = seq2seq.BasicDecoderOutput(cell_outputs, sample_ids)
        return outputs, next_state, next_inputs, finished
Example #6
0
    def next_inputs(self, time, inputs, beam_search_output, beam_search_state):
        """ Computes the inputs at the next time step given the beam outputs
            :param time: The current time step (scalar)
            :param inputs: A (structure of) input tensors.
            :param beam_search_output: The output of the beam search step
            :param beam_search_state: The state after the beam search step
            :return: `(beam_search_output, next_inputs)`
            :type beam_search_output: beam_search_decoder.BeamSearchDecoderOutput
            :type beam_search_state: beam_search_decoder.BeamSearchDecoderState
        """
        next_time = time + 1
        all_finished = math_ops.reduce_all(next_time >= self._sequence_length)

        # Sampling
        next_word_ids = beam_search_output.predicted_ids
        candidates = inputs.candidates
        nb_candidates = array_ops.shape(candidates)[1]
        sample_ids = math_ops.reduce_sum(array_ops.one_hot(next_word_ids, nb_candidates, dtype=dtypes.int32)
                                         * array_ops.expand_dims(candidates, axis=1), axis=-1)

        def get_next_inputs():
            """ Retrieves the inputs for the next time step """
            inputs_next_step = sample_ids
            inputs_emb_next_step = self._input_layer(self._order_embedding_fn(inputs_next_step))
            candidate_next_step = self._candidate_tas.read(next_time)
            candidate_emb_next_step = self._candidate_embedding_fn(candidate_next_step)

            # Prevents this branch from executing eagerly
            with ops.control_dependencies([inputs_emb_next_step, candidate_next_step, candidate_emb_next_step]):
                return CandidateInputs(inputs=array_ops.identity(inputs_emb_next_step),
                                       candidates=array_ops.identity(candidate_next_step),
                                       candidates_emb=array_ops.identity(candidate_emb_next_step))

        # Getting next inputs
        next_inputs = control_flow_ops.cond(all_finished,
                                            true_fn=lambda: self._zero_inputs,
                                            false_fn=get_next_inputs)

        # Rewriting beam search output with the correct sample ids
        beam_search_output = beam_search_decoder.BeamSearchDecoderOutput(scores=beam_search_output.scores,
                                                                         predicted_ids=sample_ids,
                                                                         parent_ids=beam_search_output.parent_ids)

        # Returning
        return beam_search_output, next_inputs