Ejemplo n.º 1
0
    def get_initial_loop_state(self) -> BeamSearchLoopState:
        """Construct the initial loop state for the beam search decoder.

        During the construction, the body function of the underlying decoder
        is called once to retrieve the initial log probabilities of the first
        token.

        The values are initialized as follows:

        - ``search_state``
            - ``logprob_sum`` - For each sentence in batch, logprob sum of the
              first hypothesis in the beam is set to zero while the others are
              set to negative infinity.
            - ``prev_logprobs`` - This is the softmax over the logits from the
              initial decoder step.
            - ``lengths`` - All zeros.
            - ``finshed`` - All false.

        - ``search_results``
            - ``scores`` - A (1, batch, beam)-sized tensor of zeros.
            - ``token_ids`` - A (1, batch, beam)-sized tensor filled with
              indices of decoder-specific initial input symbols (usually start
              symbol IDs).

        - ``decoder_loop_state`` - The loop state of the underlying
            autoregressive decoder, as returned from the initial call to the
            body function.

        Returns:
            A populated ``BeamSearchLoopState`` structure.
        """
        # Get the initial loop state of the underlying decoder. Then, expand
        # the tensors from the loop state to (batch * beam) and inject them
        # back into the decoder loop state.

        dec_init_ls = self.parent_decoder.get_initial_loop_state()

        feedables = tf.contrib.framework.nest.map_structure(
            self.expand_to_beam, dec_init_ls.feedables)
        histories = tf.contrib.framework.nest.map_structure(
            lambda x: self.expand_to_beam(x, dim=1), dec_init_ls.histories)

        # constants = tf.constant(0)
        # if dec_init_ls.constants:
        #     constants = tf.contrib.framework.nest.map_structure(
        #         self.expand_to_beam, dec_init_ls.constants)

        dec_init_ls = dec_init_ls._replace(feedables=feedables,
                                           histories=histories)
        # constants=constants)

        # Call the decoder body function with the expanded loop state to get
        # the log probabilities of the possible first tokens.

        decoder_body = self.parent_decoder.get_body(False)
        dec_next_ls = decoder_body(*dec_init_ls)

        # Construct the initial loop state of the beam search decoder. To allow
        # ensembling, the values are replaced with placeholders with a default
        # value. Despite this is necessary only for variables that grow in
        # time, the placeholder replacement is done on the whole structures, as
        # you can see below.

        search_state = SearchState(
            logprob_sum=tf.tile(tf.expand_dims([0.0] + [-INF] *
                                               (self.beam_size - 1), 0),
                                [self.batch_size, 1],
                                name="bs_logprob_sum"),
            prev_logprobs=tf.reshape(
                tf.nn.log_softmax(dec_next_ls.feedables.prev_logits),
                [self.batch_size, self.beam_size,
                 len(self.vocabulary)]),
            lengths=tf.zeros([self.batch_size, self.beam_size],
                             dtype=tf.int32,
                             name="bs_lengths"),
            finished=tf.zeros([self.batch_size, self.beam_size],
                              dtype=tf.bool))

        # We add the input_symbol to token_ids during search_results
        # initialization for simpler beam_body implementation

        search_results = SearchResults(
            scores=tf.zeros(shape=[1, self.batch_size, self.beam_size],
                            dtype=tf.float32,
                            name="beam_scores"),
            token_ids=tf.reshape(feedables.input_symbol,
                                 [1, self.batch_size, self.beam_size],
                                 name="beam_tokens"),
            parent_ids=tf.zeros(shape=[1, self.batch_size, self.beam_size],
                                dtype=tf.int32,
                                name="parent_ids"))

        # In structures that contain tensors that grow in time, we replace
        # tensors with placeholders with loosened shape constraints in the time
        # dimension.

        dec_next_ls = tf.contrib.framework.nest.map_structure(
            lambda x: tf.placeholder_with_default(
                x, get_state_shape_invariants(x)), dec_next_ls)

        search_results = tf.contrib.framework.nest.map_structure(
            lambda x: tf.placeholder_with_default(
                x, get_state_shape_invariants(x)), search_results)

        return BeamSearchLoopState(search_state=search_state,
                                   search_results=search_results,
                                   decoder_loop_state=dec_next_ls)
Ejemplo n.º 2
0
    def get_initial_loop_state(self) -> BeamSearchLoopState:
        """Construct the initial loop state for the beam search decoder.

        During the construction, the body function of the underlying decoder
        is called once to retrieve the initial log probabilities of the first
        token.

        The values are initialized as follows:

        - ``search_state``
            - ``logprob_sum`` - For each sentence in batch, logprob sum of the
              first hypothesis in the beam is set to zero while the others are
              set to negative infinity.
            - ``prev_logprobs`` - This is the softmax over the logits from the
              initial decoder step.
            - ``lengths`` - All zeros.
            - ``finshed`` - All false.

        - ``search_results``
            - ``scores`` - A (batch, beam)-sized tensor of zeros.
            - ``token_ids`` - A (1, batch, beam)-sized tensor filled with
              indices of decoder-specific initial input symbols (usually start
              symbol IDs).

        - ``decoder_loop_state`` - The loop state of the underlying
            autoregressive decoder, as returned from the initial call to the
            body function.

        Returns:
            A populated ``BeamSearchLoopState`` structure.
        """
        # Get the initial loop state of the underlying decoder. Then, expand
        # the tensors from the loop state to (batch * beam) and inject them
        # back into the decoder loop state.

        dec_init_ls = self.parent_decoder.get_initial_loop_state()

        feedables = tf.contrib.framework.nest.map_structure(
            self.expand_to_beam, dec_init_ls.feedables)
        histories = tf.contrib.framework.nest.map_structure(
            lambda x: self.expand_to_beam(x, dim=1), dec_init_ls.histories)

        constants = tf.constant(0)
        if dec_init_ls.constants:
            constants = tf.contrib.framework.nest.map_structure(
                self.expand_to_beam, dec_init_ls.constants)

        dec_init_ls = dec_init_ls._replace(
            feedables=feedables,
            histories=histories,
            constants=constants)

        # Call the decoder body function with the expanded loop state to get
        # the log probabilities of the possible first tokens.

        decoder_body = self.parent_decoder.get_body(False)
        dec_next_ls = decoder_body(*dec_init_ls)

        # Construct the initial loop state of the beam search decoder. To allow
        # ensembling, the values are replaced with placeholders with a default
        # value. Despite this is necessary only for variables that grow in
        # time, the placeholder replacement is done on the whole structures, as
        # you can see below.

        logits = dec_next_ls.histories.logits[-1, :, :]
        search_state = SearchState(
            logprob_sum=tf.tile(
                tf.expand_dims([0.0] + [-INF] * (self.beam_size - 1), 0),
                [self.batch_size, 1],
                name="bs_logprob_sum"),
            prev_logprobs=tf.reshape(
                tf.nn.log_softmax(logits),
                [self.batch_size, self.beam_size, len(self.vocabulary)]),
            lengths=tf.zeros(
                [self.batch_size, self.beam_size], dtype=tf.int32,
                name="bs_lengths"),
            finished=tf.zeros(
                [self.batch_size, self.beam_size], dtype=tf.bool))

        # We add the input_symbol to token_ids during search_results
        # initialization for simpler beam_body implementation

        input_symbols = dec_next_ls.histories.output_symbols[-1, :]
        search_results = SearchResults(
            scores=tf.zeros(
                shape=[self.batch_size, self.beam_size],
                dtype=tf.float32,
                name="beam_scores"),
            token_ids=tf.reshape(
                input_symbols,
                [1, self.batch_size, self.beam_size],
                name="beam_tokens"))

        # In structures that contain tensors that grow in time, we replace
        # tensors with placeholders with loosened shape constraints in the time
        # dimension.

        dec_next_ls = tf.contrib.framework.nest.map_structure(
            lambda x: tf.placeholder_with_default(
                x, get_state_shape_invariants(x)),
            dec_next_ls)

        search_results = tf.contrib.framework.nest.map_structure(
            lambda x: tf.placeholder_with_default(
                x, get_state_shape_invariants(x)),
            search_results)

        return BeamSearchLoopState(
            search_state=search_state,
            search_results=search_results,
            decoder_loop_state=dec_next_ls)