def SequenceConcat(x, x_paddings, y, y_paddings, pad=0): """Concats sequence `x` with sequence `y`. This function is length aware (based off the paddings). Args: x: A sequence of tokens of shape [batch_size, x_len_max]. x_paddings: The paddings of `x`. y: A sequence of tokens of shape [batch_size, y_len_max]. y_paddings: The paddings of `y`. pad: The <pad> token to fill the concatenated sequence (of type integer). Returns: A tuple. - Concatenation of `x` and `y` of shape [batch_size, x_len_max + y_len_max]. - Paddings of the concatenation of shape [batch_size, x_len_max + y_len_max]. """ # Get the length (w/ eos). x_len = tf.cast(tf.round(tf.reduce_sum(1 - x_paddings, 1)), tf.int32) y_len = tf.cast(tf.round(tf.reduce_sum(1 - y_paddings, 1)), tf.int32) batch_size = py_utils.GetShape(x)[0] y_len_max = py_utils.GetShape(y)[1] # Pad `x` with necessary <pad>. x = tf.concat([x, tf.fill(py_utils.GetShape(y), pad)], 1) # Replace all <pad> with 0. x = tf.where(tf.not_equal(x, pad), x, tf.fill(py_utils.GetShape(x), 0)) # Compute the write indices of `y` in `xy`. indices = tf.stack([ tf.tile(tf.expand_dims(tf.range(batch_size), 1), [1, y_len_max]), (tf.tile(tf.expand_dims(tf.range(y_len_max), 0), [batch_size, 1]) + tf.expand_dims(x_len, 1)), ], 2) xy = x + tf.scatter_nd(indices, y, py_utils.GetShape(x)) # We need to remap all <pad> to `pad`. xy = tf.where( tf.less(tf.expand_dims(tf.range(py_utils.GetShape(xy)[1]), 0), tf.expand_dims(x_len + y_len, 1)), xy, tf.fill(py_utils.GetShape(xy), pad)) xy_paddings = 1 - tf.sequence_mask(x_len + y_len, py_utils.GetShape(xy)[1], x_paddings.dtype) return xy, xy_paddings
def SequenceAppendToken(x, x_paddings, token, extend=False): """Appends <token> to sequence `x`. Args: x: A sequence of tokens of shape [batch_size, x_len_max]. x_paddings: The paddings of `x`. token: The token to append (of type integer). extend: Whether to extend `x` along the length dimension, this must be true for any sequence length in `x` that is `x_len_max` or else an invalid sequence will be emitted. Returns: A tuple. - The new sequence, Tensor of shape [batch_size, x_len_max]. - The new paddings, Tensor of shape [batch_size, x_len_max]. """ batch_size = py_utils.GetShape(x)[0] x_len = tf.cast(tf.round(tf.reduce_sum(1 - x_paddings, 1)), tf.int32) if extend: x = tf.pad(x, [[0, 0], [0, 1]]) # Mask all invalid entries of `x` to 0. x *= tf.sequence_mask(x_len, py_utils.GetShape(x)[1], x.dtype) # Append the <token> based on `x_len`. x += tf.scatter_nd(tf.stack([tf.range(batch_size), x_len], axis=1), tf.cast(tf.fill([batch_size], token), x.dtype), py_utils.GetShape(x)) x_paddings = 1 - tf.sequence_mask(x_len + 1, py_utils.GetShape(x)[1], x_paddings.dtype) return x, x_paddings
def testBpeTokenization(self): word_vocab = test_helper.test_src_dir_path( 'core/ops/testdata/bpe_words.vocab') code_vocab = test_helper.test_src_dir_path( 'core/ops/testdata/bpe_codes.vocab') sentences = [ 'GIVE ME A PENNY', 'THEY LIVED ALONE', 'THEY GIVE ME A PENNY ALONE' ] expected_sentences = [ b'GIVE ME A PENNY </s> ', b'THEY LIVED ALONE </s> ', b'THEY GIVE ME A PENNY ', ] expected_token_ids = [ [27, 9, 30, 14, 28, 14, 52, 11, 4, 6, 6, 10, 2, 2, 2], [16, 4, 10, 12, 9, 30, 24, 7, 12, 49, 14, 2, 2, 2, 2], [16, 4, 10, 27, 9, 30, 14, 28, 14, 52, 11, 4, 6, 6, 10], ] with self.session(use_gpu=False): label_tensor = tf.constant(sentences) _, token_ids, paddings = ops.bpe_words_to_ids( label_tensor, tokenization_filepath=word_vocab, maxlen=15) seq_lens = tf.cast(tf.round(tf.reduce_sum(1 - paddings, axis=-1)), tf.int32) target_string = ops.bpe_ids_to_words(token_ids, seq_lengths=seq_lens, vocab_filepath=code_vocab) self.assertEqual(expected_sentences, target_string.eval().tolist()) self.assertEqual(expected_token_ids, token_ids.eval().tolist())
def ComputeMetrics(self, decoder_outs, input_batch, ids_to_strings_fn): """Computes metrics on output from decoder. Args: decoder_outs: A `BeamSearchDecodeOutput`, a namedtuple containing the decode results. input_batch: A `NestedMap` of tensors representing the source, target, and other components of the input batch. ids_to_strings_fn: a function of (ids, lens) -> strings, where ids has shape [batch, length], lens has shape [batch], and strings has shape [batch]. Returns: A dict of Tensors containing decoder output and metrics. """ topk = self.GetTopK(decoder_outs, ids_to_strings_fn=ids_to_strings_fn) tgt_batch = tf.shape(topk.scores)[0] num_hyps_per_beam = tf.shape(topk.scores)[1] tgt = input_batch.tgt tgt_lens = tf.cast(tf.round(tf.reduce_sum(1.0 - tgt.paddings, 1)), tf.int32) tgt_lens = py_utils.HasShape(tgt_lens, [tgt_batch]) transcripts = ids_to_strings_fn(tgt.labels, tgt_lens - 1) # Filter out all isolated '<noise>' tokens. noise_pattern = ' <noise> |^<noise> | <noise>$|^<noise>$' filtered_refs = tf.strings.regex_replace(transcripts, noise_pattern, ' ') filtered_hyps = tf.strings.regex_replace(topk.decoded, noise_pattern, ' ') # Compute translation quality scores for all hyps. filtered_refs = tf.tile( tf.reshape(filtered_refs, [-1, 1]), [1, num_hyps_per_beam]) filtered_hyps = tf.reshape(filtered_hyps, [-1]) filtered_refs = tf.reshape(filtered_refs, [-1]) tf.logging.info('filtered_refs=%s', filtered_refs) norm_wer_errors, norm_wer_words = self.ComputeNormalizedWER( filtered_hyps, filtered_refs, num_hyps_per_beam) ret_dict = { 'target_ids': tgt.ids, 'target_labels': tgt.labels, 'target_weights': tgt.weights, 'target_paddings': tgt.paddings, 'transcripts': transcripts, 'topk_decoded': topk.decoded, 'topk_ids': topk.ids, 'topk_lens': topk.lens, 'topk_scores': topk.scores, 'norm_wer_errors': norm_wer_errors, 'norm_wer_words': norm_wer_words, } if not py_utils.use_tpu() and 'sample_ids' in input_batch: ret_dict['utt_id'] = input_batch.sample_ids ret_dict.update( self.AddAdditionalDecoderMetricsToGraph(topk, filtered_hyps, filtered_refs, input_batch, decoder_outs)) return ret_dict
def _ComputeDecoderMetrics(self, decoder_outs, input_batch): """Computes metrics on output from decoder. Args: decoder_outs: A `BeamSearchDecodeOutput`, a namedtuple containing the decode results. input_batch: A `NestedMap` of tensors representing the source, target, and other components of the input batch. Returns: A dict of Tensors containing decoder output and metrics. """ p = self.params topk = self._GetTopK(decoder_outs) tgt = self._GetTargetForDecoderMetrics(input_batch) transcripts = self.input_generator.IdsToStrings( tgt.labels, tf.cast(tf.round(tf.reduce_sum(1.0 - tgt.paddings, 1) - 1.0), tf.int32)) # Filter out all isolated '<noise>' tokens. noise_pattern = ' <noise> |^<noise> | <noise>$|^<noise>$' filtered_refs = tf.strings.regex_replace(transcripts, noise_pattern, ' ') filtered_hyps = tf.strings.regex_replace(topk.decoded, noise_pattern, ' ') # Compute translation quality scores for all hyps. filtered_refs = tf.tile(tf.reshape(filtered_refs, [-1, 1]), [1, p.decoder.beam_search.num_hyps_per_beam]) filtered_hyps = tf.reshape(filtered_hyps, [-1]) filtered_refs = tf.reshape(filtered_refs, [-1]) norm_wer_errors, norm_wer_words = self._ComputeNormalizedWER( filtered_hyps, filtered_refs) ret_dict = { 'target_ids': tgt.ids, 'target_labels': tgt.labels, 'target_weights': tgt.weights, 'target_paddings': tgt.paddings, 'transcripts': transcripts, 'topk_decoded': topk.decoded, 'topk_ids': topk.ids, 'topk_lens': topk.lens, 'topk_scores': topk.scores, 'norm_wer_errors': norm_wer_errors, 'norm_wer_words': norm_wer_words, } if not py_utils.use_tpu(): ret_dict['utt_id'] = input_batch.sample_ids ret_dict.update( self.AddAdditionalDecoderMetricsToGraph(topk, filtered_hyps, filtered_refs, input_batch, decoder_outs)) return ret_dict
def _BeamSearchDecode(self, input_batch): p = self.params with tf.name_scope('fprop'), tf.name_scope(p.name): encoder_outputs = self.enc.FPropDefaultTheta(input_batch.src) encoder_outputs = self.dec.AddExtraDecodingInfo( encoder_outputs, input_batch.tgt) decoder_outs = self.dec.BeamSearchDecode(encoder_outputs) topk_hyps = decoder_outs.topk_hyps topk_ids = decoder_outs.topk_ids topk_lens = decoder_outs.topk_lens topk_scores = decoder_outs.topk_scores slen = tf.cast( tf.round(tf.reduce_sum(1 - input_batch.src.paddings, 1) - 1), tf.int32) srcs = self.input_generator.IdsToStrings( input_batch.src.ids, slen, self._GetTokenizerKeyToUse('src')) topk_decoded = self.input_generator.IdsToStrings( topk_ids, topk_lens - 1, self._GetTokenizerKeyToUse('tgt')) topk_decoded = tf.reshape(topk_decoded, tf.shape(topk_hyps)) topk_scores = tf.reshape(topk_scores, tf.shape(topk_hyps)) refs = self.input_generator.IdsToStrings( input_batch.tgt.labels, tf.cast( tf.round( tf.reduce_sum(1.0 - input_batch.tgt.paddings, 1) - 1.0), tf.int32), self._GetTokenizerKeyToUse('tgt')) ret_dict = { 'target_ids': input_batch.tgt.ids, 'target_labels': input_batch.tgt.labels, 'target_weights': input_batch.tgt.weights, 'target_paddings': input_batch.tgt.paddings, 'sources': srcs, 'targets': refs, 'topk_decoded': topk_decoded, 'topk_lens': topk_lens, 'topk_scores': topk_scores, } return ret_dict
def _ProcessLine(self, line): """A single-text-line processor. Gets a string tensor representing a line of text that have been read from the input file, and splits it to graphemes (characters). We use original characters as the target labels, and the lowercased and punctuation-removed characters as the source labels. Args: line: a 1D string tensor. Returns: A list of tensors, in the expected order by __init__. """ # Tokenize the input into integer ids. # tgt_ids has the start-of-sentence token prepended, and tgt_labels has the # end-of-sentence token appended. tgt_ids, tgt_labels, tgt_paddings = self.StringsToIds( tf.convert_to_tensor([line])) def Normalize(line): # Lowercase and remove punctuation. line = line.lower().translate(None, string.punctuation.encode('utf-8')) # Convert multiple consecutive spaces to a single one. line = b' '.join(line.split()) return line normalized_line = tf.py_func(Normalize, [line], tf.string, stateful=False) _, src_labels, src_paddings = self.StringsToIds(tf.convert_to_tensor( [normalized_line]), is_source=True) # The model expects the source without a start-of-sentence token. src_ids = src_labels # Compute the length for bucketing. bucket_key = tf.cast( tf.round( tf.maximum(tf.reduce_sum(1.0 - src_paddings), tf.reduce_sum(1.0 - tgt_paddings))), tf.int32) tgt_weights = 1.0 - tgt_paddings # Return tensors in an order consistent with __init__. out_tensors = [ src_ids, src_paddings, tgt_ids, tgt_paddings, tgt_labels, tgt_weights ] return [tf.squeeze(t, axis=0) for t in out_tensors], bucket_key
def _InferenceSubgraph_Default(self): """Default inference subgraph.""" text = tf.placeholder(tf.string, shape=[None]) # [batch, time] ids, labels, paddings = self.input_generator.StringsToIds(text) weights = 1. - paddings ids, paddings, labels, weights = self._TrimIfPossible( ids, paddings, labels, weights) lengths = tf.reduce_sum(tf.cast(1 - paddings, tf.int32), axis=1) tokens_from_labels = self.input_generator.IdsToStrings(labels, lengths) oovs = tf.equal(labels, self.input_generator.tokenizer.unk_id) num_oovs_per_sample = tf.cast( tf.round( tf.reduce_sum(tf.cast(oovs, tf.float32) * (1 - paddings), axis=1)), tf.int32) batch_size = tf.shape(ids)[0] xent_output, _ = self.lm.FPropDefaultTheta( inputs=ids, paddings=paddings, state0=self.lm.zero_state(self.theta.lm, batch_size), labels=py_utils.NestedMap(class_ids=labels, class_weights=weights)) per_example_xent = py_utils.HasShape(xent_output.per_example_xent, tf.shape(ids)) log_pplx_per_sample = tf.reduce_sum(per_example_xent * (1 - paddings), axis=1) fetches = { 'log_pplx_per_token': # [batch, time] per_example_xent, 'paddings': # [batch, time] paddings, 'lengths': # [batch] lengths, 'log_pplx_per_sample': # [batch] log_pplx_per_sample, 'num_oovs_per_sample': # [batch], int32 num_oovs_per_sample, 'tokens_from_labels': # [batch], string tokens_from_labels, 'ids': # [batch, time], int32 ids } feeds = { 'text': text, 'ids': ids, 'paddings': paddings, 'labels': labels, 'weights': weights, } return fetches, feeds
def SequenceLength(padding): """Computes the length of a sequence based on binary padding. Args: padding: A tensor of binary paddings shaped [batch, seqlen]. Returns: seq_lens, A tensor of shape [batch] containing the non-padded length of each element of plot_tensor along the batch dimension. """ seq_lens = tf.cast(tf.round(tf.reduce_sum(1 - padding, axis=1)), tf.int32) # Get rid of any extra dimensions. batch_size = tf.shape(padding)[0] seq_lens = tf.reshape(seq_lens, [batch_size], name='seq_lens') return seq_lens
def _update_mask(self, weights, threshold): """Updates the mask for a given weight tensor. This functions first computes the cdf of the weight tensor, and estimates the threshold value such that 'desired_sparsity' fraction of weights have magnitude less than the threshold. Args: weights: The weight tensor that needs to be masked. threshold: The current threshold value. The function will compute a new threshold and return the exponential moving average using the current value of threshold Returns: new_threshold: The new value of the threshold based on weights, and sparsity at the current global_step new_mask: A numpy array of the same size and shape as weights containing 0 or 1 to indicate which of the values in weights falls below the threshold Raises: ValueError: if sparsity is not defined """ if self._sparsity is None: raise ValueError('Sparsity variable undefined') sparsity = self._get_sparsity(weights.op.name) with tf.name_scope(weights.op.name + '_pruning_ops'): abs_weights = tf.abs(weights) k = tf.cast( tf.round( tf.cast(tf.size(abs_weights), tf.float32) * (1 - sparsity)), tf.int32) # Sort the entire array values, _ = tf.nn.top_k(tf.reshape(abs_weights, [-1]), k=tf.size(abs_weights)) # Grab the (k-1) th value current_threshold = tf.gather(values, k - 1) smoothed_threshold = tf.add_n([ tf.multiply(current_threshold, 1 - self._spec.threshold_decay), tf.multiply(threshold, self._spec.threshold_decay) ]) new_mask = tf.cast( tf.greater_equal(abs_weights, smoothed_threshold), tf.float32) return smoothed_threshold, new_mask
def _GetSequenceLength(self, example): """Returns sequence length for the example NestedMap from the dataset. This function is used by the TFDatasetBatchBySequenceLength DataSource to obtain the key used for bucketing. Bucketing separates examples into groups before batching, such that each batch contains only examples within a certain length. Args: example: A NestedMap containing an input example. Tensors in the example do not have a leading batch dimension. Returns: An integer sequence length for the example. """ return tf.cast( tf.round( tf.maximum(tf.reduce_sum(1.0 - example.src.paddings), tf.reduce_sum(1.0 - example.tgt.paddings))), tf.int32)
def _InferenceSubgraph_Default(self): """Default inference subgraph. Returns: (fetches, feeds): - fetches: A dictionary of fetches, containing: - log_pplx_per_token: A matrix of shape [batch, time]. [i, j] is i-th input text's j-th token's log prob. - paddings: A matrix of shape [batch, time]. The padding mask. - log_pplx_per_sample: A vector of shape [batch]. [i] is i-th input text's log prob. - num_oovs_per_sample: A vector of shape [batch] counting the total number of out-of-vocabulary tokens in each input. - tokens_from_labels: A vector of shape [batch] returning the predicted tokens as a sequence after mapping them back to strings from ids using the vocabulary. - ids: A matrix of shape [batch, time]. [i, j] is i-th input text's j-th token's id. - feeds: A dictionary of feeds, containing: - text: A placeholder for a vector of strings. """ text = tf.placeholder(tf.string, shape=[None]) # [batch, time] ids, labels, paddings = self.input_generator.StringsToIds(text) lengths = tf.reduce_sum(tf.cast(1 - paddings, tf.int32), axis=1) tokens_from_labels = self.input_generator.IdsToStrings(labels, lengths) oovs = tf.equal(labels, self.input_generator.tokenizer.unk_id) num_oovs_per_sample = tf.cast( tf.round( tf.reduce_sum(tf.cast(oovs, tf.float32) * (1 - paddings), axis=1)), tf.int32) # [time, batch] ids, paddings, labels, weights = self._TrimIfPossibleThenTranspose( ids, paddings, labels, 1.0 - paddings) batch_size = tf.shape(ids)[1] xent_output, _ = self.lm.FPropDefaultTheta( inputs=ids, paddings=paddings, state0=self.lm.zero_state(self.theta.lm, batch_size), labels=py_utils.NestedMap(class_ids=labels, class_weights=weights)) per_example_xent = py_utils.HasShape(xent_output.per_example_xent, tf.shape(ids)) log_pplx_per_sample = tf.reduce_sum(per_example_xent * (1 - paddings), axis=0) fetches = { 'log_pplx_per_token': # [batch, time] tf.transpose(per_example_xent), 'paddings': # [batch, time] tf.transpose(paddings), 'lengths': # [batch] lengths, 'log_pplx_per_sample': # [batch] log_pplx_per_sample, 'num_oovs_per_sample': # [batch], int32 num_oovs_per_sample, 'tokens_from_labels': # [batch], string tokens_from_labels, 'ids': # [batch, time], int32 ids } feeds = {'text': text} return fetches, feeds
def BeamSearchDecode(self, theta, encoder_outputs, num_hyps_per_beam_override=0, init_beam_search_state=None, pre_beam_search_step_callback=None, post_beam_search_step_callback=None, max_steps=None): """Performs beam-search based decoding. Args: theta: A NestedMap object containing weights' values of the decoder layer and its children layers. encoder_outputs: A NestedMap containing encoder outputs to be passed to the callbacks. num_hyps_per_beam_override: If set to a value <= 0, this parameter is ignored. If set to a value > 0, then this value will be used to override `p.num_hyps_per_beam`. init_beam_search_state: The `InitBeamSearchState` callback. Please refer to the class header comments for more details. pre_beam_search_step_callback: The `PreBeamSearchStepCallback` callback. Please refer to the class header comments for more details. post_beam_search_step_callback: The `PostBeamSearchStepCallback` callback. Please refer to the class header comments for more details. max_steps: maximum beam search steps. If None, use self.params.target_seq_len. Returns: A `BeamSearchDecodeOutput`. """ p = self.params num_hyps_per_beam = p.num_hyps_per_beam if num_hyps_per_beam_override > 0: num_hyps_per_beam = num_hyps_per_beam_override if max_steps is None: max_steps = p.target_seq_len initial_results, other_states = init_beam_search_state( theta, encoder_outputs, num_hyps_per_beam) num_hyps = tf.shape(initial_results.log_probs)[0] num_beams = num_hyps // num_hyps_per_beam if 'step_ids' in initial_results: # [num_hyps, 1] step_ids = tf.ensure_shape(initial_results.step_ids, [None, 1]) else: step_ids = tf.fill([num_hyps, 1], tf.constant(p.target_sos_id, dtype=tf.int32)) min_score = -1e36 best_scores = (tf.zeros(shape=[num_beams], dtype=p.dtype) + min_score) cumulative_scores = tf.zeros(shape=[num_hyps], dtype=p.dtype) in_scores = tf.zeros([max_steps, num_hyps], dtype=p.dtype) in_hyps = tf.zeros([max_steps, num_hyps], dtype=tf.int32) in_prev_hyps = tf.zeros([max_steps, num_hyps], dtype=tf.int32) in_done_hyps = tf.zeros([max_steps, num_hyps], dtype=tf.string) bs_atten_probs = tf.zeros( [max_steps, num_hyps, tf.shape(initial_results.atten_probs)[1]], dtype=p.dtype) cur_step = tf.constant(0, dtype=tf.int32) all_done = tf.constant(False, dtype=tf.bool) core_bs_states = (best_scores, cumulative_scores, in_scores, in_hyps, in_prev_hyps, in_done_hyps, bs_atten_probs) def LoopContinue(cur_step, all_done, unused_step_ids, unused_core_bs_states, unused_other_states_list): return tf.logical_and(cur_step < max_steps, tf.logical_not(all_done)) def LoopBody(cur_step, unused_all_done, step_ids, core_bs_states, other_states_list): (cur_step, all_done, new_step_ids, new_bs_states, new_other_states) = self._BeamSearchStep( theta, encoder_outputs, cur_step, step_ids, core_bs_states, other_states.Pack(other_states_list), num_hyps_per_beam, pre_beam_search_step_callback, post_beam_search_step_callback) return (cur_step, all_done, new_step_ids, new_bs_states, new_other_states.Flatten()) flat_other_states = other_states.Flatten() _, _, _, final_bs_states, flat_final_other_states = tf.while_loop( LoopContinue, LoopBody, loop_vars=(cur_step, all_done, step_ids, core_bs_states, flat_other_states), parallel_iterations=10, back_prop=False, swap_memory=False, shape_invariants=(tf.TensorShape(cur_step.get_shape()), tf.TensorShape(all_done.get_shape()), tf.TensorShape(step_ids.get_shape()), _GetShapes(core_bs_states), _GetShapes(flat_other_states, none_shapes=True))) # [target_seq_len, num_beams * num_hyps_per_beam]. final_done_hyps = final_bs_states[5] final_other_states = other_states.Pack(flat_final_other_states) # TODO(rpang): avoid inspecting 'encoder_outputs'. source_paddings = encoder_outputs.padding if isinstance(source_paddings, py_utils.NestedMap): source_seq_lengths = tf.cast( tf.round( tf.reduce_sum(1.0 - tf.transpose(source_paddings.Flatten()[0]), 1)), tf.int32) else: source_seq_lengths = tf.cast( tf.round(tf.reduce_sum(1.0 - tf.transpose(source_paddings), 1)), tf.int32) # [num_beams, num_hyps_per_beam]. topk_hyps = ops.top_k_terminated_hyps( final_done_hyps, source_seq_lengths, k=num_hyps_per_beam, num_hyps_per_beam=num_hyps_per_beam, length_normalization=p.length_normalization, coverage_penalty=p.coverage_penalty, target_seq_length_ratio=p.target_seq_length_ratio, eoc_id=p.target_eoc_id, merge_paths=p.merge_paths) # [num_beams * num_hyps_per_beam, ...]. max_seq_length = 0 if isinstance(max_steps, tf.Tensor) else max_steps topk_ids, topk_lens, topk_scores = ops.unpack_hyp( tf.reshape(topk_hyps, [-1]), max_seq_length=max_seq_length) # [num_beams, num_hyps_per_beam]. topk_scores = tf.reshape(topk_scores, tf.shape(topk_hyps)) return BeamSearchDecodeOutput(final_done_hyps, topk_hyps, topk_ids, topk_lens, topk_scores, None, final_other_states)
def _StringsToIdsImpl(self, strs, max_length, append_eos, languages): """Takes a tensor of strings and returns id/padding tensors. This generates `token_ids`, `target_ids`, and `paddings` in the format that is expected for tokenizers. This performs padding to a fixed length and appends the end-of-sentence token as appropriate. Args: strs: a string Tensor. max_length: a python integer. The second dimension of the returned arrays. All sequences are padded or truncated to that length. append_eos: a python bool. See `BaseTokenizer` for explanation. languages: A vector of strings with the same length as `strs`. Returns: A tuple of 3 tensors: - token_ids: a tensor of sequences of WPM ids starting with SOS. Sequences always end with EOS unless the sequence exceeds the maximum length. Always padded with EOS. - target_ids: a tensor of sequences of WPM ids not starting with SOS but ending with EOS. Always padded with EOS. - paddings: a tensor of floats indicating, at each position, whether the corresponding position is padded. """ p = self.params if append_eos is None: append_eos = p.append_eos batch_size = py_utils.GetShape(strs)[0] token_ids_ta = tf.TensorArray(tf.int32, batch_size) target_ids_ta = tf.TensorArray(tf.int32, batch_size) paddings_ta = tf.TensorArray(tf.float32, batch_size) def _TokenizeOneSentence(i, strs, token_ids_ta, target_ids_ta, paddings_ta): """Tokenizes a single sentence.""" ids, _ = self._wpm_encoder.Encode(strs[i]) if append_eos: ids = tf.concat([ids, [self.eos_id]], axis=0) # This truncates after the eos is added, so some sentences might # not have </s> at the end. token_ids_ta = token_ids_ta.write( i, py_utils.PadOrTrimTo(tf.concat([[self.sos_id], ids], axis=0), [max_length], self.eos_id)) target_ids_ta = target_ids_ta.write( i, py_utils.PadOrTrimTo(ids, [max_length], self.eos_id)) paddings_ta = paddings_ta.write( i, py_utils.PadOrTrimTo(tf.zeros_like(ids, dtype=tf.float32), [max_length], 1.)) return i + 1, strs, token_ids_ta, target_ids_ta, paddings_ta _, _, token_ids_ta, target_ids_ta, paddings_ta = tf.while_loop( lambda i, *_: i < batch_size, _TokenizeOneSentence, loop_vars=(tf.constant(0, tf.int32), strs, token_ids_ta, target_ids_ta, paddings_ta), parallel_iterations=30, back_prop=False) token_ids = token_ids_ta.stack() target_ids = target_ids_ta.stack() paddings = paddings_ta.stack() if not p.pad_to_max_length: maxlen = tf.cast( tf.round(tf.reduce_max(tf.reduce_sum(1.0 - paddings, axis=1))), tf.int32) token_ids = token_ids[:, :maxlen] target_ids = target_ids[:, :maxlen] paddings = paddings[:, :maxlen] return token_ids, target_ids, paddings
def FProp(self, theta, x, x_paddings=None, eos_id=1, force_sample_last_token=True): """Applies SymbolInsertionLayer. We take in a `x`, which represents the groundtruth sequence (i.e., English sequence). We return a sampled rollin (observed) canvas (i.e., random subset of the English sequence), as well as the target (indices) for an insertion-based model (i.e., the targets given the random observed subset). Args: theta: Ignored, this can be None. x: The symbol ids of shape `[batch_size, time_dim]`. x_paddings: The paddings (1 or 0) of shape `[batch_size, time_dim]` where 0 is valid and 1 is invalid. eos_id: The <eos> token id to represent end-of-slot. force_sample_last_token: Set True to force sample the last token of `x`. Returns: A `NestedMap`. - canvas: The canvas (based off of the `rollin_policy`) of shape [batch_size, c_dim]. Note that, `c_dim` <= `time_dim` but need not be equal. - canvas_indices: The canvas indices (into `x`). - canvas_paddings: The paddings of `canvas_indices`. - target_indices: The target indices of shape [num_targets, 3]. `num_targets` is the number of total targets in the entire batch. [:, 0] captures the batch, [:, 1] captures the slot, and [:, 2] captures the token. Each row [batch, slot, vocab] represents the indices of the target -- i.e., the batch, slot and vocab combination of the target. Typical usage of these indices is to tf.gather_nd the log-probs (from the softmax layer). - target_weights: The target weights. Raises: ValueError: If invalid params. """ p = self.params batch_size = py_utils.GetShape(x)[0] time_dim = py_utils.GetShape(x)[1] if x_paddings is None: x_paddings = tf.zeros([batch_size, time_dim], tf.float32) oracle_policy = p.oracle_policy rollin_policy = (oracle_policy if p.rollin_policy == 'oracle' else p.rollin_policy) if rollin_policy != 'uniform': raise ValueError('Unknown or unsupported rollin policy: %s' % rollin_policy) if oracle_policy != 'uniform': raise ValueError('Unknown or unsupported oracle policy: %s' % oracle_policy) x_len = tf.cast(tf.round(tf.reduce_sum(1 - x_paddings, 1)), tf.int32) # Compute the desired length per example in the batch. ratio = tf.random.uniform([batch_size], 0.0, 1.0, seed=p.random_seed) if force_sample_last_token: c_len = tf.minimum( tf.cast(ratio * tf.cast(x_len, tf.float32), tf.int32), x_len - 1) + 1 else: c_len = tf.minimum( tf.cast(ratio * tf.cast(x_len + 1, tf.float32), tf.int32), x_len) # Compute the maximum length across the batch. c_len_max = tf.reduce_max(c_len) # Grab subset of random valid indices per example. z_logits = tf.cast( tf.expand_dims(tf.range(time_dim), 0) >= tf.expand_dims(x_len, 1), tf.float32) * -1e9 if force_sample_last_token: # Force sample the last token -- i.e., as indexed by `x_len - 1`. We can # accomplish this by add +LARGE_NUMBER to the logits. z_logits += tf.cast( tf.equal(tf.expand_dims(tf.range(time_dim), 0), tf.expand_dims(x_len - 1, 1)), tf.float32) * 1e9 # Gumbel-max trick to sample (we only sample valid positions per sample in # the batch). z = -tf.math.log(-tf.math.log( tf.random.uniform([batch_size, time_dim], seed=p.random_seed))) unused_c_values, c_indices = tf.nn.top_k(z_logits + z, time_dim) # Trim everything > c_len_max. c_indices = c_indices[:, :c_len_max] # Invalidate any indices >= c_len, we use the last index as the default # invalid index. c_indices = tf.where( tf.expand_dims(tf.range(c_len_max), 0) < tf.expand_dims(c_len, 1), c_indices, tf.fill(py_utils.GetShape(c_indices), time_dim - 1)) # Materialize the canvas. c_indices = tf.sort(c_indices) c = tf.gather_nd( x, tf.stack([ tf.reshape( tf.tile(tf.expand_dims(tf.range(batch_size), 1), [1, c_len_max]), [-1]), tf.reshape(c_indices, [-1]) ], 1)) c = tf.reshape(c, [batch_size, c_len_max]) # Compute the paddings. c_paddings = 1 - tf.sequence_mask( c_len, c_len_max, dtype=x_paddings.dtype) c *= tf.cast(1 - c_paddings, tf.int32) indices = tf.concat([ tf.reshape( tf.tile(tf.expand_dims(tf.range(batch_size), 1), [1, c_len_max]), [batch_size * c_len_max, 1]), tf.reshape(c_indices, [batch_size * c_len_max, 1]) ], 1) x_token_is_observed = tf.scatter_nd( indices, tf.ones([batch_size * c_len_max], tf.int32), py_utils.GetShape(x)) # `x_segments` captures which slot each `x` belongs to (both observed and # tokens that need to be observed). x_segments = tf.cumsum(x_token_is_observed, 1, exclusive=True) x_token_is_observed = tf.cast(x_token_is_observed, tf.bool) prev_x_token_is_observed = tf.pad(x_token_is_observed[:, :-1], [[0, 0], [1, 0]], constant_values=True) x_token_is_observed = tf.reshape(x_token_is_observed, [-1]) prev_x_token_is_observed = tf.reshape(prev_x_token_is_observed, [-1]) x_is_valid = tf.cast(1 - x_paddings, tf.bool) x_is_valid = tf.reshape(x_is_valid, [-1]) # Remap all the observed to <eos>, note some of these need a zero weight # (or else there would be <eos> and valid token in the same slot). target_indices = tf.cast(tf.reshape(x, [-1, 1]), tf.int32) target_indices = tf.where( x_token_is_observed, tf.fill(py_utils.GetShape(target_indices), eos_id), target_indices) # TODO(williamchan): We give uniform 1.0 weight, however, math suggests # we may want to weigh this term by the original sequence length. target_weights = tf.ones_like(target_indices, tf.float32) # We need to set all the weights for <eos> which actually have valid tokens # in the slot to zero. target_weights = tf.where( x_token_is_observed & ~prev_x_token_is_observed, tf.zeros_like(target_weights), target_weights) # TODO(williamchan): Consider dropping the entries w/ weight zero. # Add the batch and slot indices. target_indices = tf.concat([ tf.reshape( tf.tile(tf.expand_dims(tf.range(batch_size), 1), [1, time_dim]), [batch_size * time_dim, 1]), tf.reshape(x_segments, [-1, 1]), target_indices ], 1) # Select only the valid indices. The selected valid ones include slots w/ # <eos>. target_indices = target_indices[x_is_valid] target_weights = target_weights[x_is_valid] return py_utils.NestedMap(canvas=c, canvas_indices=c_indices, canvas_paddings=c_paddings, target_indices=target_indices, target_weights=target_weights)
def _StringsToIdsImpl(self, strs, max_length, append_eos, languages): del languages p = self.params if append_eos is None: append_eos = p.append_eos batch_size = py_utils.GetShape(strs)[0] token_ids_ta = tf.TensorArray(tf.int32, batch_size) target_ids_ta = tf.TensorArray(tf.int32, batch_size) paddings_ta = tf.TensorArray(tf.float32, batch_size) def _TokenizeOneSentence(i, text, token_ids_ta, target_ids_ta, paddings_ta): """Tokenizes a single sentence.""" if tf.is_tensor(i): text_i = tf.gather(text, i) else: text_i = text[i] ids = self._tokenizer.tokenize(text_i).merge_dims(0, -1) ids.set_shape([None]) if append_eos: ids = tf.concat([ids, [self.eos_id]], axis=0) sos_ids = tf.concat([[self.sos_id], ids], axis=0) if p.prepend_sos: ids = sos_ids # This truncates after the EOS is added, so some sentences might # not have EOS at the end. token_ids_ta = token_ids_ta.write( i, py_utils.PadOrTrimTo(sos_ids, [max_length], 0)) target_ids_ta = target_ids_ta.write( i, py_utils.PadOrTrimTo(ids, [max_length], 0)) paddings_ta = paddings_ta.write( i, py_utils.PadOrTrimTo(tf.zeros_like(ids, dtype=tf.float32), [max_length], 1.)) return i + 1, strs, token_ids_ta, target_ids_ta, paddings_ta _, _, token_ids_ta, target_ids_ta, paddings_ta = tf.while_loop( lambda i, *_: i < batch_size, _TokenizeOneSentence, loop_vars=(tf.constant(0, tf.int32), strs, token_ids_ta, target_ids_ta, paddings_ta), parallel_iterations=30, back_prop=False) token_ids = token_ids_ta.stack() target_ids = target_ids_ta.stack() paddings = paddings_ta.stack() if not p.pad_to_max_length: maxlen = tf.cast( tf.round(tf.reduce_max(tf.reduce_sum(1.0 - paddings, axis=1))), tf.int32) token_ids = token_ids[:, :maxlen] target_ids = target_ids[:, :maxlen] paddings = paddings[:, :maxlen] return token_ids, target_ids, paddings