Esempio n. 1
0
    def test_step(self):
        dummy_cell_state = tf.zeros([self.batch_size, self.beam_width])
        beam_state = beam_search_decoder.BeamSearchDecoderState(
            cell_state=dummy_cell_state,
            log_probs=tf.nn.log_softmax(
                tf.ones([self.batch_size, self.beam_width])),
            lengths=tf.constant(2,
                                shape=[self.batch_size, self.beam_width],
                                dtype=tf.int64),
            finished=tf.zeros([self.batch_size, self.beam_width],
                              dtype=tf.bool),
            accumulated_attention_probs=(),
        )

        logits_ = np.full([self.batch_size, self.beam_width, self.vocab_size],
                          0.0001)
        logits_[0, 0, 2] = 1.9
        logits_[0, 0, 3] = 2.1
        logits_[0, 1, 3] = 3.1
        logits_[0, 1, 4] = 0.9
        logits_[1, 0, 1] = 0.5
        logits_[1, 1, 2] = 2.7
        logits_[1, 2, 2] = 10.0
        logits_[1, 2, 3] = 0.2
        logits = tf.convert_to_tensor(logits_, dtype=tf.float32)
        log_probs = tf.nn.log_softmax(logits)

        outputs, next_beam_state = beam_search_decoder._beam_search_step(
            time=2,
            logits=logits,
            next_cell_state=dummy_cell_state,
            beam_state=beam_state,
            batch_size=tf.convert_to_tensor(self.batch_size),
            beam_width=self.beam_width,
            end_token=self.end_token,
            length_penalty_weight=self.length_penalty_weight,
            coverage_penalty_weight=self.coverage_penalty_weight,
        )

        with self.cached_session() as sess:
            outputs_, next_state_, state_, log_probs_ = sess.run(
                [outputs, next_beam_state, beam_state, log_probs])

        self.assertAllEqual(outputs_.predicted_ids, [[3, 3, 2], [2, 2, 1]])
        self.assertAllEqual(outputs_.parent_ids, [[1, 0, 0], [2, 1, 0]])
        self.assertAllEqual(next_state_.lengths, [[3, 3, 3], [3, 3, 3]])
        self.assertAllEqual(next_state_.finished,
                            [[False, False, False], [False, False, False]])

        expected_log_probs = []
        expected_log_probs.append(state_.log_probs[0][[1, 0, 0]])
        expected_log_probs.append(state_.log_probs[1][[2, 1, 0]])  # 0 --> 1
        expected_log_probs[0][0] += log_probs_[0, 1, 3]
        expected_log_probs[0][1] += log_probs_[0, 0, 3]
        expected_log_probs[0][2] += log_probs_[0, 0, 2]
        expected_log_probs[1][0] += log_probs_[1, 2, 2]
        expected_log_probs[1][1] += log_probs_[1, 1, 2]
        expected_log_probs[1][2] += log_probs_[1, 0, 1]
        self.assertAllEqual(next_state_.log_probs, expected_log_probs)
Esempio n. 2
0
    def test_step_with_eos(self):
        dummy_cell_state = array_ops.zeros([self.batch_size, self.beam_width])
        beam_state = beam_search_decoder.BeamSearchDecoderState(
            cell_state=dummy_cell_state,
            log_probs=nn_ops.log_softmax(
                array_ops.ones([self.batch_size, self.beam_width])),
            lengths=ops.convert_to_tensor([[2, 1, 2], [2, 2, 1]],
                                          dtype=dtypes.int64),
            finished=ops.convert_to_tensor(
                [[False, True, False], [False, False, True]],
                dtype=dtypes.bool),
            accumulated_attention_probs=())

        logits_ = np.full([self.batch_size, self.beam_width, self.vocab_size],
                          0.0001)
        logits_[0, 0, 2] = 1.9
        logits_[0, 0, 3] = 2.1
        logits_[0, 1, 3] = 3.1
        logits_[0, 1, 4] = 0.9
        logits_[1, 0, 1] = 0.5
        logits_[1, 1, 2] = 5.7  # why does this not work when it's 2.7?
        logits_[1, 2, 2] = 1.0
        logits_[1, 2, 3] = 0.2
        logits = ops.convert_to_tensor(logits_, dtype=dtypes.float32)
        log_probs = nn_ops.log_softmax(logits)

        outputs, next_beam_state = beam_search_decoder._beam_search_step(
            time=2,
            logits=logits,
            next_cell_state=dummy_cell_state,
            beam_state=beam_state,
            batch_size=ops.convert_to_tensor(self.batch_size),
            beam_width=self.beam_width,
            end_token=self.end_token,
            length_penalty_weight=self.length_penalty_weight,
            coverage_penalty_weight=self.coverage_penalty_weight)

        with self.cached_session() as sess:
            outputs_, next_state_, state_, log_probs_ = sess.run(
                [outputs, next_beam_state, beam_state, log_probs])

        self.assertAllEqual(outputs_.parent_ids, [[1, 0, 0], [1, 2, 0]])
        self.assertAllEqual(outputs_.predicted_ids, [[0, 3, 2], [2, 0, 1]])
        self.assertAllEqual(next_state_.lengths, [[1, 3, 3], [3, 1, 3]])
        self.assertAllEqual(next_state_.finished,
                            [[True, False, False], [False, True, False]])

        expected_log_probs = []
        expected_log_probs.append(state_.log_probs[0][[1, 0, 0]])
        expected_log_probs.append(state_.log_probs[1][[1, 2, 0]])
        expected_log_probs[0][1] += log_probs_[0, 0, 3]
        expected_log_probs[0][2] += log_probs_[0, 0, 2]
        expected_log_probs[1][0] += log_probs_[1, 1, 2]
        expected_log_probs[1][2] += log_probs_[1, 0, 1]
        self.assertAllEqual(next_state_.log_probs, expected_log_probs)
Esempio n. 3
0
    def step(
        self,
        time: tf.Tensor,
        inputs: list[tf.Tensor],
        state: tf.Tensor,
        training: Optional[bool] = None,
        name: Optional[str] = None,
    ) -> list[tf.Tensor]:
        """Perform a decoding step.
        Args:
          time: scalar `int32` tensor.
          inputs: A (structure of) input tensors.
          state: A (structure of) state tensors and TensorArrays.
          training: Python boolean. Indicates whether the layer should
              behave in training mode or in inference mode. Only relevant
              when `dropout` or `recurrent_dropout` is used.
          name: Name scope for any created operations.
        Returns:
          `(outputs, next_state, next_inputs, finished)`.
        """
        with tf.name_scope(name or "BeamSearchDecoderStep"):
            cell_state = state.cell_state
            cell_outputs, next_cell_state = self._cell(inputs,
                                                       cell_state,
                                                       training=training)
            cell_outputs = tf.nest.map_structure(
                lambda out: self._split_batch_beams(out, out.shape[1:]),
                cell_outputs)

            if self._output_layer is not None:
                cell_outputs = self._output_layer(cell_outputs)

            beam_search_output, beam_search_state = _beam_search_step(
                time=time,
                logits=cell_outputs,
                next_cell_state=next_cell_state,
                beam_state=state,
                batch_size=self._batch_size,
                beam_width=self._beam_width,
                end_token=self._end_token,
                length_penalty_weight=self._length_penalty_weight,
                coverage_penalty_weight=self._coverage_penalty_weight,
                output_all_scores=self._output_all_scores,
            )

            finished = beam_search_state.finished
            sample_ids = beam_search_output.predicted_ids
            next_inputs = self._next_inputs(
                inputs,
                self._merge_batch_beams(sample_ids, s=sample_ids.shape[2:]))

        return [beam_search_output, beam_search_state, next_inputs, finished]
Esempio n. 4
0
    def test_step(self):
        def get_probs():
            """this simulates the initialize method in BeamSearchDecoder."""
            log_prob_mask = tf.one_hot(tf.zeros([self.batch_size],
                                                dtype=tf.int32),
                                       depth=self.beam_width,
                                       on_value=True,
                                       off_value=False,
                                       dtype=tf.bool)

            log_prob_zeros = tf.zeros([self.batch_size, self.beam_width],
                                      dtype=tf.float32)
            log_prob_neg_inf = tf.ones([self.batch_size, self.beam_width],
                                       dtype=tf.float32) * -np.Inf

            log_probs = tf.where(log_prob_mask, log_prob_zeros,
                                 log_prob_neg_inf)
            return log_probs

        log_probs = get_probs()
        dummy_cell_state = tf.zeros([self.batch_size, self.beam_width])

        # pylint: disable=invalid-name
        _finished = tf.one_hot(tf.zeros([self.batch_size], dtype=tf.int32),
                               depth=self.beam_width,
                               on_value=False,
                               off_value=True,
                               dtype=tf.bool)
        _lengths = np.zeros([self.batch_size, self.beam_width], dtype=np.int64)
        _lengths[:, 0] = 2
        _lengths = tf.constant(_lengths, dtype=tf.int64)

        beam_state = beam_search_decoder.BeamSearchDecoderState(
            cell_state=dummy_cell_state,
            log_probs=log_probs,
            lengths=_lengths,
            finished=_finished,
            accumulated_attention_probs=())

        logits_ = np.full([self.batch_size, self.beam_width, self.vocab_size],
                          0.0001)
        logits_[0, 0, 2] = 1.9
        logits_[0, 0, 3] = 2.1
        logits_[0, 1, 3] = 3.1
        logits_[0, 1, 4] = 0.9
        logits_[1, 0, 1] = 0.5
        logits_[1, 1, 2] = 2.7
        logits_[1, 2, 2] = 10.0
        logits_[1, 2, 3] = 0.2
        logits = tf.constant(logits_, dtype=tf.float32)
        log_probs = tf.nn.log_softmax(logits)

        outputs, next_beam_state = beam_search_decoder._beam_search_step(
            time=2,
            logits=logits,
            next_cell_state=dummy_cell_state,
            beam_state=beam_state,
            batch_size=tf.convert_to_tensor(self.batch_size),
            beam_width=self.beam_width,
            end_token=self.end_token,
            length_penalty_weight=self.length_penalty_weight,
            coverage_penalty_weight=self.coverage_penalty_weight)

        with self.cached_session() as sess:
            outputs_, next_state_, _, _ = sess.run(
                [outputs, next_beam_state, beam_state, log_probs])

        self.assertEqual(outputs_.predicted_ids[0, 0], 3)
        self.assertEqual(outputs_.predicted_ids[0, 1], 2)
        self.assertEqual(outputs_.predicted_ids[1, 0], 1)
        neg_inf = -np.Inf
        self.assertAllEqual(
            next_state_.log_probs[:, -3:],
            [[neg_inf, neg_inf, neg_inf], [neg_inf, neg_inf, neg_inf]])
        self.assertEqual((next_state_.log_probs[:, :-3] > neg_inf).all(), True)
        self.assertEqual((next_state_.lengths[:, :-3] > 0).all(), True)
        self.assertAllEqual(next_state_.lengths[:, -3:],
                            [[0, 0, 0], [0, 0, 0]])
Esempio n. 5
0
def test_large_beam_step():
    batch_size = 2
    beam_width = 8
    vocab_size = 5
    end_token = 0
    length_penalty_weight = 0.6
    coverage_penalty_weight = 0.0

    def get_probs():
        """this simulates the initialize method in BeamSearchDecoder."""
        log_prob_mask = tf.one_hot(
            tf.zeros([batch_size], dtype=tf.int32),
            depth=beam_width,
            on_value=True,
            off_value=False,
            dtype=tf.bool,
        )

        log_prob_zeros = tf.zeros([batch_size, beam_width], dtype=tf.float32)
        log_prob_neg_inf = tf.ones([batch_size, beam_width],
                                   dtype=tf.float32) * -np.Inf

        log_probs = tf.where(log_prob_mask, log_prob_zeros, log_prob_neg_inf)
        return log_probs

    log_probs = get_probs()
    dummy_cell_state = tf.zeros([batch_size, beam_width])

    _finished = tf.one_hot(
        tf.zeros([batch_size], dtype=tf.int32),
        depth=beam_width,
        on_value=False,
        off_value=True,
        dtype=tf.bool,
    )
    _lengths = np.zeros([batch_size, beam_width], dtype=np.int64)
    _lengths[:, 0] = 2
    _lengths = tf.constant(_lengths, dtype=tf.int64)

    beam_state = beam_search_decoder.BeamSearchDecoderState(
        cell_state=dummy_cell_state,
        log_probs=log_probs,
        lengths=_lengths,
        finished=_finished,
        accumulated_attention_probs=(),
    )

    logits_ = np.full([batch_size, beam_width, vocab_size], 0.0001)
    logits_[0, 0, 2] = 1.9
    logits_[0, 0, 3] = 2.1
    logits_[0, 1, 3] = 3.1
    logits_[0, 1, 4] = 0.9
    logits_[1, 0, 1] = 0.5
    logits_[1, 1, 2] = 2.7
    logits_[1, 2, 2] = 10.0
    logits_[1, 2, 3] = 0.2
    logits = tf.constant(logits_, dtype=tf.float32)
    log_probs = tf.nn.log_softmax(logits)

    outputs, next_beam_state = beam_search_decoder._beam_search_step(
        time=2,
        logits=logits,
        next_cell_state=dummy_cell_state,
        beam_state=beam_state,
        batch_size=tf.convert_to_tensor(batch_size),
        beam_width=beam_width,
        end_token=end_token,
        length_penalty_weight=length_penalty_weight,
        coverage_penalty_weight=coverage_penalty_weight,
    )

    outputs_, next_state_ = [outputs, next_beam_state]

    assert outputs_.predicted_ids[0, 0] == 3
    assert outputs_.predicted_ids[0, 1] == 2
    assert outputs_.predicted_ids[1, 0] == 1
    neg_inf = -np.Inf
    np.testing.assert_equal(
        next_state_.log_probs[:, -3:].numpy(),
        np.asanyarray([[neg_inf, neg_inf, neg_inf],
                       [neg_inf, neg_inf, neg_inf]]),
    )
    np.testing.assert_equal(
        np.asanyarray(next_state_.log_probs[:, :-3] > neg_inf), True)
    np.testing.assert_equal(np.asanyarray(next_state_.lengths[:, :-3] > 0),
                            True)
    np.testing.assert_equal(next_state_.lengths[:, -3:].numpy(),
                            np.asanyarray([[0, 0, 0], [0, 0, 0]]))
Esempio n. 6
0
def test_step_with_eos():
    batch_size = 2
    beam_width = 3
    vocab_size = 5
    end_token = 0
    length_penalty_weight = 0.6
    coverage_penalty_weight = 0.0

    dummy_cell_state = tf.zeros([batch_size, beam_width])
    beam_state = beam_search_decoder.BeamSearchDecoderState(
        cell_state=dummy_cell_state,
        log_probs=tf.nn.log_softmax(tf.ones([batch_size, beam_width])),
        lengths=tf.convert_to_tensor([[2, 1, 2], [2, 2, 1]], dtype=tf.int64),
        finished=tf.convert_to_tensor(
            [[False, True, False], [False, False, True]], dtype=tf.bool),
        accumulated_attention_probs=(),
    )

    logits_ = np.full([batch_size, beam_width, vocab_size], 0.0001)
    logits_[0, 0, 2] = 1.9
    logits_[0, 0, 3] = 2.1
    logits_[0, 1, 3] = 3.1
    logits_[0, 1, 4] = 0.9
    logits_[1, 0, 1] = 0.5
    logits_[1, 1, 2] = 5.7  # why does this not work when it's 2.7?
    logits_[1, 2, 2] = 1.0
    logits_[1, 2, 3] = 0.2
    logits = tf.convert_to_tensor(logits_, dtype=tf.float32)
    log_probs = tf.nn.log_softmax(logits)

    outputs, next_beam_state = beam_search_decoder._beam_search_step(
        time=2,
        logits=logits,
        next_cell_state=dummy_cell_state,
        beam_state=beam_state,
        batch_size=tf.convert_to_tensor(batch_size),
        beam_width=beam_width,
        end_token=end_token,
        length_penalty_weight=length_penalty_weight,
        coverage_penalty_weight=coverage_penalty_weight,
    )

    outputs_, next_state_, state_, log_probs_ = [
        outputs,
        next_beam_state,
        beam_state,
        log_probs,
    ]

    np.testing.assert_equal(outputs_.parent_ids.numpy(),
                            np.asanyarray([[1, 0, 0], [1, 2, 0]]))
    np.testing.assert_equal(outputs_.predicted_ids.numpy(),
                            np.asanyarray([[0, 3, 2], [2, 0, 1]]))
    np.testing.assert_equal(next_state_.lengths.numpy(),
                            np.asanyarray([[1, 3, 3], [3, 1, 3]]))
    np.testing.assert_equal(
        next_state_.finished.numpy(),
        np.asanyarray([[True, False, False], [False, True, False]]),
    )

    expected_log_probs = []
    expected_log_probs.append(state_.log_probs[0].numpy())
    expected_log_probs.append(state_.log_probs[1].numpy())
    expected_log_probs[0][1] += log_probs_[0, 0, 3]
    expected_log_probs[0][2] += log_probs_[0, 0, 2]
    expected_log_probs[1][0] += log_probs_[1, 1, 2]
    expected_log_probs[1][2] += log_probs_[1, 0, 1]
    np.testing.assert_equal(next_state_.log_probs.numpy(),
                            np.asanyarray(expected_log_probs))