def _GetBetaGamma(self, theta, inputs, **kwargs): p = self.params assert 'class_emb' in kwargs class_emb = kwargs['class_emb'] # class_emb is a one-hot vector of shape [batch, class_emb_dim=num_classes]. class_ids = tf.math.argmax(class_emb, axis=-1, output_type=tf.int32) # [batch, dim] # Not using matmul/einsum to avoid potential precision problem on TPU with # sparse inputs. beta = tf.gather(theta.beta, class_ids) gamma = tf.gather(theta.gamma, class_ids) if not p.gamma_zero_init and not p.gamma_one_init: # Note, The real gamma to use is 1 + gamma. gamma = 1.0 + gamma # Extend to [batch, 1, ... 1, dim] batch = py_utils.GetShape(inputs)[0] to_shape = tf.concat( [[batch], tf.ones([py_utils.GetRank(inputs) - 2], tf.int32), [self.params.dim]], axis=0) beta = tf.reshape(beta, to_shape) gamma = tf.reshape(gamma, to_shape) return beta, gamma
def _InputBatch(self): p = self.params @tf.function def ReadData(): x, y = io_ops.restore_v2(p.ckpt, [p.data, p.label], [''] * 2, [p.data_dtype, p.label_dtype]) # Always convert to float32. return tf.cast(x, tf.float32), tf.cast(y, tf.float32) # Loads data and label into memory and keep it around. data, label = ops.cached_call(f=ReadData.get_concrete_function(), T=[tf.float32, tf.float32]) b, shape = self.InfeedBatchSize(), list(p.data_shape) data = tf.reshape(data, [-1] + shape) label = tf.reshape(label, [-1]) label = py_utils.HasShape(label, [tf.shape(data)[0]]) sample_ids = ops.random_permutation_sequence( num=p.num_samples, batch=b, repeat=p.repeat, seed=p.random_seed if p.random_seed else 0) n = tf.shape(sample_ids)[0] raw = py_utils.PadOrTrimTo(tf.gather(data, sample_ids), [b] + shape) ret = py_utils.NestedMap( raw=raw, data=self._Preprocess(raw), label=py_utils.PadOrTrimTo(tf.gather(label, sample_ids), [b]), weight=py_utils.PadOrTrimTo(tf.ones([n], dtype=tf.float32), [b])) if not py_utils.use_tpu(): ret['sample_ids'] = sample_ids return ret
def testForwardPassWithDoubleBatch(self): with self.session(use_gpu=False) as sess: p = self._EncoderParams() bs = 2 seq_len = 16 tf.random.set_seed(8372749040) mt_enc = p.Instantiate() batch = py_utils.NestedMap() batch.ids = tf.constant( np.random.randint(low=0, high=63, size=[bs, seq_len], dtype=np.int32)) paddings = [] for _ in range(bs): zeros_len = np.random.randint(1, seq_len + 1) paddings.append([ 0., ] * zeros_len + [1.] * (seq_len - zeros_len)) batch.paddings = tf.zeros([bs, seq_len]) other_batch = py_utils.NestedMap() other_batch.ids = tf.gather(batch.ids, [1, 0]) other_batch.paddings = tf.gather(batch.paddings, [1, 0]) lambdas = np.random.random((bs, seq_len)) lambdas = tf.constant(lambdas, tf.float32) out = mt_enc.FProp(mt_enc.theta, batch, interpolation_batch=other_batch, lambdas=[lambdas, 1 - lambdas]) enc_out_sum = tf.reduce_sum(out.encoded, 0) tf.global_variables_initializer().run() actual_enc_out, actual_enc_out_sum = sess.run( [out.encoded, enc_out_sum]) expected_enc_out_sum = [[ -38.089085, -22.181915, 3.3765068, -45.2483, -58.186905, -3.4464571, 24.461462, 12.014615, 33.08178, 34.02244, 23.391253, -15.515911, 0.72847706, 50.45283, -26.36325, 21.799355 ], [ -37.716507, -12.993027, 7.148979, -39.70747, -57.864025, 2.2049172, 29.571432, 18.955816, 30.406136, 33.270325, 21.685469, -17.21592, 1.3697424, 49.33187, -30.023928, 22.915518 ]] # pyformat: disable self.assertAllEqual([seq_len, bs, p.model_dim], actual_enc_out.shape) self.assertAllClose(expected_enc_out_sum, actual_enc_out_sum, rtol=1e-05, atol=1e-05)
def _GetLangIds(self, source_id): """Look up the correct lang_id from the source_id tensor.""" task_id = self._GetTaskIds(source_id) src_lang_id = task_id tgt_lang_id = task_id if self.params.task_to_src_lang_map: src_langs = tf.constant(self.params.task_to_src_lang_map, dtype=tf.int32) src_lang_id = tf.gather(src_langs, task_id) if self.params.task_to_tgt_lang_map: tgt_langs = tf.constant(self.params.task_to_tgt_lang_map, dtype=tf.int32) tgt_lang_id = tf.gather(tgt_langs, task_id) return src_lang_id, tgt_lang_id
def ReOrderHyps(x_in): """Reorders x_in based on prev hyp ids.""" if (isinstance(x_in, tf.Tensor) and x_in.shape.ndims and x_in.shape.ndims > 0): if x_in.shape.ndims > 2 and not p.batch_major_state: x_out = tf.gather(x_in, old_hyp_ids, axis=1) else: x_out = tf.gather(x_in, old_hyp_ids) x_out.set_shape(x_in.get_shape()) return x_out else: return x_in
def ApplyBias(): """Bias and update log_probs and consistent.""" def TileForBeamAndFlatten(tensor): tensor = tf.reshape(tensor, [1, -1]) # [1, src_batch] tensor = tf.tile(tensor, [num_hyps_per_beam, 1 ]) # [num_hyps_per_beam, src_batch] tgt_batch = tf.shape(step_ids)[ 0] # num_hyps_per_beam*src_batch return tf.reshape(tensor, [tgt_batch]) # Consistent if step_ids == labels from previous step # TODO(navari): Consider updating consistent only if weights > 0. Then # re-evaluate the need for bias_only_if_consistent=True. # Note that prev_label is incorrrect for step 0 but is overridden later prev_label = TileForBeamAndFlatten( tf.gather(labels, tf.maximum(time_step - 1, 0), axis=1)) is_step0 = tf.equal(time_step, 0) local_consistence = tf.logical_or( is_step0, tf.equal(prev_label, tf.squeeze(step_ids, 1))) consistent = tf.logical_and(states.consistent, local_consistence) # get label, weight slices corresponding to current time_step label = TileForBeamAndFlatten( tf.gather(labels, time_step, axis=1)) weight = TileForBeamAndFlatten( tf.gather(weights, time_step, axis=1)) if p.bias_only_if_consistent: weight = weight * tf.cast(consistent, p.dtype) # convert from dense label to sparse label probs vocab_size = tf.shape(bs_results.log_probs)[1] uncertainty = tf.constant( 1e-10, p.dtype) # avoid 0 probs which may cause issues with log label_probs = tf.one_hot( label, vocab_size, on_value=1 - uncertainty, off_value=uncertainty / tf.cast(vocab_size - 1, p.dtype), dtype=p.dtype) # [tgt_batch, vocab_size] pred_probs = tf.exp(bs_results.log_probs) # interpolate predicted probs and label probs weight = tf.expand_dims(weight, 1) probs = py_utils.with_dependencies([ py_utils.assert_less_equal(weight, 1.), py_utils.assert_greater_equal(weight, 0.) ], (1.0 - weight) * pred_probs + weight * label_probs) return tf.log(probs), consistent
def ComputeMoments(inputs, padding, reduce_over_dims, cumulative_axis=None, enable_cross_replica_sum_on_tpu=False, keepdims=False): """Computes mean and variance over the valid data points in inputs.""" mask = 1.0 - padding inputs = py_utils.with_dependencies([ py_utils.assert_equal(tf.rank(inputs), tf.rank(mask)), py_utils.assert_greater_equal(mask, tf.zeros_like(mask)), ], inputs) sum_v = tf.reduce_sum(inputs * tf.cast(mask, inputs.dtype), reduce_over_dims, keepdims=keepdims) count_v = tf.reduce_sum(mask, reduce_over_dims, keepdims=keepdims) if cumulative_axis is not None: sum_v = tf.math.cumsum(sum_v, axis=cumulative_axis) count_v = tf.math.cumsum(count_v, axis=cumulative_axis) # Input shape is guaranteed to be a multiple of mask shape because the # inputs * mask op above was successfully broadcasted. input_size_on_reduced_dims = tf.reduce_prod( tf.gather(tf.shape(inputs), reduce_over_dims)) mask_size_on_reduced_dims = tf.reduce_prod( tf.gather(tf.shape(mask), reduce_over_dims)) mask_multiplier = tf.math.truediv(input_size_on_reduced_dims, mask_size_on_reduced_dims) count_v *= tf.cast(mask_multiplier, count_v.dtype) if py_utils.use_tpu() and enable_cross_replica_sum_on_tpu: sum_v = tf.tpu.cross_replica_sum(sum_v) count_v = tf.tpu.cross_replica_sum(count_v) count_v = tf.maximum(count_v, 1.0) mean = sum_v / count_v sum_vv = tf.reduce_sum((inputs - mean) * (inputs - mean) * mask, reduce_over_dims, keepdims=keepdims) if cumulative_axis is not None: sum_vv = tf.math.cumsum(sum_vv, axis=cumulative_axis) if py_utils.use_tpu() and enable_cross_replica_sum_on_tpu: sum_vv = tf.tpu.cross_replica_sum(sum_vv) variance = py_utils.with_dependencies([ py_utils.assert_greater_equal(sum_vv, tf.zeros_like(sum_vv)), ], sum_vv / count_v) return mean, variance
def _TokenizeOneSentence(i, text, token_ids_ta, target_ids_ta, paddings_ta): """Tokenizes a single sentence.""" if tf.is_tensor(i): text_i = tf.gather(text, i) else: text_i = text[i] ids = self._tokenizer.tokenize(text_i).merge_dims(0, -1) ids.set_shape([None]) if append_eos: ids = tf.concat([ids, [self.eos_id]], axis=0) sos_ids = tf.concat([[self.sos_id], ids], axis=0) if p.prepend_sos: ids = sos_ids # This truncates after the EOS is added, so some sentences might # not have EOS at the end. token_ids_ta = token_ids_ta.write( i, py_utils.PadOrTrimTo(sos_ids, [max_length], 0)) target_ids_ta = target_ids_ta.write( i, py_utils.PadOrTrimTo(ids, [max_length], 0)) paddings_ta = paddings_ta.write( i, py_utils.PadOrTrimTo(tf.zeros_like(ids, dtype=tf.float32), [max_length], 1.)) return i + 1, strs, token_ids_ta, target_ids_ta, paddings_ta
def CreateDenseCoordinates(self, ranges): """Create a matrix of coordinate locations corresponding to a dense grid. Example: To create (x, y) coordinates corresponding over a 10x10 grid with step sizes 1, call ``CreateDenseCoordinates([(1, 10, 10), (1, 10, 10)])``. Args: ranges: A list of 3-tuples, each tuple is expected to contain (min, max, num_steps). Each list element corresponds to one dimesion. Each tuple will be passed into np.linspace to create the values for a single dimension. Returns: tf.float32 tensor of shape [total_points, len(ranges)], where total_points = product of all num_steps. """ total_points = int(np.prod([r_steps for _, _, r_steps in ranges])) cycle_steps = total_points stack_coordinates = [] for r_start, r_stop, r_steps in ranges: values = tf.lin_space(tf.cast(r_start, tf.float32), tf.cast(r_stop, tf.float32), tf.cast(r_steps, tf.int32)) cycle_steps //= r_steps gather_idx = (tf.range(total_points) // cycle_steps) % r_steps stack_coordinates.append(tf.gather(values, gather_idx)) return tf.stack(stack_coordinates, axis=1)
def _GetTaskIds(self, source_id): """Look up the correct task_id from the source_id tensor.""" if self.params.file_pattern_task_ids: file_task_ids = tf.constant(self.params.file_pattern_task_ids, dtype=tf.int32) return tf.gather(file_task_ids, source_id) return source_id
def ReOrderHyps(x_in): """Reorders x_in based on prev hyp ids.""" if (isinstance(x_in, tf.Tensor) and x_in.shape.ndims and x_in.shape.ndims > 0): if x_in.shape.ndims > 2 and not p.batch_major_state: # Use corrected indices only here for batch major compute as key/value # caches are the states being affected. correct_old_hyp_ids = ( old_hyp_ids_in_cache_order if p.batch_major_compute else old_hyp_ids) x_out = tf.gather(x_in, correct_old_hyp_ids, axis=1) else: x_out = tf.gather(x_in, old_hyp_ids) x_out.set_shape(x_in.get_shape()) return x_out else: return x_in
def _GetTaskIds(self, source_id): """Look up the correct task_id from the source_id tensor.""" if self.params.file_pattern_task_ids: file_task_ids = tf.constant( self.params.file_pattern_task_ids, dtype=tf.int32) source_id = tf.gather(file_task_ids, source_id) src_task_id = source_id tgt_task_id = source_id if self.params.task_to_src_lang_map: src_lang_ids = tf.constant( self.params.task_to_src_lang_map, dtype=tf.int32) src_task_id = tf.gather(src_lang_ids, src_task_id) if self.params.task_to_tgt_lang_map: tgt_lang_ids = tf.constant( self.params.task_to_tgt_lang_map, dtype=tf.int32) tgt_task_id = tf.gather(tgt_lang_ids, tgt_task_id) return src_task_id, tgt_task_id
def ApplyBias(): """Bias and update log_probs and consistent.""" # Consistent if step_ids == labels from previous step # TODO(navari): Consider updating consistent only if weights > 0. Then # re-evaluate the need for bias_only_if_consistent=True. # Note that prev_label is incorrrect for step 0 but is overridden # later prev_label = TileForBeamAndFlatten( tf.gather(labels, tf.maximum(time_step - 1, 0), axis=1)) is_step0 = tf.equal(time_step, 0) local_consistence = tf.math.logical_or( is_step0, tf.equal(prev_label, tf.squeeze(step_ids, 1))) consistent = tf.math.logical_and(states.consistent, local_consistence) # get label, weight slices corresponding to current time_step label = TileForBeamAndFlatten( tf.gather(labels, time_step, axis=1)) weight = TileForBeamAndFlatten( tf.gather(weights, time_step, axis=1)) if p.bias_only_if_consistent: weight = weight * tf.cast(consistent, py_utils.FPropDtype(p)) # convert from dense label to sparse label probs vocab_size = tf.shape(bs_results.log_probs)[1] label_probs = tf.one_hot(label, vocab_size, dtype=py_utils.FPropDtype( p)) # [tgt_batch, vocab_size] pred_probs = tf.exp(bs_results.log_probs) # interpolate predicted probs and label probs weight = tf.expand_dims(weight, 1) probs = py_utils.with_dependencies([ py_utils.assert_less_equal(weight, 1.), py_utils.assert_greater_equal(weight, 0.) ], (1.0 - weight) * pred_probs + weight * label_probs) # Ensure that tf.math.log is applied to positive values. probs = tf.maximum(probs, tf.constant(1e-12, dtype=probs.dtype)) return tf.math.log(probs), consistent
def _Padding(): indices = tf.random.uniform([num_points_out - actual_num], minval=0, maxval=actual_num, dtype=tf.int32, seed=seed) padded = [] for t in tensor_list: padded.append(tf.concat([t, tf.gather(t, indices, axis=0)], axis=0)) return padded
def _ApplyMass(source_id): if self.params.file_pattern_task_ids: file_task_ids = tf.constant( self.params.file_pattern_task_ids, dtype=tf.int32) task_id = tf.gather(file_task_ids, source_id) else: task_id = source_id mass_task_ids = tf.constant(self.params.mass_task_ids, dtype=tf.int32) return tf.reduce_any(tf.equal(task_id, mass_task_ids))
def Step(recurrent_theta, state0, inputs): """Computes one decoder step.""" if p.use_recurrent: del inputs with tf.name_scope('single_sampler_step'): # Compute logits and states. bs_result, bs_state1 = pre_step_callback( decoder_theta, recurrent_theta.encoder_outputs, tf.expand_dims(state0.ids, 1), # [batch, 1]. state0.bs_state, num_hyps_per_beam=p.num_hyps_per_beam) batch = tf.shape(bs_result.log_probs)[0] state1 = py_utils.NestedMap(timestep=state0.timestep + 1) state1.logits = bs_result.log_probs if p.top_k > 0: topk_logits, topk_ids = tf.math.top_k(state1.logits, k=p.top_k) sample_logits = tf.nn.log_softmax( topk_logits) if p.top_k_renormalize else topk_logits else: sample_logits = state1.logits # Sample ids from logits. [batch]. ids = tf.reshape( tf.random.stateless_categorical( sample_logits / p.temperature, num_samples=1, seed=tf.stack( [recurrent_theta.random_seed, state0.timestep]), dtype=state0.ids.dtype, name='sample_next_id'), [batch]) state1.ids = tf.gather(topk_ids, ids, axis=1, batch_dims=1) if p.top_k > 0 else ids if 'is_last_chunk' in bs_result and p.target_eoc_id >= 0: state1.ids = tf.where( tf.math.logical_and( bs_result.is_last_chunk, tf.equal(state1.ids, p.target_eoc_id)), tf.fill(tf.shape(state1.ids), p.target_eos_id), state1.ids) state1.bs_state = post_step_callback( decoder_theta, recurrent_theta.encoder_outputs, state1.ids, bs_state1) if p.use_recurrent: return state1, py_utils.NestedMap() else: inputs.ids = inputs.ids.write(state0.timestep, state1.ids) inputs.logits = inputs.logits.write(state0.timestep, state1.logits) return (recurrent_theta, state1, inputs)
def ReOrderHyps(key, x_in): """Reorders x_in based on prev hyp ids.""" if random_seed_regex.match(key): # For keys like rnn_states[0].r, it is a shape [2] random seeds tensor # used for deterministic behavior and should not be reordered. return py_utils.HasShape(x_in, [2]) correct_old_hyp_ids = (old_hyp_ids_in_cache_order if p.batch_major_compute else old_hyp_ids) if (isinstance(x_in, tf.Tensor) and x_in.shape.ndims): if x_in.shape.ndims > 2 and not p.batch_major_state: # Use corrected indices only here for batch major compute as key/value # caches are the states being affected. x_out = tf.gather(x_in, correct_old_hyp_ids, axis=1) elif key in POSSIBLY_TIME_MAJOR_STATE_KEYS: x_out = tf.gather(x_in, old_hyp_ids, axis=-1) else: x_out = tf.gather(x_in, correct_old_hyp_ids) x_out.set_shape(x_in.get_shape()) return x_out else: return x_in
def _ReshapeGather(tensor): """Reshapes tensor and then gathers using the nms indices.""" tensor = tf.gather( tf.reshape(tensor, [batch_size, num_bboxes, -1]), per_cls_idxs, batch_dims=1) if not p.use_oriented_per_class_nms: # Tile so that the data fits the expected per class shape of # [batch_size, num_classes, ...]. When *not* using oriented NMS, the # num_classes dimension will be missing since the indices will not # have it. tensor = tf.tile(tensor[:, tf.newaxis, :, :], [1, p.num_classes, 1, 1]) return tensor
def CollectVarHistogram(vs_gs): """Adds histogram summaries for variables and gradients.""" for name, (var, grad) in vs_gs.FlattenItems(): with tf.device(var.device), tf.name_scope(name + '/summary'): if isinstance(grad, tf.IndexedSlices): var = tf.gather(var, grad.indices) grad = grad.values if var.dtype.is_complex: var = tf.abs(var) grad = tf.abs(grad) histogram('var_hist/' + name, var) histogram('grad_hist/' + name, grad)
def _update_mask(self, weights, threshold): """Updates the mask for a given weight tensor. This functions first computes the cdf of the weight tensor, and estimates the threshold value such that 'desired_sparsity' fraction of weights have magnitude less than the threshold. Args: weights: The weight tensor that needs to be masked. threshold: The current threshold value. The function will compute a new threshold and return the exponential moving average using the current value of threshold Returns: new_threshold: The new value of the threshold based on weights, and sparsity at the current global_step new_mask: A numpy array of the same size and shape as weights containing 0 or 1 to indicate which of the values in weights falls below the threshold Raises: ValueError: if sparsity is not defined """ if self._sparsity is None: raise ValueError('Sparsity variable undefined') sparsity = self._get_sparsity(weights.op.name) with tf.name_scope(weights.op.name + '_pruning_ops'): abs_weights = tf.abs(weights) k = tf.cast( tf.round( tf.cast(tf.size(abs_weights), tf.float32) * (1 - sparsity)), tf.int32) # Sort the entire array values, _ = tf.nn.top_k(tf.reshape(abs_weights, [-1]), k=tf.size(abs_weights)) # Grab the (k-1) th value current_threshold = tf.gather(values, k - 1) smoothed_threshold = tf.add_n([ tf.multiply(current_threshold, 1 - self._spec.threshold_decay), tf.multiply(threshold, self._spec.threshold_decay) ]) new_mask = tf.cast( tf.greater_equal(abs_weights, smoothed_threshold), tf.float32) return smoothed_threshold, new_mask
def CollectVarHistogram(vs_gs): """Adds histogram summaries for variables and gradients.""" for name, (var, grad) in vs_gs.FlattenItems(): name = py_utils.SanitizeScopeKey(name) with tf.device(var.device), tf.name_scope(name + '/summary'): if isinstance(grad, tf.IndexedSlices): var = tf.gather(var, grad.indices) grad = grad.values if var.dtype.is_complex: var = tf.abs(var) grad = tf.abs(grad) if py_utils.IsEagerMode(): histogram_v2(f'var_hist/{name}', var) histogram_v2(f'grad_hist/{name}', grad) else: histogram(f'var_hist/{name}', var) histogram(f'grad_hist/{name}', grad)
def GetSentenceEmbeddings(inputs, segment_id): """Returns the average sentence embedding to gate by. Example:: inputs: <tf.Variable 'Variable:0' shape=(10, 3) dtype=float64, numpy= array([[0.41258181, 0.61071571, 0.63777673], [0.65571443, 0.54297766, 0.10288261], [0.8577837 , 0.81915847, 0.61996602], [0.46897136, 0.92662692, 0.32942232], [0.60162383, 0.3385829 , 0.3408632 ], [0.40774807, 0.86139635, 0.00927162], [0.56126334, 0.51748817, 0.07791397], [0.06595223, 0.95529216, 0.34458149], [0.1238971 , 0.49897169, 0.25216722], [0.11221774, 0.50284604, 0.84106974]])> segment_id: <tf.Variable 'Variable:0' shape=(10,) dtype=int64, numpy=array([1, 1, 2, 0, 0, 3, 3, 3, 3, 0])> Args: inputs: G`SM Tensor. segment_id: G`S Tensor. Returns: sentence_embeddings: GSM Tensor that is an average of the input embeddings per segment. """ reshaped_inputs = tf.reshape(inputs, [-1, inputs.shape[-1]]) # We set num_segments to a large value so that shape is known at compile time. max_segments = py_utils.GetShape(reshaped_inputs)[0] # We change the padding to be max_segments - 1 instead of 0 because # tf.math.unsorted_segment_mean because it only accepts values between 1 and # max_segments. modified_segment_id = tf.cast( segment_id + max_segments * tf.cast( tf.equal(segment_id, 0), dtype=tf.dtypes.as_dtype(segment_id.dtype)) - 1, dtype=tf.int32) reshaped_segment_id = tf.reshape(modified_segment_id, [-1]) # Takes the mean of all segments, w/ 0s for the padding. params = tf.concat([ tf.math.unsorted_segment_mean(reshaped_inputs, reshaped_segment_id, max_segments)[:-1], tf.zeros([1, reshaped_inputs.shape[-1]], dtype=reshaped_inputs.dtype) ], axis=0) raw_sentence_embeddings = tf.gather(params, modified_segment_id) # sentence_embedding: <tf.Tensor: shape=(10, 3), dtype=float64, numpy= # array([[0.92657252, 0.40264503, 0.55494457], # [0.92657252, 0.40264503, 0.55494457], # [0.08002721, 0.02360659, 0.63688627], # [0. , 0. , 0. ], # [0. , 0. , 0. ], # [0.8138629 , 0.54451293, 0.48802852], # [0.8138629 , 0.54451293, 0.48802852], # [0.8138629 , 0.54451293, 0.48802852], # [0.8138629 , 0.54451293, 0.48802852], # [0. , 0. , 0. ]])> sentence_embeddings = tf.reshape(raw_sentence_embeddings, inputs.shape) return sentence_embeddings
def _Slicing(): # Choose a random set of indices. indices = tf.range(actual_num) indices = tf.random_shuffle(indices, seed=seed)[:num_points_out] return [tf.gather(t, indices, axis=0) for t in tensor_list]
def _AddNoise(self, batch): """Adding noise the src (see https://arxiv.org/pdf/1711.00043). This function implement 3 types of noise (hyparams defined in self.params.denoise): 1) slightly shuffle the sentence following p.shuffle_tok_range 2) randomly drop tokens with probability p.drop_tok_prob 3) randomly mask tokens with probability p.blank_tok_prob The noises are added to the input with probability p.noise_sent_prob. Args: batch: a `.NestedMap` of the input batch. """ def IsSpecialExample(task_ids, special_task_ids): """A utility function indicates whether inputs belong to specific tasks. Args: task_ids: Task ids for the input batch. Tensor of shape [batch]. special_task_ids: A list of specified task ids. Returns: A tensor indicating whether each sample in the batch belong to the specified task. Return a tensor of size [batch]. """ batch_size = py_utils.GetShape(task_ids)[0] return tf.reduce_any( tf.equal( tf.expand_dims(task_ids, -1), tf.cast( tf.broadcast_to( special_task_ids, [batch_size, len(special_task_ids)]), tf.int32)), -1) p = self.params.denoise batch_size = tf.shape(batch.src.ids)[0] source_max_len = tf.shape(batch.src.ids)[1] # Shuffle tokens according to p.shuffle_tok_range noise = tf.random.uniform([batch_size, source_max_len], 0, p.shuffle_tok_range + 1) # Don't shuffle eos or padding shuffle_tok_range = tf.fill([batch_size, source_max_len], float(p.shuffle_tok_range)) shifted_paddings = tf.pad(batch.src.paddings[:, 1:], [[0, 0], [0, 1]], constant_values=1) noise = tf.where(tf.equal(shifted_paddings, 0), noise, shuffle_tok_range) indices = tf.broadcast_to(tf.range(source_max_len, dtype=tf.int32), [batch_size, source_max_len]) noisy_indices = tf.cast(indices, dtype=tf.float32) + noise permutations = tf.argsort(noisy_indices) stacked = tf.stack([batch.src.ids, permutations], axis=1) denoise_src_ids = tf.stack(tf.map_fn(lambda x: tf.gather(x[0], x[1]), stacked), axis=0) # Select tokens to drop with probability=p.drop_tok_prob random_drop_tok = tf.random.uniform([batch_size, source_max_len]) # Don't drop eos token is_keep_tok = tf.math.logical_or( tf.greater(random_drop_tok, p.drop_tok_prob), tf.equal(denoise_src_ids, self._src_tokenizer.eos_id)) denoise_src_ids = tf.ragged.boolean_mask( denoise_src_ids, is_keep_tok).to_tensor(default_value=0, shape=tf.shape(batch.src.ids)) denoise_src_paddings = tf.ragged.boolean_mask( batch.src.paddings, is_keep_tok).to_tensor(default_value=1, shape=tf.shape(batch.src.ids)) # Select tokens to blank with probability=p.blank_tok_prob # Don't blank eos token random_blank_tok = tf.random.uniform([batch_size, source_max_len]) shifted_paddings = tf.pad(denoise_src_paddings[:, 1:], [[0, 0], [0, 1]], constant_values=1) is_blank_tok = tf.math.logical_and( tf.less(random_blank_tok, p.blank_tok_prob), tf.equal(shifted_paddings, 0)) blank_id = tf.fill([batch_size, source_max_len], p.blank_id) denoise_src_ids = tf.where(is_blank_tok, blank_id, denoise_src_ids) # Select denoising task examples with probability=p.denoise_sent_prob random_uniform_sent = tf.random.uniform([batch_size]) is_denoise_sent = tf.math.logical_and( tf.less(random_uniform_sent, p.noise_sent_prob), IsSpecialExample(self._GetTaskIds(batch.src.source_ids[:, 0]), p.task_ids)) batch.src.ids = tf.where(is_denoise_sent, denoise_src_ids, batch.src.ids) batch.src.paddings = tf.where(is_denoise_sent, denoise_src_paddings, batch.src.paddings) batch.src.ids_indicator = 1 - batch.src.paddings batch.src.weights = batch.src.ids_indicator
def _ApplyPacking(self, batch): """Packs a given batch. Note that this may change the batch size. This function packs the input batch and adds .segment_ids and .segment_pos fields to its `src` and `tgt` fields. Args: batch: a `.NestedMap` of input tensors to be packed. It is modified in place. """ src_actual_seq_len = tf.math.reduce_sum(tf.cast( batch.src.ids_indicator, tf.int32), axis=1) tgt_actual_seq_len = tf.math.reduce_sum(tf.cast( batch.tgt.ids_indicator, tf.int32), axis=1) summary_utils.histogram('source_seq_lengths', src_actual_seq_len) summary_utils.histogram('target_seq_lengths', tgt_actual_seq_len) if not self.params.packing_factor: # Supply segment_ids and segment_pos with no packing. batch.src.segment_ids = batch.src.ids_indicator batch.src.segment_pos = _GetSegmentPos(batch.src.ids_indicator) batch.tgt.segment_ids = batch.tgt.ids_indicator batch.tgt.segment_pos = _GetSegmentPos(batch.tgt.ids_indicator) return (src_segment_ids, src_segment_pos, src_indices_in_input, tgt_segment_ids, tgt_segment_pos, tgt_indices_in_input) = ops.pack_sequences( src_actual_seq_len, tgt_actual_seq_len, self._ScaledBatchSize(), self.params.source_max_length, self.params.target_max_length) uniq_src_indices_in_input = tf.unique( tf.reshape(src_indices_in_input, [-1])).y uniq_tgt_indices_in_input = tf.unique( tf.reshape(tgt_indices_in_input, [-1])).y summary_utils.histogram( 'packed_source_seq_lengths', tf.gather(src_actual_seq_len, uniq_src_indices_in_input, axis=0)) summary_utils.histogram( 'packed_target_seq_lengths', tf.gather(tgt_actual_seq_len, uniq_tgt_indices_in_input, axis=0)) # Ratio of number of non-padded tokens. If < 1.0, we are dropping # input data due to p.packing_factor too high. src_orig_tokens_count = tf.cast(tf.reduce_sum(src_actual_seq_len), tf.float32) src_packed_tokens_count = tf.reduce_sum( tf.cast(src_segment_ids > 0, tf.float32)) summary_utils.scalar('examples/src_packed_token_ratio', src_packed_tokens_count / src_orig_tokens_count) tgt_orig_tokens_count = tf.cast(tf.reduce_sum(tgt_actual_seq_len), tf.float32) tgt_packed_tokens_count = tf.reduce_sum( tf.cast(tgt_segment_ids > 0, tf.float32)) summary_utils.scalar('examples/tgt_packed_token_ratio', tgt_packed_tokens_count / tgt_orig_tokens_count) # We deferred adding .paddings and use its complement .ids_indicator # exclusively so that we can apply the packing with padding set to 0 for all # fields. def ApplyPackingToSource(x): if x.dtype == tf.string: return ops.apply_packing(x, '\t', src_segment_ids, src_indices_in_input) return ops.apply_packing(x, 0, src_segment_ids, src_indices_in_input) src_paddings = ops.apply_packing(batch.src.paddings, 1, src_segment_ids, src_indices_in_input) batch.src = batch.src.Transform(ApplyPackingToSource) batch.src.paddings = src_paddings batch.src.segment_ids = tf.cast(src_segment_ids, tf.float32) batch.src.segment_pos = src_segment_pos def ApplyPackingToTarget(x): if x.dtype == tf.string: return ops.apply_packing(x, '\t', tgt_segment_ids, tgt_indices_in_input) return ops.apply_packing(x, 0, tgt_segment_ids, tgt_indices_in_input) tgt_paddings = ops.apply_packing(batch.tgt.paddings, 1, tgt_segment_ids, tgt_indices_in_input) batch.tgt = batch.tgt.Transform(ApplyPackingToTarget) batch.tgt.paddings = tgt_paddings batch.tgt.segment_ids = tf.cast(tgt_segment_ids, tf.float32) batch.tgt.segment_pos = tgt_segment_pos # The number of examples is indicated by the segment_ids of the target. num_segments = tf.math.reduce_max(batch.tgt.segment_ids, axis=1) num_examples = tf.reduce_sum(num_segments) # Note that this is per infeed value when p.use_per_host_infeed = True. metric_name = 'examples/num_packed_examples' summary_utils.scalar(metric_name, num_examples)
def Step(recurrent_theta, state0, inputs): """Computes one decoder step.""" if p.use_recurrent: del inputs with tf.name_scope('single_sampler_step'): # Compute logits and states. bs_result, bs_state1 = pre_step_callback( decoder_theta, recurrent_theta.encoder_outputs, tf.expand_dims(state0.ids, 1), # [batch, 1]. state0.bs_state, p.num_hyps_per_beam, 0) # cur_step batch = tf.shape(bs_result.log_probs)[0] state1 = py_utils.NestedMap(timestep=state0.timestep + 1) state1.logits = bs_result.log_probs sample_logits = state1.logits # Perform Nucleus Sampling. Assumes logits are in (-1e10, 1e3). if p.nucleus_p < 1.0: max_logit = 1e3 min_logit = -1e10 sorted_logits = tf.sort(sample_logits, direction='DESCENDING', axis=-1) sorted_probs = tf.nn.softmax(sorted_logits) cumsum_probs = tf.math.cumsum(sorted_probs, axis=-1, exclusive=True) masked_logits = tf.where( cumsum_probs < p.nucleus_p, sorted_logits, tf.ones_like(sorted_logits) * max_logit) threshold = tf.math.reduce_min(masked_logits, axis=-1, keepdims=True) sample_logits = tf.where( sample_logits < threshold, tf.ones_like(sorted_logits) * min_logit, sample_logits) # Note that here, we retain the possibility of applying both top_k # and nucleus filtering. if p.top_k > 0: topk_logits, topk_ids = tf.math.top_k(sample_logits, k=p.top_k) sample_logits = tf.nn.log_softmax( topk_logits) if p.top_k_renormalize else topk_logits # Sample ids from logits. [batch]. ids = tf.reshape( tf.random.stateless_categorical( sample_logits / p.temperature, num_samples=1, seed=tf.stack( [recurrent_theta.random_seed, state0.timestep]), dtype=state0.ids.dtype, name='sample_next_id'), [batch]) state1.ids = tf.gather(topk_ids, ids, axis=1, batch_dims=1) if p.top_k > 0 else ids if 'is_last_chunk' in bs_result and p.target_eoc_id >= 0: state1.ids = tf.where( tf.math.logical_and( bs_result.is_last_chunk, tf.equal(state1.ids, p.target_eoc_id)), tf.fill(tf.shape(state1.ids), p.target_eos_id), state1.ids) state1.bs_state = post_step_callback( decoder_theta, recurrent_theta.encoder_outputs, state1.ids, bs_state1) if p.use_recurrent: return state1, py_utils.NestedMap() else: inputs.ids = inputs.ids.write(state0.timestep, state1.ids) inputs.logits = inputs.logits.write(state0.timestep, state1.logits) return (recurrent_theta, state1, inputs)
def _Pack(self, batch): """Packs a given batch. Note that this may change the batch size. This function packs the input batch and adds .segment_ids and .segment_pos fields to its `src` and `tgt` fields. Args: batch: a `.NestedMap` of input tensors to be packed. It is modified in place. """ src_actual_seq_len = tf.math.reduce_sum( tf.cast(batch.src.ids_indicator, tf.int32), axis=1) tgt_actual_seq_len = tf.math.reduce_sum( tf.cast(batch.tgt.ids_indicator, tf.int32), axis=1) summary_utils.histogram('source_seq_lengths', src_actual_seq_len) summary_utils.histogram('target_seq_lengths', tgt_actual_seq_len) if not self.params.packing_factor: # Supply segment_ids and segment_pos with no packing. batch.src.segment_ids = batch.src.ids_indicator batch.src.segment_pos = _GetSegmentPos(batch.src.ids_indicator) batch.tgt.segment_ids = batch.tgt.ids_indicator batch.tgt.segment_pos = _GetSegmentPos(batch.tgt.ids_indicator) return (src_segment_ids, src_segment_pos, src_indices_in_input, tgt_segment_ids, tgt_segment_pos, tgt_indices_in_input) = ops.pack_sequences( src_actual_seq_len, tgt_actual_seq_len, self._ScaledBatchSize(), self.params.source_max_length, self.params.target_max_length) uniq_src_indices_in_input = tf.unique( tf.reshape(src_indices_in_input, [-1])).y uniq_tgt_indices_in_input = tf.unique( tf.reshape(tgt_indices_in_input, [-1])).y summary_utils.histogram( 'packed_source_seq_lengths', tf.gather(src_actual_seq_len, uniq_src_indices_in_input, axis=0)) summary_utils.histogram( 'packed_target_seq_lengths', tf.gather(tgt_actual_seq_len, uniq_tgt_indices_in_input, axis=0)) # We deferred adding .paddings and use its complement .ids_indicator # exclusively so that we can apply the packing with padding set to 0 for all # fields. def ApplyPackingToSource(x): if x.dtype == tf.string: return ops.apply_packing(x, '\t', src_segment_ids, src_indices_in_input) return ops.apply_packing(x, 0, src_segment_ids, src_indices_in_input) batch.src = batch.src.Transform(ApplyPackingToSource) batch.src.segment_ids = tf.cast(src_segment_ids, tf.float32) batch.src.segment_pos = src_segment_pos def ApplyPackingToTarget(x): if x.dtype == tf.string: return ops.apply_packing(x, '\t', tgt_segment_ids, tgt_indices_in_input) return ops.apply_packing(x, 0, tgt_segment_ids, tgt_indices_in_input) batch.tgt = batch.tgt.Transform(ApplyPackingToTarget) batch.tgt.segment_ids = tf.cast(tgt_segment_ids, tf.float32) batch.tgt.segment_pos = tgt_segment_pos
def ExtractBlockContextV2(x, block_size, left_context, right_context, padding_val=0.0, paddings=None): """Extracts temporal context for every block (without restrictions). This is a generalized implementation of ExtractBlockContext, where block_size, left_context, and right_context are 3 free parameters and we don't have constraints (other than l>=1, r>=0, block_size>0). Args: x: a tensor of [batch, time, dim]. block_size: int. Number of time frames in a block. left_context: int. Left context size. Note that the actual left context is `left_context - 1` (this is to be compatible with ExtractBlockContext implementation). right_context: int. Right context size. padding_val: float. value on the padded frames. paddings: optional. If specified, it must be a tensor of [batch, time], and we will return a padding tensor indicating padding info for the returned tensor. Returns: (x_patches, x_paddings) where - x_patches: A tensor of [batch, num_blocks, context_size + block_size, dim] with necessary paddings, where context_size = (left_context - 1) + right_context, and output[:, i, ...] are x[:, start-left_context+1:end+right_context, ...], where start = i * block_size, end = (i + 1) * block_size. - x_paddings: None if paddings = None; else a [batch, num_blocks, context_size + block_size] tensor, indicating the padding info for the corresponding position in x_patches. Let's define some variables here: B: batch size T: input tensor length in time axis D: input tensor dimension in the last axis W: block size U: ceil(T/W) L: left context size R: right context size C: L-1+W+R, full block length Given a [B, T, D] tensor, the return is a [B, U, C, D] tensor where ret[b, u, :] is a length of 2D tensor in a shape (L - 1 + W + R, D), which is a u-th block of the input tensor with (L - 1) left context frames and R right context frames. Implementation note: We use the following procedure to get the return tensor - first do padding in the beginning and at the end: [B, T, D] -> [B, L - 1 + U*W + L - 1 + R, D] - add one extra axis [B, L-1+U*W+R, D] -> [B, L-1+U*W+R, D, 1] - use gather to gather blocks [B, L-1+U*W+R+L-1, D, 1] -> [B, U, C, D] TODO(yqw): after verfiying correctness and benchmark, consider replace v1 implementation? """ # 0. basic shapes b, t, d = py_utils.GetShape(x, 3) w = block_size u = (t + w - 1) // w # equivalent to math.ceil(t/w) l = left_context r = right_context c = l - 1 + r + w # the only constraints are block_size > 0 , l >= 1, r>=0 if w <= 0: raise ValueError(f'block size ({w}) must be greater than 0') if l < 1: raise ValueError(f'Left context ({left_context}) must be >= 1.') if r < 0: raise ValueError(f'Right context ({right_context}) must be >= 0') if paddings is not None: paddings = py_utils.HasShape(paddings, [b, t]) # 1. do front and rear padding left_pad = l - 1 # we need to make sure all u * w elements have enough long context right_pad = (u * w - t + l - 1 + r) x_padded = _DoPadding(x, b, left_pad, right_pad, d, padding_val=padding_val) if paddings is not None: paddings = _DoPadding(paddings, b, left_pad, right_pad, d=None, padding_val=1.0) # 2. generate gather indices # gather_indices is a [u, c] matrix like # [ 0, ........., c-1] # [ w, ........., w + (c-1)] # [2w, .........., 2w + (c-1)] # [(u-1)*w, ...., (u-1)*w + (c-1)] gather_indices = (tf.tile(tf.expand_dims(tf.range(0, c), axis=0), (u, 1)) + tf.tile(tf.expand_dims(tf.range(0, u * w, w), axis=1), (1, c))) # 3. generate x_patches, shape [b, u, c, d] x_patches = tf.gather(x_padded, gather_indices, axis=1) if paddings is not None: # gather is now a [b, u, c] tensor paddings = tf.gather(paddings, gather_indices, axis=1) return x_patches, paddings
def PreBeamSearchStepCallback(theta, encoder_outputs, step_ids, states, num_hyps_per_beam, *args, **kwargs): """Wrapper for adding bias to _PreBeamSearchStateCallback. Biases results.log_probs towards provided encoder_outputs.targets. Args: theta: a NestedMap of parameters. encoder_outputs: a NestedMap computed by encoder. step_ids: A tensor of shape [tgt_batch, 1]. states: A `.NestedMap` of tensors representing states that the clients would like to keep track of for each of the active hyps. num_hyps_per_beam: Beam size. *args: additional arguments to _PreBeamSearchStepCallback. **kwargs: additional arguments to _PreBeamSearchStepCallback. Returns: A tuple (results, out_states). results: A `.NestedMap` of beam search results. atten_probs: The updated attention probs, of shape [tgt_batch, src_len]. log_probs: Log prob for each of the tokens in the target vocab. This is of shape [tgt_batch, vocab_size]. out_states: a `.NestedMap` The updated states. The states relevant here are: time_step: A scalar indicating current step of decoder. Must be provided and maintained by subclass. consistent: A boolean vector of shape [tgt_batch, ] which tracks whether each hypothesis has exactly matched encoder_outputs.targets so far. """ p = self.params time_step = states.time_step bs_results, out_states = self._PreBeamSearchStepCallback( theta, encoder_outputs, step_ids, states, num_hyps_per_beam, *args, **kwargs) labels = encoder_outputs.targets.labels weights = encoder_outputs.targets.weights def TileForBeamAndFlatten(tensor): tensor = tf.reshape(tensor, [1, -1]) # [1, src_batch] tensor = tf.tile( tensor, [num_hyps_per_beam, 1]) # [num_hyps_per_beam, src_batch] tgt_batch = tf.shape(step_ids)[ 0] # num_hyps_per_beam*src_batch return tf.reshape(tensor, [tgt_batch]) # Consistent if step_ids == labels from previous step # TODO(navari): Consider updating consistent only if weights > 0. Then # re-evaluate the need for bias_only_if_consistent=True. # Note that prev_label is incorrrect for step 0 but is overridden later prev_label = TileForBeamAndFlatten( tf.gather(labels, tf.maximum(time_step - 1, 0), axis=1)) is_step0 = tf.equal(time_step, 0) local_consistence = tf.logical_or( is_step0, tf.equal(prev_label, tf.squeeze(step_ids, 1))) out_states.consistent = tf.logical_and(states.consistent, local_consistence) # get label, weight slices corresponding to current time_step label = TileForBeamAndFlatten(tf.gather(labels, time_step, axis=1)) weight = TileForBeamAndFlatten( tf.gather(weights, time_step, axis=1)) if p.bias_only_if_consistent: weight = weight * tf.cast(out_states.consistent, p.dtype) # convert from dense label to sparse label probs vocab_size = tf.shape(bs_results.log_probs)[1] uncertainty = tf.constant( 1e-10, p.dtype) # avoid 0 probs which may cause issues with log label_probs = tf.one_hot(label, vocab_size, on_value=1 - uncertainty, off_value=uncertainty / tf.cast(vocab_size - 1, p.dtype), dtype=p.dtype) # [tgt_batch, vocab_size] pred_probs = tf.exp(bs_results.log_probs) # interpolate predicted probs and label probs weight = tf.expand_dims(weight, 1) probs = py_utils.with_dependencies([ py_utils.assert_less_equal(weight, 1.), py_utils.assert_greater_equal(weight, 0.) ], (1.0 - weight) * pred_probs + weight * label_probs) bs_results.log_probs = tf.log(probs) return bs_results, out_states