def body(i, inputs_tuple, all_cache_key, all_cache_value, past_length, initial_id): """[This is the body of the top k top p decoder] Args: i ([tf.tensor]): [iterator (an int)] inputs ([List of model inputs]): [description] all_cache_key ([K]): [description] all_cache_value ([V]): [description] past_length ([tf.tensor (1 x batch_size)]): [description] This is our main output or decoded ids] initial_id ([tf.tensor]): [To keep track of concatanted ids generated in each iteration] Returns: [List of tensors]: [Outputs] """ inputs = {} for k in range(len(self.input_name_list)): inputs[self.input_name_list[k]] = inputs_tuple[k] inputs["all_cache_key"] = all_cache_key inputs["all_cache_value"] = all_cache_value inputs["past_length"] = past_length model_outputs = self.model(inputs) model_logits = model_outputs["last_token_logits"] if self.top_k > 0: model_logits = top_k_logits(model_logits, k=self.top_k) if self.top_p > 0: model_logits = top_p_logits(model_logits, p=self.top_p) if self.do_sample: prediction_ids = tf.random.categorical(model_logits, num_samples=1) input_ids = tf.cast(prediction_ids, tf.int32) else: prediction_ids = tf.argmax(model_logits, axis=1) input_ids = tf.cast(tf.expand_dims(prediction_ids, axis=1), tf.int32) inputs_tuple = [None] * len(self.input_name_list) for index, name in self.input_name_map.items(): if name == "input_ids": inputs_tuple[index] = input_ids # input_shapes_tuple.append(tf.TensorShape([None, None])) continue if name == "input_type_ids": inputs_tuple[index] = tf.ones_like( input_ids) * self.input_type_ids # input_shapes_tuple.append(tf.TensorShape([None, None])) continue if name == "input_mask": inputs_tuple[index] = tf.ones_like( input_ids) * self.input_mask_ids # input_shapes_tuple.append(tf.TensorShape([None, None])) continue # Convert to tuple inputs_tuple = tuple(inputs_tuple) return [ i + 1, inputs_tuple, model_outputs["all_cache_key"], model_outputs["all_cache_value"], model_outputs["past_length"], tf.concat([initial_id, input_ids], axis=1), ]
def call_top_k_top_p(inputs): """The main function to perform Top K top P (Nucleus) decoding Args: inputs ([dict]): [dict of tf.tensors (model inputs)] """ input_ids_orig = inputs["input_ids"] batch_size = tf.shape(inputs["input_ids"])[0] max_sequence_length = tf.shape(inputs["input_ids"])[1] if self.max_iterations is None: iterations = tf.squeeze(inputs["iterations"]) else: iterations = self.max_iterations model_inputs = {} for input_key, input_value in inputs.items(): if input_key == "iterations": continue model_inputs[input_key] = tf.repeat( input_value, [self.num_return_sequences], axis=0) # Updated batch size batch_size_updated = tf.shape(model_inputs["input_ids"])[0] # Pre-initialize addtional inputs zero_entry = tf.zeros(( self.num_hidden_layers, batch_size_updated, self.num_attention_heads, max_sequence_length, self.attention_state, )) all_cache_key = zero_entry all_cache_value = zero_entry # past_length for keeping track of positional ids past_length = tf.expand_dims( tf.zeros(batch_size_updated, dtype=tf.int32), 0) # Iterator to keep track of the loop i = tf.constant([[0]]) initial_id = tf.ones(shape=(batch_size_updated, 1), dtype=tf.int32) # Add remaining model inputs model_inputs["all_cache_key"] = all_cache_key model_inputs["all_cache_value"] = all_cache_value model_inputs["past_length"] = past_length if "input_type_ids" in self.input_name_list: model_inputs["input_type_ids"] = tf.ones_like( model_inputs["input_ids"]) * self.input_type_ids if "input_mask" in self.input_name_list: model_inputs["input_mask"] = tf.ones_like( model_inputs["input_ids"]) * self.input_mask_ids # First pass to the model model_outputs = self.model(model_inputs) model_logits = model_outputs["last_token_logits"] if self.top_k > 0: model_logits = top_k_logits(model_logits, k=self.top_k) if self.top_p > 0: model_logits = top_p_logits(model_logits, p=self.top_p) if self.do_sample: prediction_ids = tf.random.categorical(model_logits, num_samples=1) input_ids = tf.cast(prediction_ids, tf.int32) else: prediction_ids = tf.argmax(model_logits, axis=1) input_ids = tf.cast(tf.expand_dims(prediction_ids, axis=1), tf.int32) inputs_tuple = [None] * len(self.input_name_list) input_shapes_tuple = [tf.TensorShape([None, None])] * len( self.input_name_list) for index, name in self.input_name_map.items(): if name == "input_ids": inputs_tuple[index] = input_ids # input_shapes_tuple.append(tf.TensorShape([None, None])) continue if name == "input_type_ids": inputs_tuple[index] = tf.ones_like( input_ids) * self.input_type_ids # input_shapes_tuple.append(tf.TensorShape([None, None])) continue if name == "input_mask": inputs_tuple[index] = tf.ones_like( input_ids) * self.input_mask_ids # input_shapes_tuple.append(tf.TensorShape([None, None])) continue inputs_tuple = tuple(inputs_tuple) input_shapes_tuple = tuple(input_shapes_tuple) # Concatanate initial_id = tf.concat([initial_id, input_ids], axis=1) # on step 0 masks = tf.cast(tf.not_equal(model_inputs["input_ids"], -1), tf.float32) masks = tf.reshape( masks, (1, batch_size_updated, 1, tf.shape( model_inputs["input_ids"])[1], 1), ) all_cache_key = model_outputs["all_cache_key"] all_cache_value = model_outputs["all_cache_value"] all_cache_key = all_cache_key * masks all_cache_value = all_cache_value * masks # END results = tf.while_loop( cond, body, maximum_iterations=iterations - 1, loop_vars=[ i, inputs_tuple, all_cache_key, all_cache_value, model_outputs["past_length"], initial_id, ], shape_invariants=[ i.get_shape(), input_shapes_tuple, tf.TensorShape([ self.num_hidden_layers, None, self.num_attention_heads, None, self.attention_state, ]), tf.TensorShape([ self.num_hidden_layers, None, self.num_attention_heads, None, self.attention_state, ]), tf.TensorShape([None, None]), tf.TensorShape([None, None]), ], ) results_dict = {} results_dict["iterations"] = results[0] results_dict["input_ids"] = input_ids_orig # Skip -1 initial ids results_dict["predicted_ids"] = results[-1][:, 1:] results_dict["predicted_ids"] = tf.reshape( results_dict["predicted_ids"], (batch_size, self.num_return_sequences, -1), ) matched_positions = (tf.squeeze( tf.reshape( tf.argmax( tf.cast( tf.equal(self.eos_id, results_dict["predicted_ids"]), tf.int32, ), axis=2, ), (-1, batch_size * self.num_return_sequences), ), [0], ) - 1) # no eos matched positions will be 0, replace with -1 eos_pos_mask = tf.cast(tf.equal(matched_positions, 0), tf.int32) * -1 matched_positions = tf.cast(matched_positions, tf.int32) + eos_pos_mask results_dict["matched_eos_pos"] = matched_positions return results_dict
def call_beam(inputs): """The main function to perform beam search Args: inputs ([dict]): [dict of tf.tensors (model inputs)] """ input_ids_orig = inputs["input_ids"] # We take 2x beams beams_to_keep = 2 * self.beam_size # Original batch size and sequence length batch_size = tf.shape(inputs["input_ids"])[0] max_sequence_length = tf.shape(inputs["input_ids"])[1] # Repeat for beam search (We nedd batch_size x beam_size) model_inputs = {} for input_key, input_value in inputs.items(): if input_key == "iterations": continue model_inputs[input_key] = tf.repeat(input_value, [self.beam_size], axis=0) # New batch size batch_size_updated = tf.shape(model_inputs["input_ids"])[0] # Pre-initialize addtional inputs zero_entry = tf.zeros(( self.num_hidden_layers, batch_size_updated, self.num_attention_heads, max_sequence_length, self.attention_state, )) all_cache_key = zero_entry all_cache_value = zero_entry # past_length for keeping track of positional ids past_length = tf.expand_dims( tf.zeros(batch_size_updated, dtype=tf.int32), 0) # Iterator to keep track of the loop i = tf.constant([[0]]) if self.max_iterations is None: iterations = tf.squeeze(inputs["iterations"]) else: iterations = self.max_iterations # Add remaining model inputs model_inputs["all_cache_key"] = all_cache_key model_inputs["all_cache_value"] = all_cache_value model_inputs["past_length"] = past_length if "input_type_ids" in self.input_name_list: model_inputs["input_type_ids"] = tf.ones_like( model_inputs["input_ids"]) * self.input_type_ids if "input_mask" in self.input_name_list: model_inputs["input_mask"] = tf.ones_like( model_inputs["input_ids"]) * self.input_mask_ids # We need this to re-ordering and keep track of best -log(prob)) alive_log_probs = -np.inf * tf.ones( (batch_size, self.beam_size - 1)) alive_log_probs = tf.concat( [tf.zeros([batch_size, 1]), alive_log_probs], axis=1) alive_seq = tf.zeros((batch_size, self.beam_size, 1)) # First pass to the model model_outputs = self.model(model_inputs) model_logits = model_outputs["last_token_logits"] / self.temperature # Update iter i = i + 1 all_cache_key = model_outputs["all_cache_key"] all_cache_value = model_outputs["all_cache_value"] past_length = model_outputs["past_length"] if self.top_k > 0: model_logits = top_k_logits(model_logits, k=self.top_k) if self.top_p > 0: model_logits = top_p_logits(model_logits, p=self.top_p) # vocab size vocab_size = tf.shape(model_logits)[1] logits = tf.reshape(model_logits, (batch_size, self.beam_size, -1)) # # Convert logits to normalized log probs candidate_log_probs = _log_prob_from_logits(logits) # Calculate new log probabilities if each of the alive sequences were # extended # by the the candidate IDs. # Shape [batch_size, beam_size, vocab_size] log_probs = candidate_log_probs + tf.expand_dims( alive_log_probs, 2) # Calculate new log probabilities if each of the alive sequences were # extended # by the the candidate IDs. # Shape [batch_size, beam_size, vocab_size] log_probs = candidate_log_probs + tf.expand_dims(alive_log_probs, axis=2) # Add length penalty length_penalty = tf.pow( ((5.0 + (tf.cast(i, tf.float32) + 1.0)) / 6.0), self.alpha) log_probs = log_probs / length_penalty # Each batch item has beam_size * vocab_size candidate sequences. For each # batch item, get the k candidates with the highest log probabilities. flat_log_probs = tf.reshape(log_probs, [-1, self.beam_size * vocab_size]) if self.do_sample: next_tokens = tf.random.categorical( flat_log_probs, dtype=tf.int32, num_samples=beams_to_keep) # (batch_size, 2 * num_beams) # # Compute next scores next_scores = tf.gather( flat_log_probs, next_tokens, batch_dims=1) # (batch_size, 2 * num_beams) # # sort the sampled vector to make sure that the first num_beams # samples are the best next_scores_indices = tf.argsort(next_scores, direction="DESCENDING", axis=1) next_scores = tf.gather( next_scores, next_scores_indices, batch_dims=1) # (batch_size, num_beams * 2) next_tokens = tf.gather( next_tokens, next_scores_indices, batch_dims=1) # (batch_size, num_beams * 2) topk_log_probs = next_scores topk_indices = next_tokens else: topk_log_probs, topk_indices = tf.nn.top_k(flat_log_probs, k=beams_to_keep) topk_beam_indices = topk_indices // vocab_size topk_seq, coordinates = _gather_beams(alive_seq, topk_beam_indices, batch_size, beams_to_keep) topk_seq = tf.cast(topk_seq, tf.int32) topk_ids = topk_indices % vocab_size topk_seq = tf.concat( [topk_seq, tf.expand_dims(topk_ids, axis=2)], axis=2) topk_alive_seq = topk_seq[:, :self.beam_size, :] alive_log_probs = topk_log_probs[:, :self.beam_size] input_ids = tf.reshape(topk_ids[:, :self.beam_size], [-1, 1]) alive_seq = topk_alive_seq inputs_tuple = [None] * len(self.input_name_list) input_shapes_tuple = [tf.TensorShape([None, None])] * len( self.input_name_list) for index, name in self.input_name_map.items(): if name == "input_ids": inputs_tuple[index] = input_ids # input_shapes_tuple.append(tf.TensorShape([None, None])) continue if name == "input_type_ids": inputs_tuple[index] = tf.ones_like( input_ids) * self.input_type_ids # input_shapes_tuple.append(tf.TensorShape([None, None])) continue if name == "input_mask": inputs_tuple[index] = tf.ones_like( input_ids) * self.input_mask_ids # input_shapes_tuple.append(tf.TensorShape([None, None])) continue inputs_tuple = tuple(inputs_tuple) input_shapes_tuple = tuple(input_shapes_tuple) # on step 0 masks = tf.cast(tf.not_equal(model_inputs["input_ids"], -1), tf.float32) masks = tf.reshape( masks, (1, batch_size_updated, 1, tf.shape( model_inputs["input_ids"])[1], 1), ) all_cache_key = all_cache_key * masks all_cache_value = all_cache_value * masks all_cache_key, all_cache_value = self.reorder_past_batches( all_cache_key, all_cache_value, coordinates, self.beam_size) # END results = tf.while_loop( cond, body, maximum_iterations=iterations - 1, loop_vars=[ i, inputs_tuple, all_cache_key, all_cache_value, past_length, alive_log_probs, alive_seq, ], shape_invariants=[ i.get_shape(), input_shapes_tuple, tf.TensorShape([ self.num_hidden_layers, None, self.num_attention_heads, None, self.attention_state, ]), tf.TensorShape([ self.num_hidden_layers, None, self.num_attention_heads, None, self.attention_state, ]), tf.TensorShape([None, None]), tf.TensorShape([None, None]), tf.TensorShape([None, None, None]), ], ) results_dict = {} results_dict["iterations"] = results[0] results_dict["input_ids"] = input_ids_orig # Skip -1 initial ids results_dict["predicted_ids"] = results[ -1][:, :, 1:] # to remove initial 0 matched_positions = (tf.squeeze( tf.reshape( tf.argmax( tf.cast( tf.equal(self.eos_id, results_dict["predicted_ids"]), tf.int32, ), axis=2, ), (-1, batch_size * self.beam_size), ), [0], ) - 1) # no eos matched positions will be 0, replace with -1 eos_pos_mask = tf.cast(tf.equal(matched_positions, 0), tf.int32) * -1 matched_positions = tf.cast(matched_positions, tf.int32) + eos_pos_mask results_dict["matched_eos_pos"] = matched_positions return results_dict
def body( i, inputs_tuple, all_cache_key, all_cache_value, past_length, alive_log_probs, alive_seq, ): """[This is the body of the beam decoder] Args: i ([tf.tensor]): [iterator (an int)] inputs ([List of model inputs]): [description] all_cache_key ([K]): [description] all_cache_value ([V]): [description] past_length ([tf.tensor (1 x batch_size)]): [description] This is our main output or decoded ids] alive_log_probs ([tf.tensor]): [To keep track of active ids] alive_seq ([tf.tensor]): [description] Returns: [List of tensors]: [Outputs] """ inputs = {} for k in range(len(self.input_name_list)): inputs[self.input_name_list[k]] = inputs_tuple[k] inputs["all_cache_key"] = all_cache_key inputs["all_cache_value"] = all_cache_value inputs["past_length"] = past_length beams_to_keep = 2 * self.beam_size model_outputs = self.model(inputs) model_logits = model_outputs["last_token_logits"] all_cache_key = model_outputs["all_cache_key"] all_cache_value = model_outputs["all_cache_value"] past_length = model_outputs["past_length"] if self.top_k > 0: model_logits = top_k_logits(model_logits, k=self.top_k) if self.top_p > 0: model_logits = top_p_logits(model_logits, p=self.top_p) vocab_size = tf.shape(model_logits)[1] batch_size = tf.shape(inputs["input_ids"])[0] // self.beam_size logits = tf.reshape(model_logits, (batch_size, self.beam_size, -1)) # # Convert logits to normalized log probs candidate_log_probs = _log_prob_from_logits(logits) # Calculate new log probabilities if each of the alive sequences were # extended # by the the candidate IDs. # Shape [batch_size, beam_size, vocab_size] log_probs = candidate_log_probs + tf.expand_dims( alive_log_probs, 2) # Calculate new log probabilities if each of the alive sequences were # extended # by the the candidate IDs. # Shape [batch_size, beam_size, vocab_size] log_probs = candidate_log_probs + tf.expand_dims(alive_log_probs, axis=2) # Add length penalty length_penalty = tf.pow( ((5.0 + (tf.cast(i, tf.float32) + 1.0)) / 6.0), self.alpha) log_probs = log_probs / length_penalty # Each batch item has beam_size * vocab_size candidate sequences. For each # batch item, get the k candidates with the highest log probabilities. flat_log_probs = tf.reshape(log_probs, [-1, self.beam_size * vocab_size]) if self.do_sample: next_tokens = tf.random.categorical( flat_log_probs, dtype=tf.int32, num_samples=beams_to_keep) # (batch_size, 2 * num_beams) # # Compute next scores next_scores = tf.gather( flat_log_probs, next_tokens, batch_dims=1) # (batch_size, 2 * num_beams) # # sort the sampled vector to make sure that the first num_beams # samples are the best next_scores_indices = tf.argsort(next_scores, direction="DESCENDING", axis=1) next_scores = tf.gather( next_scores, next_scores_indices, batch_dims=1) # (batch_size, num_beams * 2) next_tokens = tf.gather( next_tokens, next_scores_indices, batch_dims=1) # (batch_size, num_beams * 2) topk_log_probs = next_scores topk_indices = next_tokens else: topk_log_probs, topk_indices = tf.nn.top_k(flat_log_probs, k=beams_to_keep) topk_beam_indices = topk_indices // vocab_size topk_seq, coordinates = _gather_beams(alive_seq, topk_beam_indices, batch_size, beams_to_keep) topk_seq = tf.cast(topk_seq, tf.int32) topk_ids = topk_indices % vocab_size topk_seq = tf.concat( [topk_seq, tf.expand_dims(topk_ids, axis=2)], axis=2) topk_alive_seq = topk_seq[:, :self.beam_size, :] alive_log_probs = topk_log_probs[:, :self.beam_size] input_ids = tf.reshape(topk_ids[:, :self.beam_size], [-1, 1]) alive_seq = topk_alive_seq inputs_tuple = [None] * len(self.input_name_list) for index, name in self.input_name_map.items(): if name == "input_ids": inputs_tuple[index] = input_ids # input_shapes_tuple.append(tf.TensorShape([None, None])) continue if name == "input_type_ids": inputs_tuple[index] = tf.ones_like( input_ids) * self.input_type_ids # input_shapes_tuple.append(tf.TensorShape([None, None])) continue if name == "input_mask": inputs_tuple[index] = tf.ones_like( input_ids) * self.input_mask_ids # input_shapes_tuple.append(tf.TensorShape([None, None])) continue # Convert to tuple inputs_tuple = tuple(inputs_tuple) all_cache_key, all_cache_value = self.reorder_past_batches( all_cache_key, all_cache_value, coordinates, self.beam_size) model_outputs["all_cache_key"] = all_cache_key model_outputs["all_cache_value"] = all_cache_value return [ i + 1, inputs_tuple, model_outputs["all_cache_key"], model_outputs["all_cache_value"], model_outputs["past_length"], alive_log_probs, alive_seq, ]
def top_k_top_p( self, tokenized_input_dict, max_iterations, top_k=0, top_p=0, temperature=1.0, do_sample=True, num_return_sequences=1, eos_id=-100, ): # We need this to return input_ids_original = tokenized_input_dict["encoder_input_ids"] batch_size = tf.shape(tokenized_input_dict["encoder_input_ids"])[0] # Repeat for beam search tokenized_input_dict_ragged = {} for input_key, input_value in tokenized_input_dict.items(): tokenized_input_dict_ragged[input_key] = tf.repeat( input_value, [num_return_sequences], axis=0) # We take 2x beams batch_size_updated = tf.shape( tokenized_input_dict_ragged["encoder_input_ids"])[0] decoder_start_sequence_length = 1 # Initialize with zeros encoder_sequence_length = tf.shape( tokenized_input_dict_ragged["encoder_input_ids"])[1] decoder_start_sequence_length = 1 encoder_hidden_states = tf.zeros( (batch_size_updated, encoder_sequence_length, self.embedding_size)) all_cache_key = tf.zeros(( self.decoder_num_hidden_layers, batch_size_updated, self.decoder_num_attention_heads, decoder_start_sequence_length, self.decoder_attention_state, )) all_cache_value = tf.zeros(( self.decoder_num_hidden_layers, batch_size_updated, self.decoder_num_attention_heads, decoder_start_sequence_length, self.decoder_attention_state, )) # Inputs ready tokenized_input_dict_ragged["decoder_input_ids"] = tf.cast( tf.ones(shape=(batch_size_updated, 1)) * self.decode_start_token_id, tf.int32, ) tokenized_input_dict_ragged["decoder_all_cache_key"] = all_cache_key tokenized_input_dict_ragged[ "decoder_all_cache_value"] = all_cache_value tokenized_input_dict_ragged[ "encoder_hidden_states"] = encoder_hidden_states if self.decoder_input_type_ids > -1: tokenized_input_dict_ragged["decoder_input_type_ids"] = ( tf.ones_like(tokenized_input_dict_ragged["decoder_input_ids"]) * self.decoder_input_type_ids) all_predictions = [] matched_positions = tf.constant([-1] * batch_size_updated) # Iterate Over for i in range(max_iterations): result = self.model_fn(tokenized_input_dict_ragged) model_logits = result["last_token_logits"] all_cache_key = result["decoder_all_cache_key"] all_cache_value = result["decoder_all_cache_value"] encoder_hidden_states = result["encoder_hidden_states"] model_logits = model_logits / temperature if top_k > 0: model_logits = top_k_logits(model_logits, k=top_k) if top_p > 0: model_logits = top_p_logits(model_logits, p=top_p) if do_sample: prediction_ids = tf.random.categorical(model_logits, num_samples=1) input_ids = tf.cast(prediction_ids, tf.int32) else: prediction_ids = tf.argmax(model_logits, axis=1) input_ids = tf.cast(tf.expand_dims(prediction_ids, axis=1), tf.int32) all_predictions.append(input_ids) tokenized_input_dict_ragged["decoder_input_ids"] = tf.cast( input_ids, tf.int32) tokenized_input_dict_ragged[ "decoder_all_cache_key"] = all_cache_key tokenized_input_dict_ragged[ "decoder_all_cache_value"] = all_cache_value tokenized_input_dict_ragged[ "encoder_hidden_states"] = encoder_hidden_states eos_check = tf.greater( tf.reduce_prod( tf.reduce_sum( tf.cast( tf.equal(tf.concat(all_predictions, axis=1), eos_id), tf.int32), axis=[1], )), 0, ) if eos_check: break matched_positions = (tf.reshape( tf.argmax( tf.cast(tf.equal(eos_id, tf.concat(all_predictions, axis=1)), tf.int32), axis=1, ), -1, ) - 1) # no eos matched positions will be 0, replace with -1 eos_pos_mask = tf.cast(tf.equal(matched_positions, 0), tf.int32) * -1 matched_positions = tf.cast(matched_positions, tf.int32) + eos_pos_mask all_predictions = tf.reshape(tf.concat(all_predictions, axis=1), (batch_size, num_return_sequences, -1)) return { "iterations": i + 1, "input_ids": input_ids_original, "predicted_ids": all_predictions, "matched_eos_pos": matched_positions, }
def beam_decode( self, tokenized_input_dict, beam_size, max_iterations, temperature=1.0, alpha=0.0, top_k=0, top_p=0, do_sample=False, eos_id=-100, ): """Supports Variable Batch Decoding for GPT2 text_list: a list of text beam_size: int length: number of steps to decode vocab_size: vocabulary size do_sample: Using multinomial distribution to \ sample the most likely word, still uses beam eos_ids: list of IDS, to consider as decoder stop """ # We need this to return input_ids_original = tokenized_input_dict["encoder_input_ids"] batch_size = tf.shape(tokenized_input_dict["encoder_input_ids"])[0] tokenized_input_dict_ragged = {} # Repeat for beam search for input_key, input_value in tokenized_input_dict.items(): tokenized_input_dict_ragged[input_key] = tf.repeat(input_value, [beam_size], axis=0) # We take 2x beams beams_to_keep = 2 * beam_size batch_size_updated = tf.shape( tokenized_input_dict_ragged["encoder_input_ids"])[0] decoder_start_sequence_length = 1 # Initialize with zeros encoder_sequence_length = tf.shape( tokenized_input_dict_ragged["encoder_input_ids"])[1] decoder_start_sequence_length = 1 encoder_hidden_states = tf.zeros( (batch_size_updated, encoder_sequence_length, self.embedding_size)) all_cache_key = tf.zeros(( self.decoder_num_hidden_layers, batch_size_updated, self.decoder_num_attention_heads, decoder_start_sequence_length, self.decoder_attention_state, )) all_cache_value = tf.zeros(( self.decoder_num_hidden_layers, batch_size_updated, self.decoder_num_attention_heads, decoder_start_sequence_length, self.decoder_attention_state, )) # Inputs ready tokenized_input_dict_ragged["decoder_input_ids"] = tf.cast( tf.ones(shape=(batch_size_updated, 1)) * self.decode_start_token_id, tf.int32, ) tokenized_input_dict_ragged["decoder_all_cache_key"] = all_cache_key tokenized_input_dict_ragged[ "decoder_all_cache_value"] = all_cache_value tokenized_input_dict_ragged[ "encoder_hidden_states"] = encoder_hidden_states if self.decoder_input_type_ids > -1: tokenized_input_dict_ragged["decoder_input_type_ids"] = ( tf.ones_like(tokenized_input_dict_ragged["decoder_input_ids"]) * self.decoder_input_type_ids) matched_positions = tf.constant([-1] * batch_size_updated) alive_log_probs = -np.inf * tf.ones((batch_size, beam_size - 1)) # alive_log_probs = tf.zeros((batch_size, beam_size-1)) alive_log_probs = tf.concat( [tf.zeros([batch_size, 1]), alive_log_probs], axis=1) alive_seq = tf.zeros((batch_size, beam_size, 1)) for i in range(max_iterations): result = self.model_fn(tokenized_input_dict_ragged) model_logits = result["last_token_logits"] all_cache_key = result["decoder_all_cache_key"] all_cache_value = result["decoder_all_cache_value"] encoder_hidden_states = result["encoder_hidden_states"] model_logits = model_logits / temperature if top_k > 0: model_logits = top_k_logits(model_logits, k=top_k) if top_p > 0: model_logits = top_p_logits(model_logits, p=top_p) vocab_size = tf.shape(model_logits)[1] logits = tf.reshape(model_logits, (batch_size, beam_size, -1)) # # Convert logits to normalized log probs candidate_log_probs = _log_prob_from_logits(logits) # Calculate new log probabilities if each of the alive sequences were # extended # by the the candidate IDs. # Shape [batch_size, beam_size, vocab_size] log_probs = candidate_log_probs + tf.expand_dims( alive_log_probs, 2) # Calculate new log probabilities if each of the alive sequences were # extended # by the the candidate IDs. # Shape [batch_size, beam_size, vocab_size] log_probs = candidate_log_probs + tf.expand_dims(alive_log_probs, axis=2) # Add length penalty length_penalty = tf.pow( ((5.0 + (tf.cast(i, tf.float32) + 1.0)) / 6.0), alpha) log_probs = log_probs / length_penalty # Each batch item has beam_size * vocab_size candidate sequences. For each # batch item, get the k candidates with the highest log probabilities. flat_log_probs = tf.reshape(log_probs, [-1, beam_size * vocab_size]) if do_sample: next_tokens = tf.random.categorical( flat_log_probs, dtype=tf.int32, num_samples=beams_to_keep) # (batch_size, 2 * num_beams) # # Compute next scores next_scores = tf.gather( flat_log_probs, next_tokens, batch_dims=1) # (batch_size, 2 * num_beams) # # sort the sampled vector to make sure that \ # the first num_beams samples are the best next_scores_indices = tf.argsort(next_scores, direction="DESCENDING", axis=1) next_scores = tf.gather( next_scores, next_scores_indices, batch_dims=1) # (batch_size, num_beams * 2) next_tokens = tf.gather( next_tokens, next_scores_indices, batch_dims=1) # (batch_size, num_beams * 2) topk_log_probs = next_scores topk_indices = next_tokens else: topk_log_probs, topk_indices = tf.nn.top_k( flat_log_probs, k=beams_to_keep) # (batch_size x k (beams_to_keep)) topk_beam_indices = topk_indices // vocab_size topk_seq, coordinates = _gather_beams(alive_seq, topk_beam_indices, batch_size, beams_to_keep) topk_seq = tf.cast(topk_seq, tf.int32) topk_ids = topk_indices % vocab_size topk_seq = tf.concat( [topk_seq, tf.expand_dims(topk_ids, axis=2)], axis=2) topk_alive_seq = topk_seq[:, :beam_size, :] alive_log_probs = topk_log_probs[:, :beam_size] input_ids = tf.reshape(topk_ids[:, :beam_size], [-1, 1]) alive_seq = topk_alive_seq all_cache_key, all_cache_value = self.reorder_past_batches( all_cache_key, all_cache_value, coordinates, beam_size) tokenized_input_dict_ragged["decoder_input_ids"] = tf.cast( input_ids, tf.int32) tokenized_input_dict_ragged[ "decoder_all_cache_key"] = all_cache_key tokenized_input_dict_ragged[ "decoder_all_cache_value"] = all_cache_value tokenized_input_dict_ragged[ "encoder_hidden_states"] = encoder_hidden_states eos_check = tf.greater( tf.reduce_prod( tf.reduce_sum( tf.cast(tf.equal(topk_alive_seq, eos_id), tf.int32), axis=[2], )), 0, ) if eos_check: break matched_positions = (tf.reshape( tf.argmax(tf.cast(tf.equal(eos_id, topk_alive_seq), tf.int32), axis=2), -1, ) - 1) # no eos matched positions will be 0, replace with -1 eos_pos_mask = tf.cast(tf.equal(matched_positions, 0), tf.int32) * -1 matched_positions = tf.cast(matched_positions, tf.int32) + eos_pos_mask return { "iterations": i + 1, "input_ids": input_ids_original, "predicted_ids": topk_alive_seq[:, :, 1:], # to avoid initial 0 "matched_eos_pos": matched_positions - 1, }
def top_k_top_p( self, tokenized_input_dict, max_iterations, top_k=0, top_p=0, temperature=1.0, do_sample=True, num_return_sequences=1, eos_id=-100, ): # We need this to return input_ids_original = tokenized_input_dict["input_ids"] batch_size = tf.shape(input_ids_original)[0] max_sequence_length = tf.shape(input_ids_original)[1] # Repeat for beam search tokenized_input_dict_ragged = {} for input_key, input_value in tokenized_input_dict.items(): tokenized_input_dict_ragged[input_key] = tf.repeat(input_value, [num_return_sequences], axis=0) # We take 2x beams batch_size_updated = tokenized_input_dict_ragged["input_ids"].shape[0] # Initialize with zeros zero_entry = tf.zeros( ( self.num_hidden_layers, batch_size_updated, self.num_attention_heads, max_sequence_length, self.attention_state, ) ) all_cache_key = zero_entry all_cache_value = zero_entry past_length = tf.expand_dims(tf.zeros(batch_size_updated, dtype=tf.int32), 0) # Inputs ready tokenized_input_dict_ragged["all_cache_key"] = all_cache_key tokenized_input_dict_ragged["all_cache_value"] = all_cache_value tokenized_input_dict_ragged["past_length"] = past_length if self.input_type_ids > -1: tokenized_input_dict_ragged["input_type_ids"] = ( tf.ones_like(tokenized_input_dict_ragged["input_ids"]) * self.input_type_ids ) if self.input_mask_ids > -1: tokenized_input_dict_ragged["input_mask"] = ( tf.ones_like(tokenized_input_dict_ragged["input_ids"]) * self.input_mask_ids ) all_predictions = [] matched_positions = tf.constant([-1] * batch_size_updated) # Iterate Over for i in range(max_iterations): result = self.model_fn(tokenized_input_dict_ragged) model_logits = result["last_token_logits"] all_cache_key = result["all_cache_key"] all_cache_value = result["all_cache_value"] past_length = result["past_length"] if top_k > 0: model_logits = top_k_logits(model_logits, k=top_k) if top_p > 0: model_logits = top_p_logits(model_logits, p=top_p) if do_sample: prediction_ids = tf.random.categorical(model_logits, num_samples=1) input_ids = tf.cast(prediction_ids, tf.int32) else: prediction_ids = tf.argmax(model_logits, axis=1) input_ids = tf.cast(tf.expand_dims(prediction_ids, axis=1), tf.int32) all_predictions.append(input_ids) if i == 0: # all_cache_key = assign_zeros_to_K_V(all_cache_key, \ # input_ids_copy, batch_size, max_sequence_length) # all_cache_value = assign_zeros_to_K_V(all_cache_value, \ # input_ids_copy, batch_size, max_sequence_length) masks = tf.cast( tf.not_equal(tokenized_input_dict_ragged["input_ids"], -1), tf.float32, ) masks = tf.reshape(masks, (1, batch_size_updated, 1, max_sequence_length, 1)) all_cache_key = all_cache_key * masks all_cache_value = all_cache_value * masks tokenized_input_dict_ragged["input_ids"] = tf.cast(input_ids, tf.int32) tokenized_input_dict_ragged["all_cache_key"] = all_cache_key tokenized_input_dict_ragged["all_cache_value"] = all_cache_value tokenized_input_dict_ragged["past_length"] = past_length if self.input_type_ids > -1: tokenized_input_dict_ragged["input_type_ids"] = ( tf.ones_like(tokenized_input_dict_ragged["input_ids"]) * self.input_type_ids ) if self.input_mask_ids > -1: tokenized_input_dict_ragged["input_mask"] = ( tf.ones_like(tokenized_input_dict_ragged["input_ids"]) * self.input_mask_ids ) if eos_id: temp_m = tf.concat(all_predictions, axis=1) eos_check = tf.greater( tf.reduce_prod(tf.reduce_sum(tf.cast(tf.equal(temp_m, eos_id), tf.int32), axis=[1])), 0, ) if eos_check: matched_positions = tf.argmax(tf.cast(tf.equal(eos_id, temp_m), tf.int32), axis=1) # matched_positions += 1 break matched_positions = ( tf.reshape( tf.argmax( tf.cast(tf.equal(eos_id, tf.concat(all_predictions, axis=1)), tf.int32), axis=1, ), -1, ) - 1 ) # no eos matched positions will be 0, replace with -1 eos_pos_mask = tf.cast(tf.equal(matched_positions, 0), tf.int32) * -1 matched_positions = tf.cast(matched_positions, tf.int32) + eos_pos_mask all_predictions = tf.reshape(tf.concat(all_predictions, axis=1), (batch_size, num_return_sequences, -1)) return { "iterations": i + 1, "input_ids": input_ids_original, "predicted_ids": all_predictions, "matched_eos_pos": matched_positions, }