Ejemplo n.º 1
0
    def test_step_with_eos(self):
        beam_state = beam_search.BeamSearchState(
            log_probs=tf.nn.log_softmax(tf.ones(self.config.beam_width)),
            lengths=tf.convert_to_tensor([2, 1, 2], dtype=tf.int32),
            finished=tf.constant([False, True, False], dtype=tf.bool))

        logits_ = np.full([self.config.beam_width, self.config.vocab_size],
                          0.0001)
        logits_[0, 2] = 1.1
        logits_[1, 2] = 1.0
        logits_[2, 2] = 1.0
        logits = tf.convert_to_tensor(logits_, dtype=tf.float32)
        log_probs = tf.nn.log_softmax(logits)

        outputs, next_beam_state = beam_search.beam_search_step(
            time_=2, logits=logits, beam_state=beam_state, config=self.config)

        with self.test_session() as sess:
            outputs_, next_state_, state_, log_probs_ = sess.run(
                [outputs, next_beam_state, beam_state, log_probs])

        np.testing.assert_array_equal(outputs_.predicted_ids, [0, 2, 2])
        np.testing.assert_array_equal(outputs_.beam_parent_ids, [1, 0, 2])
        np.testing.assert_array_equal(next_state_.lengths, [1, 3, 3])
        np.testing.assert_array_equal(next_state_.finished,
                                      [True, False, False])

        expected_log_probs = state_.log_probs[outputs_.beam_parent_ids]
        expected_log_probs[1] += log_probs_[0, 2]
        expected_log_probs[2] += log_probs_[2, 2]
        np.testing.assert_array_equal(next_state_.log_probs,
                                      expected_log_probs)
Ejemplo n.º 2
0
    def test_step_with_eos(self):
        beam_state = beam_search.BeamState(
            time=tf.constant(2),
            log_probs=tf.nn.log_softmax(tf.ones(self.config.beam_width)),
            scores=tf.nn.log_softmax(tf.ones(self.config.beam_width)),
            predicted_ids=tf.convert_to_tensor([[1, 2, -1, -1, -1],
                                                [3, 0, -1, -1, -1],
                                                [5, 6, -1, -1, -1]]),
            beam_parent_ids=tf.zeros(self.config.beam_width))
        logits = tf.sparse_to_dense(
            [[0, 2], [1, 2], [2, 2]],
            output_shape=[self.config.beam_width, self.config.vocab_size],
            sparse_values=[1.0, 1.0, 1.0],
            default_value=0.0001)
        next_beam_state = beam_search.beam_search_step(logits=logits,
                                                       beam_state=beam_state,
                                                       config=self.config)

        with self.test_session() as sess:
            res = sess.run(next_beam_state)
            expected_predictions = np.array([[3, 0, 0, -1, -1],
                                             [1, 2, 2, -1, -1],
                                             [5, 6, 2, -1, -1]])
            np.testing.assert_array_equal(res.predicted_ids,
                                          expected_predictions)
            previous_log_probs = sess.run(beam_state.log_probs)
            np.testing.assert_array_equal(res.log_probs[0],
                                          previous_log_probs[0])
Ejemplo n.º 3
0
  def step(self, time, inputs, state, name=None):
   
    cur_inputs = inputs[:,0:time+1,:] 
    zeros_padding = inputs[:,time+2:,:] 
    cur_inputs_pos = self.add_position_embedding(cur_inputs, time)
    
    enc_output, beam_state = state 
    logits = self.infer_conv_block(enc_output, cur_inputs_pos)
    
    bs_output, beam_state = beam_search.beam_search_step(
        time_=time,
        logits=logits,
        beam_state=beam_state,
        config=self.config)

    finished, next_inputs = self.next_inputs(sample_ids=bs_output.predicted_ids)
    next_inputs = tf.reshape(next_inputs, [self.config.beam_width, 1, inputs.get_shape().as_list()[-1]])
    next_inputs = tf.concat([cur_inputs, next_inputs], axis=1)
    next_inputs = tf.concat([next_inputs, zeros_padding], axis=1)
    next_inputs.set_shape([self.config.beam_width, self.params['max_decode_length'], inputs.get_shape().as_list()[-1]])
    outputs = BeamDecoderOutput(
        logits=tf.zeros([self.config.beam_width, self.config.vocab_size]),
        predicted_ids=bs_output.predicted_ids,
        log_probs=beam_state.log_probs,
        scores=bs_output.scores,
        beam_parent_ids=bs_output.beam_parent_ids)
    return outputs, (enc_output,beam_state), next_inputs, finished
    def step(self, time, inputs, state, name=None):

        cur_inputs = inputs[:, 0:time + 1, :]
        zeros_padding = inputs[:, time + 2:, :]
        cur_inputs_pos = self.add_position_embedding(cur_inputs, time)

        enc_output, beam_state = state
        logits = self.infer_conv_block(enc_output, cur_inputs_pos)

        bs_output, beam_state = beam_search.beam_search_step(
            time_=time,
            logits=logits,
            beam_state=beam_state,
            config=self.config)

        finished, next_inputs = self.next_inputs(
            sample_ids=bs_output.predicted_ids)
        next_inputs = tf.reshape(
            next_inputs,
            [self.config.beam_width, 1,
             inputs.get_shape().as_list()[-1]])
        next_inputs = tf.concat([cur_inputs, next_inputs], axis=1)
        next_inputs = tf.concat([next_inputs, zeros_padding], axis=1)
        next_inputs.set_shape([
            self.config.beam_width, self.params['max_decode_length'],
            inputs.get_shape().as_list()[-1]
        ])
        outputs = BeamDecoderOutput(logits=tf.zeros(
            [self.config.beam_width, self.config.vocab_size]),
                                    predicted_ids=bs_output.predicted_ids,
                                    log_probs=beam_state.log_probs,
                                    scores=bs_output.scores,
                                    beam_parent_ids=bs_output.beam_parent_ids)
        return outputs, (enc_output, beam_state), next_inputs, finished
Ejemplo n.º 5
0
    def step(self, time_, inputs, state, name=None):
        decoder_state, beam_state = state

        # Call the original decoder
        (decoder_output, decoder_state, _,
         _) = self.decoder.step(time_, inputs, decoder_state)

        # Perform a step of beam search
        bs_output, beam_state = beam_search.beam_search_step(
            time_=time_,
            logits=decoder_output.logits,
            beam_state=beam_state,
            config=self.config)

        # Shuffle everything according to beam search result
        if isinstance(decoder_state, AttentionWrapperState):
            beam_cell_state = nest.map_structure(
                lambda x: tf.gather(x, bs_output.beam_parent_ids),
                decoder_state.cell_state)
            decoder_state = decoder_state.clone(cell_state=beam_cell_state)
        else:
            decoder_state = nest.map_structure(
                lambda x: tf.gather(x, bs_output.beam_parent_ids),
                decoder_state)
        decoder_output = nest.map_structure(
            lambda x: tf.gather(x, bs_output.beam_parent_ids), decoder_output)

        next_state = (decoder_state, beam_state)

        outputs = BeamDecoderOutput(logits=tf.zeros(
            [self.config.beam_width, self.config.vocab_size]),
                                    predicted_ids=bs_output.predicted_ids,
                                    log_probs=beam_state.log_probs,
                                    scores=bs_output.scores,
                                    beam_parent_ids=bs_output.beam_parent_ids,
                                    original_outputs=decoder_output)

        finished, next_inputs, next_state = self.decoder.helper.next_inputs(
            time=time_,
            outputs=decoder_output,
            state=next_state,
            sample_ids=bs_output.predicted_ids)
        next_inputs.set_shape([self.batch_size, None])

        return (outputs, next_state, next_inputs, finished)
    def step(self, time_, cell_output, cell_state, loop_state):
        initial_call = (cell_output is None)

        if initial_call:
            cell_output = tf.zeros(
                [self.config.beam_width, self.cell.output_size])

            # We start out with all beams being equal, so we tile the cell state
            # [beam_width] times
            next_cell_state = beam_search.nest_map(
                cell_state, lambda x: tf.tile(x, [self.config.beam_width, 1]))

            # Call the original decoder
            original_outputs = self.decoder.step(time_, None, cell_state,
                                                 loop_state)

            # Create an initial Beam State
            beam_state = beam_search.create_initial_beam_state(
                config=self.config, max_time=self.decoder.max_decode_length)

            next_loop_state = self._wrap_loop_state(
                beam_state, original_outputs.next_loop_state)

            outputs = self.output_shapes()

        else:
            prev_beam_state, original_loop_state = self._unwrap_loop_state(
                loop_state)

            # Call the original decoder
            original_outputs = self.decoder.step(time_, cell_output,
                                                 cell_state,
                                                 original_loop_state)

            # Perform a step of beam search
            beam_state = beam_search.beam_search_step(
                logits=original_outputs.outputs.logits,
                beam_state=prev_beam_state,
                config=self.config)
            beam_state.predicted_ids.set_shape(
                [None, self.decoder.max_decode_length])
            next_loop_state = self._wrap_loop_state(
                beam_state, original_outputs.next_loop_state)

            outputs = BeamDecoderOutput(
                logits=tf.zeros(
                    [self.config.beam_width, self.config.vocab_size]),
                predicted_ids=tf.to_int64(beam_state.predicted_ids[:,
                                                                   time_ - 1]),
                log_probs=beam_state.log_probs,
                scores=beam_state.scores,
                beam_parent_ids=beam_state.beam_parent_ids,
                original_outputs=original_outputs.outputs)

            # Cell states are shuffled around by beam search
            next_cell_state = beam_search.nest_map(
                original_outputs.next_cell_state,
                lambda x: tf.gather(x, beam_state.beam_parent_ids))

        # The final step output
        step_output = DecoderStepOutput(outputs=outputs,
                                        next_cell_state=next_cell_state,
                                        next_loop_state=next_loop_state)

        return step_output