def get_initial_loop_state(self) -> BeamSearchLoopState: """Construct the initial loop state for the beam search decoder. During the construction, the body function of the underlying decoder is called once to retrieve the initial log probabilities of the first token. The values are initialized as follows: - ``search_state`` - ``logprob_sum`` - For each sentence in batch, logprob sum of the first hypothesis in the beam is set to zero while the others are set to negative infinity. - ``prev_logprobs`` - This is the softmax over the logits from the initial decoder step. - ``lengths`` - All zeros. - ``finshed`` - All false. - ``search_results`` - ``scores`` - A (1, batch, beam)-sized tensor of zeros. - ``token_ids`` - A (1, batch, beam)-sized tensor filled with indices of decoder-specific initial input symbols (usually start symbol IDs). - ``decoder_loop_state`` - The loop state of the underlying autoregressive decoder, as returned from the initial call to the body function. Returns: A populated ``BeamSearchLoopState`` structure. """ # Get the initial loop state of the underlying decoder. Then, expand # the tensors from the loop state to (batch * beam) and inject them # back into the decoder loop state. dec_init_ls = self.parent_decoder.get_initial_loop_state() feedables = tf.contrib.framework.nest.map_structure( self.expand_to_beam, dec_init_ls.feedables) histories = tf.contrib.framework.nest.map_structure( lambda x: self.expand_to_beam(x, dim=1), dec_init_ls.histories) # constants = tf.constant(0) # if dec_init_ls.constants: # constants = tf.contrib.framework.nest.map_structure( # self.expand_to_beam, dec_init_ls.constants) dec_init_ls = dec_init_ls._replace(feedables=feedables, histories=histories) # constants=constants) # Call the decoder body function with the expanded loop state to get # the log probabilities of the possible first tokens. decoder_body = self.parent_decoder.get_body(False) dec_next_ls = decoder_body(*dec_init_ls) # Construct the initial loop state of the beam search decoder. To allow # ensembling, the values are replaced with placeholders with a default # value. Despite this is necessary only for variables that grow in # time, the placeholder replacement is done on the whole structures, as # you can see below. search_state = SearchState( logprob_sum=tf.tile(tf.expand_dims([0.0] + [-INF] * (self.beam_size - 1), 0), [self.batch_size, 1], name="bs_logprob_sum"), prev_logprobs=tf.reshape( tf.nn.log_softmax(dec_next_ls.feedables.prev_logits), [self.batch_size, self.beam_size, len(self.vocabulary)]), lengths=tf.zeros([self.batch_size, self.beam_size], dtype=tf.int32, name="bs_lengths"), finished=tf.zeros([self.batch_size, self.beam_size], dtype=tf.bool)) # We add the input_symbol to token_ids during search_results # initialization for simpler beam_body implementation search_results = SearchResults( scores=tf.zeros(shape=[1, self.batch_size, self.beam_size], dtype=tf.float32, name="beam_scores"), token_ids=tf.reshape(feedables.input_symbol, [1, self.batch_size, self.beam_size], name="beam_tokens"), parent_ids=tf.zeros(shape=[1, self.batch_size, self.beam_size], dtype=tf.int32, name="parent_ids")) # In structures that contain tensors that grow in time, we replace # tensors with placeholders with loosened shape constraints in the time # dimension. dec_next_ls = tf.contrib.framework.nest.map_structure( lambda x: tf.placeholder_with_default( x, get_state_shape_invariants(x)), dec_next_ls) search_results = tf.contrib.framework.nest.map_structure( lambda x: tf.placeholder_with_default( x, get_state_shape_invariants(x)), search_results) return BeamSearchLoopState(search_state=search_state, search_results=search_results, decoder_loop_state=dec_next_ls)
def get_initial_loop_state(self) -> BeamSearchLoopState: """Construct the initial loop state for the beam search decoder. During the construction, the body function of the underlying decoder is called once to retrieve the initial log probabilities of the first token. The values are initialized as follows: - ``search_state`` - ``logprob_sum`` - For each sentence in batch, logprob sum of the first hypothesis in the beam is set to zero while the others are set to negative infinity. - ``prev_logprobs`` - This is the softmax over the logits from the initial decoder step. - ``lengths`` - All zeros. - ``finshed`` - All false. - ``search_results`` - ``scores`` - A (batch, beam)-sized tensor of zeros. - ``token_ids`` - A (1, batch, beam)-sized tensor filled with indices of decoder-specific initial input symbols (usually start symbol IDs). - ``decoder_loop_state`` - The loop state of the underlying autoregressive decoder, as returned from the initial call to the body function. Returns: A populated ``BeamSearchLoopState`` structure. """ # Get the initial loop state of the underlying decoder. Then, expand # the tensors from the loop state to (batch * beam) and inject them # back into the decoder loop state. dec_init_ls = self.parent_decoder.get_initial_loop_state() feedables = tf.contrib.framework.nest.map_structure( self.expand_to_beam, dec_init_ls.feedables) histories = tf.contrib.framework.nest.map_structure( lambda x: self.expand_to_beam(x, dim=1), dec_init_ls.histories) constants = tf.constant(0) if dec_init_ls.constants: constants = tf.contrib.framework.nest.map_structure( self.expand_to_beam, dec_init_ls.constants) dec_init_ls = dec_init_ls._replace( feedables=feedables, histories=histories, constants=constants) # Call the decoder body function with the expanded loop state to get # the log probabilities of the possible first tokens. decoder_body = self.parent_decoder.get_body(False) dec_next_ls = decoder_body(*dec_init_ls) # Construct the initial loop state of the beam search decoder. To allow # ensembling, the values are replaced with placeholders with a default # value. Despite this is necessary only for variables that grow in # time, the placeholder replacement is done on the whole structures, as # you can see below. logits = dec_next_ls.histories.logits[-1, :, :] search_state = SearchState( logprob_sum=tf.tile( tf.expand_dims([0.0] + [-INF] * (self.beam_size - 1), 0), [self.batch_size, 1], name="bs_logprob_sum"), prev_logprobs=tf.reshape( tf.nn.log_softmax(logits), [self.batch_size, self.beam_size, len(self.vocabulary)]), lengths=tf.zeros( [self.batch_size, self.beam_size], dtype=tf.int32, name="bs_lengths"), finished=tf.zeros( [self.batch_size, self.beam_size], dtype=tf.bool)) # We add the input_symbol to token_ids during search_results # initialization for simpler beam_body implementation input_symbols = dec_next_ls.histories.output_symbols[-1, :] search_results = SearchResults( scores=tf.zeros( shape=[self.batch_size, self.beam_size], dtype=tf.float32, name="beam_scores"), token_ids=tf.reshape( input_symbols, [1, self.batch_size, self.beam_size], name="beam_tokens")) # In structures that contain tensors that grow in time, we replace # tensors with placeholders with loosened shape constraints in the time # dimension. dec_next_ls = tf.contrib.framework.nest.map_structure( lambda x: tf.placeholder_with_default( x, get_state_shape_invariants(x)), dec_next_ls) search_results = tf.contrib.framework.nest.map_structure( lambda x: tf.placeholder_with_default( x, get_state_shape_invariants(x)), search_results) return BeamSearchLoopState( search_state=search_state, search_results=search_results, decoder_loop_state=dec_next_ls)