def test_step(self):
        dummy_cell_state = array_ops.zeros([self.batch_size, self.beam_width])
        beam_state = beam_search_decoder.BeamSearchDecoderState(
            cell_state=dummy_cell_state,
            log_probs=nn_ops.log_softmax(
                array_ops.ones([self.batch_size, self.beam_width])),
            lengths=constant_op.constant(
                2,
                shape=[self.batch_size, self.beam_width],
                dtype=dtypes.int32),
            finished=array_ops.zeros([self.batch_size, self.beam_width],
                                     dtype=dtypes.bool))

        logits_ = np.full([self.batch_size, self.beam_width, self.vocab_size],
                          0.0001)
        logits_[0, 0, 2] = 1.9
        logits_[0, 0, 3] = 2.1
        logits_[0, 1, 3] = 3.1
        logits_[0, 1, 4] = 0.9
        logits_[1, 0, 1] = 0.5
        logits_[1, 1, 2] = 2.7
        logits_[1, 2, 2] = 10.0
        logits_[1, 2, 3] = 0.2
        logits = ops.convert_to_tensor(logits_, dtype=dtypes.float32)
        log_probs = nn_ops.log_softmax(logits)

        outputs, next_beam_state = beam_search_decoder._beam_search_step(
            time=2,
            logits=logits,
            beam_state=beam_state,
            batch_size=ops.convert_to_tensor(self.batch_size),
            beam_width=self.beam_width,
            end_token=self.end_token,
            length_penalty_weight=self.length_penalty_weight)

        with self.test_session() as sess:
            outputs_, next_state_, state_, log_probs_ = sess.run(
                [outputs, next_beam_state, beam_state, log_probs])

        np.testing.assert_array_equal(outputs_.predicted_ids,
                                      [[3, 3, 2], [2, 2, 1]])
        np.testing.assert_array_equal(outputs_.parent_ids,
                                      [[1, 0, 0], [2, 1, 0]])
        np.testing.assert_array_equal(next_state_.lengths,
                                      [[3, 3, 3], [3, 3, 3]])
        np.testing.assert_array_equal(
            next_state_.finished,
            [[False, False, False], [False, False, False]])

        expected_log_probs = []
        expected_log_probs.append(state_.log_probs[0][[1, 0, 0]])
        expected_log_probs.append(state_.log_probs[1][[2, 1, 0]])  # 0 --> 1
        expected_log_probs[0][0] += log_probs_[0, 1, 3]
        expected_log_probs[0][1] += log_probs_[0, 0, 3]
        expected_log_probs[0][2] += log_probs_[0, 0, 2]
        expected_log_probs[1][0] += log_probs_[1, 2, 2]
        expected_log_probs[1][1] += log_probs_[1, 1, 2]
        expected_log_probs[1][2] += log_probs_[1, 0, 1]
        np.testing.assert_array_equal(next_state_.log_probs,
                                      expected_log_probs)
Exemplo n.º 2
0
    def step(self, time, inputs, state, name=None):
        batch_size = self._batch_size
        beam_width = self._beam_width
        end_token = self._end_token
        length_penalty_weight = self._length_penalty_weight

        with ops.name_scope(name, "BeamSearchDecoderStep", (time, inputs, state)):
            cell_state = state.cell_state
            inputs = nest.map_structure(
                lambda inp: self._merge_batch_beams(inp, s=inp.shape[2:]), inputs)
            cell_state = nest.map_structure(
                self._maybe_merge_batch_beams,
                cell_state, self._cell.state_size)
            cell_outputs, next_cell_state = self._cell(inputs, cell_state)
            cell_outputs = nest.map_structure(
                lambda out: self._split_batch_beams(out, out.shape[1:]), cell_outputs)
            next_cell_state = nest.map_structure(
                self._maybe_split_batch_beams,
                next_cell_state, self._cell.state_size)

            ### context vector
            K = next_cell_state
            Q = self.encoder_ouputs
            V = self.encoder_ouputs
            outputs = tf.matmul(Q, tf.transpose(K, [0, 2, 1]))  # bxtxc bxcxbeam => bxtxbeam
            attens = tf.nn.softmax(outputs, axis=1) # bxtxbeam
            context_vec = tf.expand_dims(attens, 3)*tf.expand_dims(V, 2) # bxtxbeamx1 bxtx1xc => bxtxbeamxc
            context_vec = tf.reduce_sum(context_vec, axis=1)  #  bxtxbeamxc => bxbeamxc
            ### end context vector
            ### cell_outputs vector
            cell_outputs = array_ops.concat([cell_outputs, context_vec], -1)
            ### end cell_outputs vector

            if self._output_layer is not None:
                cell_outputs = self._output_layer(cell_outputs)

            beam_search_output, beam_search_state = _beam_search_step(
                time=time,
                logits=cell_outputs,
                next_cell_state=next_cell_state,
                beam_state=state,
                batch_size=batch_size,
                beam_width=beam_width,
                end_token=end_token,
                length_penalty_weight=length_penalty_weight)

            finished = beam_search_state.finished
            sample_ids = beam_search_output.predicted_ids
            next_inputs = control_flow_ops.cond(
                math_ops.reduce_all(finished), lambda: self._start_inputs,
                lambda: self._embedding_fn(sample_ids))

            ### next_inputs vector
            next_inputs = array_ops.concat([next_inputs, self.z], -1) # bxbeamx[e+c+c]=bx5x640
            # attens = tf.transpose(attens, [0,2,1]) # bxbeamxt
            # beam_search_output = BeamSearchDecoderOutput(attens, beam_search_output[1], beam_search_output[2], beam_search_output[3])
            ### next_inputs vector
        
        return (beam_search_output, beam_search_state, next_inputs, finished)
  def test_step(self):
    dummy_cell_state = array_ops.zeros([self.batch_size, self.beam_width])
    beam_state = beam_search_decoder.BeamSearchDecoderState(
        cell_state=dummy_cell_state,
        log_probs=nn_ops.log_softmax(
            array_ops.ones([self.batch_size, self.beam_width])),
        lengths=constant_op.constant(
            2, shape=[self.batch_size, self.beam_width], dtype=dtypes.int64),
        finished=array_ops.zeros(
            [self.batch_size, self.beam_width], dtype=dtypes.bool),
        accumulated_attention_probs=())

    logits_ = np.full([self.batch_size, self.beam_width, self.vocab_size],
                      0.0001)
    logits_[0, 0, 2] = 1.9
    logits_[0, 0, 3] = 2.1
    logits_[0, 1, 3] = 3.1
    logits_[0, 1, 4] = 0.9
    logits_[1, 0, 1] = 0.5
    logits_[1, 1, 2] = 2.7
    logits_[1, 2, 2] = 10.0
    logits_[1, 2, 3] = 0.2
    logits = ops.convert_to_tensor(logits_, dtype=dtypes.float32)
    log_probs = nn_ops.log_softmax(logits)

    outputs, next_beam_state = beam_search_decoder._beam_search_step(
        time=2,
        logits=logits,
        next_cell_state=dummy_cell_state,
        beam_state=beam_state,
        batch_size=ops.convert_to_tensor(self.batch_size),
        beam_width=self.beam_width,
        end_token=self.end_token,
        length_penalty_weight=self.length_penalty_weight,
        coverage_penalty_weight=self.coverage_penalty_weight)

    with self.cached_session() as sess:
      outputs_, next_state_, state_, log_probs_ = sess.run(
          [outputs, next_beam_state, beam_state, log_probs])

    self.assertAllEqual(outputs_.predicted_ids, [[3, 3, 2], [2, 2, 1]])
    self.assertAllEqual(outputs_.parent_ids, [[1, 0, 0], [2, 1, 0]])
    self.assertAllEqual(next_state_.lengths, [[3, 3, 3], [3, 3, 3]])
    self.assertAllEqual(next_state_.finished,
                        [[False, False, False], [False, False, False]])

    expected_log_probs = []
    expected_log_probs.append(state_.log_probs[0][[1, 0, 0]])
    expected_log_probs.append(state_.log_probs[1][[2, 1, 0]])  # 0 --> 1
    expected_log_probs[0][0] += log_probs_[0, 1, 3]
    expected_log_probs[0][1] += log_probs_[0, 0, 3]
    expected_log_probs[0][2] += log_probs_[0, 0, 2]
    expected_log_probs[1][0] += log_probs_[1, 2, 2]
    expected_log_probs[1][1] += log_probs_[1, 1, 2]
    expected_log_probs[1][2] += log_probs_[1, 0, 1]
    self.assertAllEqual(next_state_.log_probs, expected_log_probs)
Exemplo n.º 4
0
    def test_step_with_eos(self):
        dummy_cell_state = array_ops.zeros([self.batch_size, self.beam_width])
        beam_state = beam_search_decoder.BeamSearchDecoderState(
            cell_state=dummy_cell_state,
            log_probs=nn_ops.log_softmax(
                array_ops.ones([self.batch_size, self.beam_width])),
            lengths=ops.convert_to_tensor([[2, 1, 2], [2, 2, 1]],
                                          dtype=dtypes.int64),
            finished=ops.convert_to_tensor(
                [[False, True, False], [False, False, True]],
                dtype=dtypes.bool),
            accumulated_attention_probs=())

        logits_ = np.full([self.batch_size, self.beam_width, self.vocab_size],
                          0.0001)
        logits_[0, 0, 2] = 1.9
        logits_[0, 0, 3] = 2.1
        logits_[0, 1, 3] = 3.1
        logits_[0, 1, 4] = 0.9
        logits_[1, 0, 1] = 0.5
        logits_[1, 1, 2] = 5.7  # why does this not work when it's 2.7?
        logits_[1, 2, 2] = 1.0
        logits_[1, 2, 3] = 0.2
        logits = ops.convert_to_tensor(logits_, dtype=dtypes.float32)
        log_probs = nn_ops.log_softmax(logits)

        outputs, next_beam_state = beam_search_decoder._beam_search_step(
            time=2,
            logits=logits,
            next_cell_state=dummy_cell_state,
            beam_state=beam_state,
            batch_size=ops.convert_to_tensor(self.batch_size),
            beam_width=self.beam_width,
            end_token=self.end_token,
            length_penalty_weight=self.length_penalty_weight,
            coverage_penalty_weight=self.coverage_penalty_weight)

        with self.cached_session() as sess:
            outputs_, next_state_, state_, log_probs_ = sess.run(
                [outputs, next_beam_state, beam_state, log_probs])

        self.assertAllEqual(outputs_.parent_ids, [[1, 0, 0], [1, 2, 0]])
        self.assertAllEqual(outputs_.predicted_ids, [[0, 3, 2], [2, 0, 1]])
        self.assertAllEqual(next_state_.lengths, [[1, 3, 3], [3, 1, 3]])
        self.assertAllEqual(next_state_.finished,
                            [[True, False, False], [False, True, False]])

        expected_log_probs = []
        expected_log_probs.append(state_.log_probs[0][[1, 0, 0]])
        expected_log_probs.append(state_.log_probs[1][[1, 2, 0]])
        expected_log_probs[0][1] += log_probs_[0, 0, 3]
        expected_log_probs[0][2] += log_probs_[0, 0, 2]
        expected_log_probs[1][0] += log_probs_[1, 1, 2]
        expected_log_probs[1][2] += log_probs_[1, 0, 1]
        self.assertAllEqual(next_state_.log_probs, expected_log_probs)
    def step(self, time, inputs, state, name=None):
        """Perform a decoding step.
        Args:
        time: scalar `int32` tensor.
        inputs: A (structure of) input tensors.
        state: A (structure of) state tensors and TensorArrays.
        name: Name scope for any created operations.
        Returns:
        `(outputs, next_state, next_inputs, finished)`.
        """
        batch_size = self._batch_size
        beam_width = self._beam_width
        end_token = self._end_token
        length_penalty_weight = self._length_penalty_weight

        with ops.name_scope(name, "BeamSearchDecoderStep",
                            (time, inputs, state)):
            cell_state = state.cell_state
            inputs = nest.map_structure(
                lambda inp: self._merge_batch_beams(inp, s=inp.shape[2:]),
                inputs)
            cell_state = nest.map_structure(self._maybe_merge_batch_beams,
                                            cell_state, self._cell.state_size)
            cell_outputs, next_cell_state = self._cell(inputs, cell_state)
            cell_outputs = nest.map_structure(
                lambda out: self._split_batch_beams(out, out.shape[1:]),
                cell_outputs)
            next_cell_state = nest.map_structure(self._maybe_split_batch_beams,
                                                 next_cell_state,
                                                 self._cell.state_size)

            if self._output_layer is not None:
                cell_outputs = self._output_layer(cell_outputs)

            beam_search_output, beam_search_state = _beam_search_step(
                time=time,
                logits=cell_outputs,
                next_cell_state=next_cell_state,
                beam_state=state,
                batch_size=batch_size,
                beam_width=beam_width,
                end_token=end_token,
                length_penalty_weight=length_penalty_weight)

            finished = beam_search_state.finished
            sample_ids = beam_search_output.predicted_ids
            next_inputs = control_flow_ops.cond(
                math_ops.reduce_all(finished), lambda: self._start_inputs,
                lambda: tf.concat(
                    [self._embedding_fn(sample_ids), self.cnn_inputs], 2))

        return (beam_search_output, beam_search_state, next_inputs, finished)
Exemplo n.º 6
0
  def test_step_with_eos(self):
    dummy_cell_state = array_ops.zeros([self.batch_size, self.beam_width])
    beam_state = beam_search_decoder.BeamSearchDecoderState(
        cell_state=dummy_cell_state,
        log_probs=nn_ops.log_softmax(
            array_ops.ones([self.batch_size, self.beam_width])),
        lengths=ops.convert_to_tensor(
            [[2, 1, 2], [2, 2, 1]], dtype=dtypes.int32),
        finished=ops.convert_to_tensor(
            [[False, True, False], [False, False, True]], dtype=dtypes.bool))

    logits_ = np.full([self.batch_size, self.beam_width, self.vocab_size],
                      0.0001)
    logits_[0, 0, 2] = 1.9
    logits_[0, 0, 3] = 2.1
    logits_[0, 1, 3] = 3.1
    logits_[0, 1, 4] = 0.9
    logits_[1, 0, 1] = 0.5
    logits_[1, 1, 2] = 5.7  # why does this not work when it's 2.7?
    logits_[1, 2, 2] = 1.0
    logits_[1, 2, 3] = 0.2
    logits = ops.convert_to_tensor(logits_, dtype=dtypes.float32)
    log_probs = nn_ops.log_softmax(logits)

    outputs, next_beam_state = beam_search_decoder._beam_search_step(
        time=2,
        logits=logits,
        beam_state=beam_state,
        batch_size=ops.convert_to_tensor(self.batch_size),
        beam_width=self.beam_width,
        end_token=self.end_token,
        length_penalty_weight=self.length_penalty_weight)

    with self.test_session() as sess:
      outputs_, next_state_, state_, log_probs_ = sess.run(
          [outputs, next_beam_state, beam_state, log_probs])

    np.testing.assert_array_equal(outputs_.parent_ids, [[1, 0, 0], [1, 2, 0]])
    np.testing.assert_array_equal(outputs_.predicted_ids, [[0, 3, 2], [2, 0,
                                                                       1]])
    np.testing.assert_array_equal(next_state_.lengths, [[1, 3, 3], [3, 1, 3]])
    np.testing.assert_array_equal(next_state_.finished, [[True, False, False],
                                                         [False, True, False]])

    expected_log_probs = []
    expected_log_probs.append(state_.log_probs[0][[1, 0, 0]])
    expected_log_probs.append(state_.log_probs[1][[1, 2, 0]])
    expected_log_probs[0][1] += log_probs_[0, 0, 3]
    expected_log_probs[0][2] += log_probs_[0, 0, 2]
    expected_log_probs[1][0] += log_probs_[1, 1, 2]
    expected_log_probs[1][2] += log_probs_[1, 0, 1]
    np.testing.assert_array_equal(next_state_.log_probs, expected_log_probs)
Exemplo n.º 7
0
    def step(self, time, inputs, state, name=None):
        batch_size = self._batch_size
        beam_width = self._beam_width
        end_token = self._end_token
        length_penalty_weight = self._length_penalty_weight

        with tf.name_scope(name, "BeamSearchDecoderStep",
                           (time, inputs, state)):
            cell_state = state.cell_state
            inputs = nest.map_structure(
                lambda inp: self._merge_batch_beams(inp, s=inp.shape[2:]),
                inputs)
            cell_state = nest.map_structure(self._maybe_merge_batch_beams,
                                            cell_state, self._cell.state_size)
            cell_outputs, next_cell_state = self._cell(inputs, cell_state)
            cell_outputs = nest.map_structure(
                lambda out: self._split_batch_beams(out, out.shape[1:]),
                cell_outputs)
            next_cell_state = nest.map_structure(self._maybe_split_batch_beams,
                                                 next_cell_state,
                                                 self._cell.state_size)

            if self._output_layer is not None:
                cell_outputs = self._output_layer(cell_outputs)

            beam_search_output, beam_search_state = _beam_search_step(
                time=time,
                logits=cell_outputs,
                next_cell_state=next_cell_state,
                beam_state=state,
                batch_size=batch_size,
                beam_width=beam_width,
                end_token=end_token,
                length_penalty_weight=length_penalty_weight)

            finished = beam_search_state.finished
            sample_ids = beam_search_output.predicted_ids
            next_inputs = tf.cond(tf.reduce_all(finished),
                                  lambda: self._start_inputs,
                                  lambda: self._embedding_fn(sample_ids))

            next_inputs = tf.concat([next_inputs, self.z], -1)

        return (beam_search_output, beam_search_state, next_inputs, finished)
    def test_step(self):
        def get_probs():
            """this simulates the initialize method in BeamSearchDecoder."""
            log_prob_mask = array_ops.one_hot(array_ops.zeros(
                [self.batch_size], dtype=dtypes.int32),
                                              depth=self.beam_width,
                                              on_value=True,
                                              off_value=False,
                                              dtype=dtypes.bool)

            log_prob_zeros = array_ops.zeros(
                [self.batch_size, self.beam_width], dtype=dtypes.float32)
            log_prob_neg_inf = array_ops.ones(
                [self.batch_size, self.beam_width],
                dtype=dtypes.float32) * -np.Inf

            log_probs = array_ops.where(log_prob_mask, log_prob_zeros,
                                        log_prob_neg_inf)
            return log_probs

        log_probs = get_probs()
        dummy_cell_state = array_ops.zeros([self.batch_size, self.beam_width])

        # pylint: disable=invalid-name
        _finished = array_ops.one_hot(array_ops.zeros([self.batch_size],
                                                      dtype=dtypes.int32),
                                      depth=self.beam_width,
                                      on_value=False,
                                      off_value=True,
                                      dtype=dtypes.bool)
        _lengths = np.zeros([self.batch_size, self.beam_width], dtype=np.int64)
        _lengths[:, 0] = 2
        _lengths = constant_op.constant(_lengths, dtype=dtypes.int64)

        beam_state = beam_search_decoder.BeamSearchDecoderState(
            cell_state=dummy_cell_state,
            log_probs=log_probs,
            lengths=_lengths,
            finished=_finished)

        logits_ = np.full([self.batch_size, self.beam_width, self.vocab_size],
                          0.0001)
        logits_[0, 0, 2] = 1.9
        logits_[0, 0, 3] = 2.1
        logits_[0, 1, 3] = 3.1
        logits_[0, 1, 4] = 0.9
        logits_[1, 0, 1] = 0.5
        logits_[1, 1, 2] = 2.7
        logits_[1, 2, 2] = 10.0
        logits_[1, 2, 3] = 0.2
        logits = constant_op.constant(logits_, dtype=dtypes.float32)
        log_probs = nn_ops.log_softmax(logits)

        outputs, next_beam_state = beam_search_decoder._beam_search_step(
            time=2,
            logits=logits,
            next_cell_state=dummy_cell_state,
            beam_state=beam_state,
            batch_size=ops.convert_to_tensor(self.batch_size),
            beam_width=self.beam_width,
            end_token=self.end_token,
            length_penalty_weight=self.length_penalty_weight)

        with self.test_session() as sess:
            outputs_, next_state_, _, _ = sess.run(
                [outputs, next_beam_state, beam_state, log_probs])

        self.assertEqual(outputs_.predicted_ids[0, 0], 3)
        self.assertEqual(outputs_.predicted_ids[0, 1], 2)
        self.assertEqual(outputs_.predicted_ids[1, 0], 1)
        neg_inf = -np.Inf
        self.assertAllEqual(
            next_state_.log_probs[:, -3:],
            [[neg_inf, neg_inf, neg_inf], [neg_inf, neg_inf, neg_inf]])
        self.assertEqual((next_state_.log_probs[:, :-3] > neg_inf).all(), True)
        self.assertEqual((next_state_.lengths[:, :-3] > 0).all(), True)
        self.assertAllEqual(next_state_.lengths[:, -3:],
                            [[0, 0, 0], [0, 0, 0]])
    def step(self, time, inputs, state, name=None):
        batch_size = self._batch_size
        beam_width = self._beam_width
        end_token = self._end_token
        length_penalty_weight = self._length_penalty_weight

        with ops.name_scope(name, "BeamSearchDecoderStep",
                            (time, inputs, state)):
            cell_state = state.cell_state
            inputs = nest.map_structure(
                lambda inp: self._merge_batch_beams(inp, s=inp.shape[2:]),
                inputs)
            cell_state = nest.map_structure(self._maybe_merge_batch_beams,
                                            cell_state, self._cell.state_size)
            cell_outputs, next_cell_state = self._cell(inputs, cell_state)
            cell_outputs = nest.map_structure(
                lambda out: self._split_batch_beams(out, out.shape[1:]),
                cell_outputs)
            next_cell_state = nest.map_structure(self._maybe_split_batch_beams,
                                                 next_cell_state,
                                                 self._cell.state_size)

            if self._output_layer is not None:
                # My modification
                if isinstance(self._output_layer,
                              taware_layer.JointDenseLayer):
                    reshaped_inputs = tf.reshape(
                        inputs, [-1, beam_width, inputs.shape[-1]])
                    if self._current_context is not None:
                        msg_attention, _ = tf.split(self._current_context,
                                                    num_or_size_splits=2,
                                                    axis=1)
                        msg_attention = tf.reshape(
                            msg_attention,
                            [-1, beam_width, msg_attention.shape[-1]])
                        cell_outputs = self._output_layer(
                            cell_outputs,
                            input=reshaped_inputs,
                            context=msg_attention)
                    else:
                        cell_outputs = self._output_layer(
                            cell_outputs, input=reshaped_inputs)
                else:
                    cell_outputs = self._output_layer(cell_outputs)

            beam_search_output, beam_search_state = _beam_search_step(
                time=time,
                logits=cell_outputs,
                next_cell_state=next_cell_state,
                beam_state=state,
                batch_size=batch_size,
                beam_width=beam_width,
                end_token=end_token,
                length_penalty_weight=length_penalty_weight,
                coverage_penalty_weight=0.0)

            finished = beam_search_state.finished
            sample_ids = beam_search_output.predicted_ids
            next_inputs = control_flow_ops.cond(
                math_ops.reduce_all(finished), lambda: self._start_inputs,
                lambda: self._embedding_fn(sample_ids))

            # My modification
            self._current_context = cell_state.attention

        return (beam_search_output, beam_search_state, next_inputs, finished)
Exemplo n.º 10
0
  def test_step(self):

    def get_probs():
      """this simulates the initialize method in BeamSearchDecoder."""
      log_prob_mask = array_ops.one_hot(
          array_ops.zeros([self.batch_size], dtype=dtypes.int32),
          depth=self.beam_width,
          on_value=True,
          off_value=False,
          dtype=dtypes.bool)

      log_prob_zeros = array_ops.zeros(
          [self.batch_size, self.beam_width], dtype=dtypes.float32)
      log_prob_neg_inf = array_ops.ones(
          [self.batch_size, self.beam_width], dtype=dtypes.float32) * -np.Inf

      log_probs = array_ops.where(log_prob_mask, log_prob_zeros,
                                  log_prob_neg_inf)
      return log_probs

    log_probs = get_probs()
    dummy_cell_state = array_ops.zeros([self.batch_size, self.beam_width])

    # pylint: disable=invalid-name
    _finished = array_ops.one_hot(
        array_ops.zeros([self.batch_size], dtype=dtypes.int32),
        depth=self.beam_width,
        on_value=False,
        off_value=True,
        dtype=dtypes.bool)
    _lengths = np.zeros([self.batch_size, self.beam_width], dtype=np.int64)
    _lengths[:, 0] = 2
    _lengths = constant_op.constant(_lengths, dtype=dtypes.int64)

    beam_state = beam_search_decoder.BeamSearchDecoderState(
        cell_state=dummy_cell_state,
        log_probs=log_probs,
        lengths=_lengths,
        finished=_finished,
        accumulated_attention_probs=())

    logits_ = np.full([self.batch_size, self.beam_width, self.vocab_size],
                      0.0001)
    logits_[0, 0, 2] = 1.9
    logits_[0, 0, 3] = 2.1
    logits_[0, 1, 3] = 3.1
    logits_[0, 1, 4] = 0.9
    logits_[1, 0, 1] = 0.5
    logits_[1, 1, 2] = 2.7
    logits_[1, 2, 2] = 10.0
    logits_[1, 2, 3] = 0.2
    logits = constant_op.constant(logits_, dtype=dtypes.float32)
    log_probs = nn_ops.log_softmax(logits)

    outputs, next_beam_state = beam_search_decoder._beam_search_step(
        time=2,
        logits=logits,
        next_cell_state=dummy_cell_state,
        beam_state=beam_state,
        batch_size=ops.convert_to_tensor(self.batch_size),
        beam_width=self.beam_width,
        end_token=self.end_token,
        length_penalty_weight=self.length_penalty_weight,
        coverage_penalty_weight=self.coverage_penalty_weight)

    with self.cached_session() as sess:
      outputs_, next_state_, _, _ = sess.run(
          [outputs, next_beam_state, beam_state, log_probs])

    self.assertEqual(outputs_.predicted_ids[0, 0], 3)
    self.assertEqual(outputs_.predicted_ids[0, 1], 2)
    self.assertEqual(outputs_.predicted_ids[1, 0], 1)
    neg_inf = -np.Inf
    self.assertAllEqual(
        next_state_.log_probs[:, -3:],
        [[neg_inf, neg_inf, neg_inf], [neg_inf, neg_inf, neg_inf]])
    self.assertEqual((next_state_.log_probs[:, :-3] > neg_inf).all(), True)
    self.assertEqual((next_state_.lengths[:, :-3] > 0).all(), True)
    self.assertAllEqual(next_state_.lengths[:, -3:], [[0, 0, 0], [0, 0, 0]])
Exemplo n.º 11
0
    def step(self, time, inputs, state, name=None):
        """Perform a decoding step.

        Args:
          time: scalar `int32` tensor.
          inputs: A (structure of) input tensors.
          state: A (structure of) state tensors and TensorArrays.
          name: Name scope for any created operations.

        Returns:
          `(outputs, next_state, next_inputs, finished)`.
        """
        batch_size = self._batch_size
        beam_width = self._beam_width
        end_token = self._end_token
        length_penalty_weight = self._length_penalty_weight

        with ops.name_scope(name, "BeamSearchDecoderStep",
                            (time, inputs, state)):
            cell_state = state.cell_state
            inputs = nest.map_structure(
                lambda inp: self._merge_batch_beams(inp, s=inp.shape[2:]),
                inputs)
            cell_state = nest.map_structure(self._maybe_merge_batch_beams,
                                            cell_state, self._cell.state_size)
            cell_outputs, next_cell_state = self._cell(inputs, cell_state)

            # finished = tf.Print(state.finished, [state.finished, 'finished', time], summarize=100)
            # not_finished = tf.Print(not_finished, [not_finished, 'not_finished', time], summarize=100)
            # cell_state.last_choice shape = [batch_size * beam_width]
            next_choices = gen_array_ops.gather_v2(self.lookup_table,
                                                   cell_state.last_choice,
                                                   axis=0)
            not_finished = tf.not_equal(next_choices[:, 0], end_token)
            next_next_choices = gen_array_ops.gather_v2(self.lookup_table,
                                                        next_choices[:, 0],
                                                        axis=0)
            will_finish = tf.logical_and(
                not_finished, tf.equal(next_next_choices[:, 0], end_token))

            def move(will_finish, last_choice, cell_outputs):
                # cell_outputs = tf.Print(cell_outputs, [cell_outputs, 'cell_outputs', time], summarize=1000)
                # will_finish = tf.Print(will_finish, [will_finish, 'will_finish', time], summarize=100)
                attention_score = self._step_method(last_choice)
                attention_score = attention_score + cell_outputs
                # final = tf.Print(final, [final, 'finalll', time], summarize=1000)
                return tf.where(will_finish, attention_score, cell_outputs)

            if self._output_layer is not None:
                cell_outputs = self._output_layer(cell_outputs)
                # will_finish = tf.Print(will_finish, [will_finish, 'will_finish, beam_search', time], summarize=100)
                cell_outputs = tf.cond(
                    tf.reduce_any(will_finish),
                    false_fn=lambda: cell_outputs,
                    true_fn=lambda: move(will_finish, cell_state.last_choice,
                                         cell_outputs))

            if self.hie:
                cell_outputs = self._mask_outputs_by_lable(
                    cell_outputs, cell_state.last_choice)

                # cell_state.last_choice shape = [batch_size*beam_width,]

            cell_outputs = nest.map_structure(
                lambda out: self._split_batch_beams(out, out.shape[1:]),
                cell_outputs)

            next_cell_state = nest.map_structure(self._maybe_split_batch_beams,
                                                 next_cell_state,
                                                 self._cell.state_size)

            beam_search_output, beam_search_state = _beam_search_step(
                time=time,
                logits=cell_outputs,
                next_cell_state=next_cell_state,
                beam_state=state,
                batch_size=batch_size,
                beam_width=beam_width,
                end_token=end_token,
                length_penalty_weight=length_penalty_weight)

            finished = beam_search_state.finished

            # replace the father ids
            sample_ids = beam_search_output.predicted_ids
            next_cell_state = beam_search_state.cell_state
            next_cell_state = next_cell_state._replace(last_choice=sample_ids)
            beam_search_state = beam_search_state._replace(
                cell_state=next_cell_state)

            # sample_ids shape = [batch_size, beam_width]
            next_inputs = control_flow_ops.cond(
                math_ops.reduce_all(finished), lambda: self._start_inputs,
                lambda: self._embedding_fn(sample_ids))

        return (beam_search_output, beam_search_state, next_inputs, finished)
Exemplo n.º 12
0
    def step(self, time, inputs, state, name=None):
        """Perform a decoding step.

        Args:
                time: scalar `int32` tensor.
                inputs: A (structure of) input tensors.
                state: A (structure of) state tensors and TensorArrays.
                name: Name scope for any created operations.

        Returns:
                `(outputs, next_state, next_inputs, finished)`.
        """
        batch_size = self._batch_size
        beam_width = self._beam_width
        end_token = self._end_token
        length_penalty_weight = self._length_penalty_weight
        coverage_penalty_weight = self._coverage_penalty_weight

        with ops.name_scope(name, "BeamSearchDecoderStep", (time, inputs, state)):
            cell_state = state.cell_state
            inputs = nest.map_structure(
                lambda inp: self._merge_batch_beams(inp, s=inp.shape[2:]), inputs)
            cell_state = nest.map_structure(self._maybe_merge_batch_beams, cell_state,
                                            self._cell.state_size)
            cell_outputs, next_cell_state = self._cell(inputs, cell_state)
            cell_outputs = nest.map_structure(
                lambda out: self._split_batch_beams(out, out.shape[1:]), cell_outputs)
            next_cell_state = nest.map_structure(
                self._maybe_split_batch_beams, next_cell_state, self._cell.state_size)

            if self._output_layer is not None:
                cell_outputs = self._output_layer(cell_outputs)

            if self._shrink_vocab > 0:
                self._skip_tokens_decoding += list(
                    range(self._shrink_vocab, cell_outputs.get_shape()[2]))
                self._skip_tokens_decoding = sorted(
                    set(self._skip_tokens_decoding))
                # Never skip _END token, no matter what
                if self._raw_end_token in self._skip_tokens_decoding:
                    self._skip_tokens_decoding.remove(self._raw_end_token)

            # Assign least possible logit for given list of tokens to avoid those tokens while decoding
            if len(self._skip_tokens_decoding) > 0:

                token_num = cell_outputs.get_shape()[2]
                minimum_activation = tf.reduce_min(cell_outputs) - 1
                blacklist = tf.sparse_to_dense(
                    self._skip_tokens_decoding,
                    output_shape=[cell_outputs.get_shape()[2]],
                    sparse_values=0.0,
                    default_value=1.0)
                cell_outputs = tf.add(tf.multiply(
                    cell_outputs, blacklist), minimum_activation * (1 - blacklist))

            beam_search_output, beam_search_state = beam_search_decoder._beam_search_step(
                time=time,
                logits=cell_outputs,
                next_cell_state=next_cell_state,
                beam_state=state,
                batch_size=batch_size,
                beam_width=beam_width,
                end_token=end_token,
                length_penalty_weight=length_penalty_weight,
                coverage_penalty_weight=coverage_penalty_weight)

            finished = beam_search_state.finished
            sample_ids = beam_search_output.predicted_ids
            next_inputs = control_flow_ops.cond(
                math_ops.reduce_all(finished), lambda: self._start_inputs,
                lambda: self._embedding_fn(sample_ids))

        return (beam_search_output, beam_search_state, next_inputs, finished)