Beispiel #1
0
    def sample(self, time, outputs, state, name=None):
        """ Samples the id for the next time step (or -1 for teacher forcing) """
        with ops.name_scope(name, 'CustomHelperSample',
                            [time, outputs, state]):

            def training():
                """ Selecting training / teacher forcing """
                fill_op = gen_array_ops.fill([array_ops.shape(outputs)[0]], -1)
                with ops.control_dependencies([fill_op]):
                    return array_ops.identity(fill_op)

            def greedy():
                """ Selecting greedy """
                argmax_op = math_ops.argmax(outputs,
                                            axis=-1,
                                            output_type=dtypes.int32)
                with ops.control_dependencies([argmax_op]):
                    return array_ops.identity(argmax_op)

            def sample():
                """ Sampling """
                logits = outputs if self._softmax_temperature is None else outputs / self._softmax_temperature
                sample_id_sampler = categorical.Categorical(logits=logits)
                sample_op = sample_id_sampler.sample(seed=self._seed)
                with ops.control_dependencies([sample_op]):
                    return array_ops.identity(sample_op)

            return control_flow_ops.case(
                [(gen_math_ops.equal(self._decoder_type,
                                     TRAINING_DECODER), training),
                 (gen_math_ops.equal(self._decoder_type,
                                     GREEDY_DECODER), greedy),
                 (gen_math_ops.equal(self._decoder_type,
                                     SAMPLE_DECODER), sample)],
                default=training)
Beispiel #2
0
    def initialize(self, name=None):
        """ Performs helper initialization (to get initial state) """
        with ops.name_scope(name, 'CustomHelperInitialize'):
            finished = gen_math_ops.equal(0, self._sequence_length)
            all_finished = math_ops.reduce_all(finished)
            initial_candidates = self._candidate_tas.read(0)

            def training_inputs():
                """ Returns the training initial input """
                embed_op = self._order_embedding_fn(self._input_tas.read(0))
                with ops.control_dependencies([embed_op]):
                    return array_ops.identity(embed_op)

            def start_inputs():
                """ Returns the GO_ID initial input """
                embed_op = self._order_embedding_fn(self._start_inputs)
                with ops.control_dependencies([embed_op]):
                    return array_ops.identity(embed_op)

            # Getting initial inputs
            initial_inputs = control_flow_ops.case(
                [(gen_math_ops.equal(self._decoder_type, TRAINING_DECODER), training_inputs),
                 (gen_math_ops.equal(self._decoder_type, GREEDY_DECODER), start_inputs),
                 (gen_math_ops.equal(self._decoder_type, SAMPLE_DECODER), start_inputs)],
                default=training_inputs)

            next_inputs = \
                control_flow_ops.cond(all_finished,
                                      lambda: self._zero_inputs,
                                      lambda: CandidateInputs(
                                          inputs=self._input_layer(initial_inputs),
                                          candidates=initial_candidates,
                                          candidates_emb=self._candidate_embedding_fn(initial_candidates)))
            return (finished, next_inputs)
Beispiel #3
0
            def get_next_inputs():
                """ Retrieves the inputs for the next time step """
                def get_training_inputs():
                    """ Selecting training inputs """
                    read_op = self._input_tas.read(next_time)
                    with ops.control_dependencies([read_op]):
                        return array_ops.identity(read_op)

                def get_sample_inputs():
                    """ Selecting greedy/sample inputs """
                    return sample_ids

                inputs_next_step = control_flow_ops.case(
                    [(gen_math_ops.equal(self._decoder_type, TRAINING_DECODER), get_training_inputs),
                     (gen_math_ops.equal(self._decoder_type, GREEDY_DECODER), get_sample_inputs),
                     (gen_math_ops.equal(self._decoder_type, SAMPLE_DECODER), get_sample_inputs)],
                    default=get_training_inputs)
                inputs_emb_next_step = self._input_layer(self._order_embedding_fn(inputs_next_step))
                candidate_next_step = self._candidate_tas.read(next_time)
                candidate_emb_next_step = self._candidate_embedding_fn(candidate_next_step)

                # Prevents this branch from executing eagerly
                with ops.control_dependencies([inputs_emb_next_step, candidate_next_step, candidate_emb_next_step]):
                    return CandidateInputs(inputs=array_ops.identity(inputs_emb_next_step),
                                           candidates=array_ops.identity(candidate_next_step),
                                           candidates_emb=array_ops.identity(candidate_emb_next_step))
Beispiel #4
0
            def get_next_inputs():
                """ Retrieves the inputs for the next time step """
                def get_training_inputs():
                    """ Selecting training inputs """
                    read_op = self._input_tas.read(next_time)
                    with ops.control_dependencies([read_op]):
                        return array_ops.identity(read_op)

                def get_sample_inputs():
                    """ Selecting greedy/sample inputs """
                    return sample_ids

                inputs_next_step = control_flow_ops.case(
                    [(gen_math_ops.equal(self._decoder_type, TRAINING_DECODER),
                      get_training_inputs),
                     (gen_math_ops.equal(self._decoder_type,
                                         GREEDY_DECODER), get_sample_inputs),
                     (gen_math_ops.equal(self._decoder_type,
                                         SAMPLE_DECODER), get_sample_inputs)],
                    default=get_training_inputs)
                inputs_emb_next_step = self._input_layer(
                    self._embedding_fn(inputs_next_step))

                # Applying mask
                # inputs_one_hot:   (b, 1, VOC, 1)
                # mask_t:           (b, 1, VOC, VOC)
                # next_mask:        (b, VOC)        -- DenseTensor
                inputs_one_hot = array_ops.one_hot(inputs_next_step,
                                                   self.vocab_size)[:, None, :,
                                                                    None]
                mask_t = _slice_mask(self._mask, [-1, next_time, -1, -1],
                                     time_major=self._time_major)
                next_mask = sparse_ops.sparse_reduce_sum(inputs_one_hot *
                                                         mask_t,
                                                         axis=[1, 2])
                next_mask = gen_math_ops.minimum(next_mask, 1.)
                next_mask.set_shape([None, self.vocab_size])

                # Prevents this branch from executing eagerly
                with ops.control_dependencies(
                    [inputs_emb_next_step, next_mask]):
                    return MaskedInputs(
                        inputs=array_ops.identity(inputs_emb_next_step),
                        mask=array_ops.identity(next_mask))
Beispiel #5
0
    def sample(self, time, outputs, state, name=None):
        """ Samples the id for the next time step (or -1 for teacher forcing)
            Note: outputs is a tuple of (cell_outputs, candidate)
        """
        cell_outputs, candidate = outputs

        with ops.name_scope(name, 'CustomHelperSample', [time, outputs, state]):

            def training():
                """ Selecting training / teacher forcing """
                fill_op = gen_array_ops.fill([array_ops.shape(cell_outputs)[0]], -1)
                with ops.control_dependencies([fill_op]):
                    return array_ops.identity(fill_op)

            def greedy():
                """ Selecting greedy """
                argmax_id = math_ops.cast(math_ops.argmax(cell_outputs, axis=-1), dtypes.int32)
                nb_candidate = array_ops.shape(candidate)[1]
                candidate_ids = \
                    math_ops.reduce_sum(array_ops.one_hot(argmax_id, nb_candidate, dtype=dtypes.int32) * candidate,
                                        axis=-1)
                with ops.control_dependencies([candidate_ids]):
                    return array_ops.identity(candidate_ids)

            def sample():
                """ Sampling """
                logits = cell_outputs if self._softmax_temperature is None else cell_outputs / self._softmax_temperature
                sample_id_sampler = categorical.Categorical(logits=logits)
                sample_ids = sample_id_sampler.sample(seed=self._seed)
                nb_candidate = array_ops.shape(candidate)[1]
                reduce_op = math_ops.reduce_sum(array_ops.one_hot(sample_ids,
                                                                  nb_candidate,
                                                                  dtype=dtypes.int32) * candidate, axis=-1)
                with ops.control_dependencies([reduce_op]):
                    return array_ops.identity(reduce_op)

            return control_flow_ops.case([(gen_math_ops.equal(self._decoder_type, TRAINING_DECODER), training),
                                          (gen_math_ops.equal(self._decoder_type, GREEDY_DECODER), greedy),
                                          (gen_math_ops.equal(self._decoder_type, SAMPLE_DECODER), sample)],
                                         default=training)