def _Moments(inputs, mask, enable_cross_replica_sum_on_tpu=False): """Computes mean and variance over the valid data points in inputs.""" inputs = py_utils.with_dependencies([ py_utils.assert_equal(tf.rank(inputs), tf.rank(mask)), py_utils.assert_greater_equal(mask, tf.zeros_like(mask)), ], inputs) rank = tf.rank(mask) reduce_over_dims = tf.range(0, rank - 1) sum_v = tf.reduce_sum(inputs * tf.cast(mask, inputs.dtype), reduce_over_dims) count_v = tf.reduce_sum(mask, reduce_over_dims) # Input shape is guaranteed to be a multiple of mask shape because the # inputs * mask op above was successfully broadcasted. mask_multiplier = tf.shape(inputs)[:-1] // tf.shape(mask)[:-1] count_v *= tf.cast(tf.reduce_prod(mask_multiplier), count_v.dtype) if py_utils.use_tpu() and enable_cross_replica_sum_on_tpu: sum_v = tf.tpu.cross_replica_sum(sum_v) count_v = tf.tpu.cross_replica_sum(count_v) count_v = tf.maximum(count_v, 1.0) mean = sum_v / count_v sum_vv = tf.reduce_sum((inputs - mean) * (inputs - mean) * mask, reduce_over_dims) if py_utils.use_tpu() and enable_cross_replica_sum_on_tpu: sum_vv = tf.tpu.cross_replica_sum(sum_vv) variance = py_utils.with_dependencies([ py_utils.assert_greater_equal(sum_vv, tf.zeros_like(sum_vv)), ], sum_vv / count_v) return mean, variance
def ComputeMoments(inputs, padding, reduce_over_dims, cumulative_axis=None, enable_cross_replica_sum_on_tpu=False, keepdims=False): """Computes mean and variance over the valid data points in inputs.""" mask = 1.0 - padding inputs = py_utils.with_dependencies([ py_utils.assert_equal(tf.rank(inputs), tf.rank(mask)), py_utils.assert_greater_equal(mask, tf.zeros_like(mask)), ], inputs) sum_v = tf.reduce_sum(inputs * tf.cast(mask, inputs.dtype), reduce_over_dims, keepdims=keepdims) count_v = tf.reduce_sum(mask, reduce_over_dims, keepdims=keepdims) if cumulative_axis is not None: sum_v = tf.math.cumsum(sum_v, axis=cumulative_axis) count_v = tf.math.cumsum(count_v, axis=cumulative_axis) # Input shape is guaranteed to be a multiple of mask shape because the # inputs * mask op above was successfully broadcasted. input_size_on_reduced_dims = tf.reduce_prod( tf.gather(tf.shape(inputs), reduce_over_dims)) mask_size_on_reduced_dims = tf.reduce_prod( tf.gather(tf.shape(mask), reduce_over_dims)) mask_multiplier = tf.math.truediv(input_size_on_reduced_dims, mask_size_on_reduced_dims) count_v *= tf.cast(mask_multiplier, count_v.dtype) if py_utils.use_tpu() and enable_cross_replica_sum_on_tpu: sum_v = tf.tpu.cross_replica_sum(sum_v) count_v = tf.tpu.cross_replica_sum(count_v) count_v = tf.maximum(count_v, 1.0) mean = sum_v / count_v sum_vv = tf.reduce_sum((inputs - mean) * (inputs - mean) * mask, reduce_over_dims, keepdims=keepdims) if cumulative_axis is not None: sum_vv = tf.math.cumsum(sum_vv, axis=cumulative_axis) if py_utils.use_tpu() and enable_cross_replica_sum_on_tpu: sum_vv = tf.tpu.cross_replica_sum(sum_vv) variance = py_utils.with_dependencies([ py_utils.assert_greater_equal(sum_vv, tf.zeros_like(sum_vv)), ], sum_vv / count_v) return mean, variance
def _ComputeBN(self, inputs, paddings, gamma, beta, norm_mean, norm_variance): p = self.params with tf.control_dependencies([ py_utils.assert_greater_equal(norm_variance, tf.zeros_like(norm_variance)), py_utils.assert_shape_match([tf.shape(inputs)[-1]], tf.shape(norm_mean)), py_utils.assert_shape_match([tf.shape(inputs)[-1]], tf.shape(norm_variance)), ]): if p.use_fused_batch_norm_for_eval and (self.do_eval or p.freeze_bn_stats): bn_output, _, _ = nn.fused_batch_norm(inputs, gamma, beta, norm_mean, norm_variance, self._epsilon, is_training=False) else: bn_output = tf.nn.batch_normalization(inputs, norm_mean, norm_variance, beta, gamma, self._epsilon) if p.set_padded_output_to_zero: bn_output = py_utils.ApplyPadding(paddings, bn_output) return bn_output
def FProp(self, theta, inputs, paddings, class_emb): """Apply batch normalization. Args: theta: A `.NestedMap` object containing weights' values of this layer and its children layers. inputs: The inputs tensor. Shaped [batch, ..., dim]. paddings: The paddings tensor. Shaped [batch, ..., 1], with the same rank as the input tensor. class_emb: The conditioning inputs, Shaped [batch, emb_dim]. Returns: Output after applying batch normalization, with the same shape as 'inputs'. """ p = self.params batch = py_utils.GetShape(inputs)[0] class_emb = py_utils.HasShape(class_emb, [batch, p.class_emb_dim]) if not py_utils.use_tpu(): class_emb = py_utils.with_dependencies([ py_utils.assert_less_equal( tf.cast(class_emb, tf.int32), 1, name='one_hot_assert1'), py_utils.assert_greater_equal( tf.cast(class_emb, tf.int32), 0, name='one_hot_assert2'), py_utils.assert_equal(tf.ones([batch], tf.int32), tf.cast(tf.reduce_sum(class_emb, -1), tf.int32), name='one_hot_assert3'), ], class_emb) with tf.name_scope(p.name): norm_mean, norm_variance, beta, gamma = self.ComputeAndUpdateMoments( theta, inputs, paddings=paddings, class_emb=class_emb) return self._ComputeBN(inputs, paddings, gamma, beta, norm_mean, norm_variance)
def FProp(self, theta, inputs, paddings=None): """Apply batch normalization. Args: theta: A `.NestedMap` object containing weights' values of this layer and its children layers. inputs: The inputs tensor. Shaped [..., dim]. paddings: The paddings tensor. Shaped [..., 1], with the same rank as the input tensor. Returns: Output after applying batch normalization, with the same shape as 'inputs'. """ p = self.params if paddings is None: paddings = self._GetDefaultPaddings(inputs) with tf.name_scope(p.name): norm_mean, norm_variance, beta, gamma = self.ComputeAndUpdateMoments( theta, inputs, paddings) with tf.control_dependencies([ py_utils.assert_greater_equal( norm_variance, tf.zeros_like(norm_variance)), py_utils.assert_shape_match([tf.shape(inputs)[-1]], tf.shape(norm_mean)), py_utils.assert_shape_match([tf.shape(inputs)[-1]], tf.shape(norm_variance)), ]): bn_output = tf.nn.batch_normalization(inputs, norm_mean, norm_variance, beta, gamma, self._epsilon) bn_output *= 1.0 - paddings return bn_output
def _Normalize(self, theta, grouped_inputs, group_mean, group_variance): p = self.params group_mean = py_utils.CheckNumerics( group_mean, f'mean of {p.name} failed numeric check.') group_variance = py_utils.CheckNumerics( group_variance, f'variance of {p.name} failed numeric check.') input_shape = py_utils.GetShape(grouped_inputs) moment_shape = list(input_shape) if p.input_rank == 4: moment_shape[2] = 1 moment_shape[-1] = 1 else: moment_shape[-1] = 1 if not p.cumulative: # If not cumulative, the seqlen dimension is also reduced. moment_shape[1] = 1 group_mean = py_utils.HasShape(group_mean, moment_shape) group_variance = py_utils.HasShape(group_variance, moment_shape) group_variance = py_utils.with_dependencies([ py_utils.assert_greater_equal(group_variance, tf.cast(0, group_variance.dtype)) ], group_variance) grouped_inputs = (grouped_inputs - group_mean ) * tf.math.rsqrt(group_variance + self._epsilon) # Merges the last two dims. grouped_inputs = tf.reshape(grouped_inputs, input_shape[:-2] + [-1]) # Note, The real gamma to use is 1 + gamma. outputs = grouped_inputs * (theta.gamma + 1) + theta.beta return outputs
def FProp(self, theta, inputs, paddings=None): """Apply group normalization. Args: theta: A NestedMap object containing weights' values of this layer and its children layers. inputs: The inputs tensor with shape [batch_size, height, width, channel]. paddings: The paddings tensor with shape [batch_size, height]. Intended to be used for sequence processing where `height` is `time`. Returns: A single tensor as the output after applying group normalization, with the same shape as 'inputs'. Or a output, output_paddings pair if input paddings is not None. """ p = self.params n, h, w, c = tf.unstack(tf.shape(inputs), axis=0, num=4) group_size = p.dim // p.num_groups num_groups = p.num_groups min_group_size = p.min_group_size if p.dim > p.min_group_size else p.dim if group_size <= min_group_size: group_size = min_group_size num_groups = p.dim // group_size with tf.name_scope(p.name): x = tf.reshape(inputs, [n, h, w, num_groups, group_size]) if paddings is None: counts, means_ss, variance_ss, _, = tf.nn.sufficient_statistics( x, axes=[1, 2, 4], keepdims=True) norm_mean, norm_variance = tf.nn.normalize_moments( counts, means_ss, variance_ss, None) else: expanded_paddings = tf.reshape(paddings, [n, h, 1, 1, 1]) norm_mean, norm_variance = ComputeMomentsWithPadding( x, expanded_paddings, [1, 2, 4], keepdims=True) norm_mean = py_utils.CheckNumerics( norm_mean, 'mean of %s failed numeric check' % p.name) norm_variance = py_utils.CheckNumerics( norm_variance, 'variance of %s failed numeric check' % p.name) beta = theta.beta gamma = theta.gamma with tf.control_dependencies([ py_utils.assert_greater_equal( norm_variance, tf.cast(0., norm_variance.dtype)), py_utils.assert_shape_match([n, 1, 1, num_groups, 1], tf.shape(norm_mean)), py_utils.assert_shape_match([n, 1, 1, num_groups, 1], tf.shape(norm_variance)), ]): x = (x - norm_mean) / tf.sqrt(norm_variance + self._epsilon) x = tf.reshape(x, [n, h, w, c]) gn_output = x * gamma + beta gn_output = tf.reshape(gn_output, [n, h, w, c]) if paddings is None: return gn_output else: return gn_output, paddings
def SplitTensors(xs, num_splits): """Splits tensors in `xs` evenly into num_splits along the 1st dimenion. Args: xs: A tuple of tensors. Each tensor's 1st dimension is the same size. num_splits: A python integer. Returns: A tuple of lists of tensors, num elements in the tuple = len(xs). i-th element in each list corresponds to i-th split of each tensor in xs along the first dimension of each tensor. """ # assert first dim of all tensors in xs is equal batch_dims = [tf.shape(x)[0] for x in xs] all_batch_dims = tf.stack(batch_dims) all_batch_dims = py_utils.with_dependencies([ py_utils.assert_equal( all_batch_dims, tf.shape(xs[0])[0], message='first dim of tensors in xs must match'), py_utils.assert_greater_equal( tf.shape(xs[0])[0], num_splits, message='first dim of tensors in xs must be greater than num_splits') ], all_batch_dims) splits = ComputeSplits(tf.shape(xs[0])[0], num_splits) # add the above assertion into the compute graph splits = py_utils.with_dependencies([all_batch_dims], splits) split_xs = [tf.split(axis=0, num_or_size_splits=splits, value=x) for x in xs] return split_xs
def MakeCausalPadding(seq_len, block_size, left_context, right_context, dtype=tf.float32): """Makes the causal padding tensor for a full sequence. Args: seq_len: int or scalar int tensor. Sequence length. block_size: int. Number of time frames in a block. left_context: int. Left context size. right_context: int. Right context size. dtype: tf.dtype, default is tf.float32. Returns: A tensor of [num_blocks, block_size, context_size] taking values in {0, 1}, where context_size = block_size + (left_context - 1) + right_context. Element b, i, j is zero if in the b-th block, the i-th frame can access the j-th frame in the context. """ seq_len = py_utils.with_dependencies([ py_utils.assert_greater_equal( seq_len, 1, message='seq_len must be at least 1') ], seq_len) num_blocks = (seq_len + block_size - 1) // block_size context_size = block_size + (left_context - 1) + right_context # [num_blocks, block_size]: source positions in the original sequence. src_positions = tf.reshape(tf.range(num_blocks * block_size), [num_blocks, block_size]) # [num_blocks,]: source positions at the start of each block. block_start_positions = tf.range(0, num_blocks * block_size, block_size) # [context_size]: positions relative to the block start. relative_context_positions = tf.range(context_size) - (left_context - 1) # [num_blocks, context_size]: target positions in the original sequence. tgt_positions = (block_start_positions[:, tf.newaxis] + relative_context_positions[tf.newaxis, :]) # [num_blocks, block_size, context_size]: position differences between source- # target pairs. position_diff = src_positions[:, :, tf.newaxis] - tgt_positions[:, tf.newaxis, :] # [num_blocks, block_size, context_size]: if attention is allowed between # source-target pairs. valid_atten = tf.math.logical_and(-right_context <= position_diff, position_diff < left_context) # [num_blocks, block_size]: if the source position is valid, not padded. valid_src = src_positions < seq_len # [num_blocks, context_size]: if the target position is valid, not padded. valid_tgt = tf.math.logical_and(0 <= tgt_positions, tgt_positions < seq_len) valid_atten &= tf.math.logical_and(valid_src[:, :, tf.newaxis], valid_tgt[:, tf.newaxis, :]) padding = 1.0 - tf.cast(valid_atten, dtype=dtype) return padding
def ApplyBias(): """Bias and update log_probs and consistent.""" def TileForBeamAndFlatten(tensor): tensor = tf.reshape(tensor, [1, -1]) # [1, src_batch] tensor = tf.tile(tensor, [num_hyps_per_beam, 1 ]) # [num_hyps_per_beam, src_batch] tgt_batch = tf.shape(step_ids)[ 0] # num_hyps_per_beam*src_batch return tf.reshape(tensor, [tgt_batch]) # Consistent if step_ids == labels from previous step # TODO(navari): Consider updating consistent only if weights > 0. Then # re-evaluate the need for bias_only_if_consistent=True. # Note that prev_label is incorrrect for step 0 but is overridden later prev_label = TileForBeamAndFlatten( tf.gather(labels, tf.maximum(time_step - 1, 0), axis=1)) is_step0 = tf.equal(time_step, 0) local_consistence = tf.logical_or( is_step0, tf.equal(prev_label, tf.squeeze(step_ids, 1))) consistent = tf.logical_and(states.consistent, local_consistence) # get label, weight slices corresponding to current time_step label = TileForBeamAndFlatten( tf.gather(labels, time_step, axis=1)) weight = TileForBeamAndFlatten( tf.gather(weights, time_step, axis=1)) if p.bias_only_if_consistent: weight = weight * tf.cast(consistent, p.dtype) # convert from dense label to sparse label probs vocab_size = tf.shape(bs_results.log_probs)[1] uncertainty = tf.constant( 1e-10, p.dtype) # avoid 0 probs which may cause issues with log label_probs = tf.one_hot( label, vocab_size, on_value=1 - uncertainty, off_value=uncertainty / tf.cast(vocab_size - 1, p.dtype), dtype=p.dtype) # [tgt_batch, vocab_size] pred_probs = tf.exp(bs_results.log_probs) # interpolate predicted probs and label probs weight = tf.expand_dims(weight, 1) probs = py_utils.with_dependencies([ py_utils.assert_less_equal(weight, 1.), py_utils.assert_greater_equal(weight, 0.) ], (1.0 - weight) * pred_probs + weight * label_probs) return tf.log(probs), consistent
def _MaybeExpandPaddings(self, inputs, paddings): # rank difference is at most one. rank_diff = tf.rank(inputs) - tf.rank(paddings) paddings = py_utils.with_dependencies([ py_utils.assert_less_equal(rank_diff, 1), py_utils.assert_greater_equal(rank_diff, 0) ], paddings) # Pads [1] to the end of paddings. paddings = tf.reshape( paddings, tf.concat([tf.shape(paddings), tf.tile([1], [rank_diff])], axis=0)) return paddings
def ApplyBias(): """Bias and update log_probs and consistent.""" # Consistent if step_ids == labels from previous step # TODO(navari): Consider updating consistent only if weights > 0. Then # re-evaluate the need for bias_only_if_consistent=True. # Note that prev_label is incorrrect for step 0 but is overridden # later prev_label = TileForBeamAndFlatten( tf.gather(labels, tf.maximum(time_step - 1, 0), axis=1)) is_step0 = tf.equal(time_step, 0) local_consistence = tf.math.logical_or( is_step0, tf.equal(prev_label, tf.squeeze(step_ids, 1))) consistent = tf.math.logical_and(states.consistent, local_consistence) # get label, weight slices corresponding to current time_step label = TileForBeamAndFlatten( tf.gather(labels, time_step, axis=1)) weight = TileForBeamAndFlatten( tf.gather(weights, time_step, axis=1)) if p.bias_only_if_consistent: weight = weight * tf.cast(consistent, py_utils.FPropDtype(p)) # convert from dense label to sparse label probs vocab_size = tf.shape(bs_results.log_probs)[1] label_probs = tf.one_hot(label, vocab_size, dtype=py_utils.FPropDtype( p)) # [tgt_batch, vocab_size] pred_probs = tf.exp(bs_results.log_probs) # interpolate predicted probs and label probs weight = tf.expand_dims(weight, 1) probs = py_utils.with_dependencies([ py_utils.assert_less_equal(weight, 1.), py_utils.assert_greater_equal(weight, 0.) ], (1.0 - weight) * pred_probs + weight * label_probs) # Ensure that tf.math.log is applied to positive values. probs = tf.maximum(probs, tf.constant(1e-12, dtype=probs.dtype)) return tf.math.log(probs), consistent
def PreBeamSearchStepCallback(theta, encoder_outputs, step_ids, states, num_hyps_per_beam, *args, **kwargs): """Wrapper for adding bias to _PreBeamSearchStateCallback. Biases results.log_probs towards provided encoder_outputs.targets. Args: theta: a NestedMap of parameters. encoder_outputs: a NestedMap computed by encoder. step_ids: A tensor of shape [tgt_batch, 1]. states: A `.NestedMap` of tensors representing states that the clients would like to keep track of for each of the active hyps. num_hyps_per_beam: Beam size. *args: additional arguments to _PreBeamSearchStepCallback. **kwargs: additional arguments to _PreBeamSearchStepCallback. Returns: A tuple (results, out_states). results: A `.NestedMap` of beam search results. atten_probs: The updated attention probs, of shape [tgt_batch, src_len]. log_probs: Log prob for each of the tokens in the target vocab. This is of shape [tgt_batch, vocab_size]. out_states: a `.NestedMap` The updated states. The states relevant here are: time_step: A scalar indicating current step of decoder. Must be provided and maintained by subclass. consistent: A boolean vector of shape [tgt_batch, ] which tracks whether each hypothesis has exactly matched encoder_outputs.targets so far. """ p = self.params time_step = states.time_step bs_results, out_states = self._PreBeamSearchStepCallback( theta, encoder_outputs, step_ids, states, num_hyps_per_beam, *args, **kwargs) labels = encoder_outputs.targets.labels weights = encoder_outputs.targets.weights def TileForBeamAndFlatten(tensor): tensor = tf.reshape(tensor, [1, -1]) # [1, src_batch] tensor = tf.tile( tensor, [num_hyps_per_beam, 1]) # [num_hyps_per_beam, src_batch] tgt_batch = tf.shape(step_ids)[ 0] # num_hyps_per_beam*src_batch return tf.reshape(tensor, [tgt_batch]) # Consistent if step_ids == labels from previous step # TODO(navari): Consider updating consistent only if weights > 0. Then # re-evaluate the need for bias_only_if_consistent=True. # Note that prev_label is incorrrect for step 0 but is overridden later prev_label = TileForBeamAndFlatten( tf.gather(labels, tf.maximum(time_step - 1, 0), axis=1)) is_step0 = tf.equal(time_step, 0) local_consistence = tf.logical_or( is_step0, tf.equal(prev_label, tf.squeeze(step_ids, 1))) out_states.consistent = tf.logical_and(states.consistent, local_consistence) # get label, weight slices corresponding to current time_step label = TileForBeamAndFlatten(tf.gather(labels, time_step, axis=1)) weight = TileForBeamAndFlatten( tf.gather(weights, time_step, axis=1)) if p.bias_only_if_consistent: weight = weight * tf.cast(out_states.consistent, p.dtype) # convert from dense label to sparse label probs vocab_size = tf.shape(bs_results.log_probs)[1] uncertainty = tf.constant( 1e-10, p.dtype) # avoid 0 probs which may cause issues with log label_probs = tf.one_hot(label, vocab_size, on_value=1 - uncertainty, off_value=uncertainty / tf.cast(vocab_size - 1, p.dtype), dtype=p.dtype) # [tgt_batch, vocab_size] pred_probs = tf.exp(bs_results.log_probs) # interpolate predicted probs and label probs weight = tf.expand_dims(weight, 1) probs = py_utils.with_dependencies([ py_utils.assert_less_equal(weight, 1.), py_utils.assert_greater_equal(weight, 0.) ], (1.0 - weight) * pred_probs + weight * label_probs) bs_results.log_probs = tf.log(probs) return bs_results, out_states
def FarthestPointSampler(points, padding, num_sampled_points, precomputed_squared_distance=None, num_seeded_points=0, random_seed=None): """Samples num_sampled_points from points using farthest point sampling. Algorithm: 1. Start by selecting a random point and adding to a selected set. 2. For all remaining points, find the furthest point from those selected. 3. Add furthest point to selected. 4. Repeat 2-3 until num_sampled_points are selected. More details at https://en.wikipedia.org/wiki/Farthest-first_traversal This output of this function can be used with tf.batch_gather to extract the desired points, for example: tf.batch_gather(points, sampled_idx) Args: points: floating point tf.Tensor of shape [N, P1, dims] padding: A floating point tf.Tensor of shape [N, P1] with 0 if the point is real, and 1 otherwise. num_sampled_points: integer number of points to sample. precomputed_squared_distance: optional tf.Tensor of shape [N, P1, P1] of distances between each point. if None, distances will be computed on the fly. num_seeded_points: If num_seeded_points > 0, then the first num_seeded_points in points are considered to be seeded in the FPS sampling. Note that we assume that these points are *not* padded, and do not check padding when seeding them. random_seed: optional integer random seed to use with all the random ops. Returns: A tuple of tf.Tensors (sampled_idx, closest_idx) of types (tf.int32, tf.int32). sampled_idx is of shape [N, num_sampled_points] representing the indices selected using the sampler. This will have range of [0, P1]. closest_idx is of shape [N, P1] representing the indices of the closest sampled points for each input point. closest_idx is used in PCNN as part of the pooling operation: each point is assigned to the closest sampled point and a max is taken over them. This will have a range of [0, P2] with the index of the closest sampled point that remains. """ points = py_utils.HasRank(points, 3) batch_size, num_points, dims = py_utils.GetShape(points, 3) points = py_utils.with_dependencies( [py_utils.assert_greater_equal(num_points, num_sampled_points)], points) # Add a tiny bit of noise to the distance matrix or points so all # points are unique. This will also ensure true repeated points # like padded points are only selected after all valid points are selected. if precomputed_squared_distance is not None: precomputed_squared_distance = py_utils.HasShape( precomputed_squared_distance, [batch_size, num_points, num_points]) precomputed_squared_distance += tf.random.uniform( (batch_size, num_points, 1), minval=1e-6, maxval=1e-5, dtype=tf.float32, seed=random_seed) else: points += tf.random.uniform((batch_size, num_points, dims), minval=1e-6, maxval=1e-5, dtype=tf.float32, seed=random_seed) # TensorArray to store the sampled indices in the loop. sampled_idx = tf.TensorArray(tf.int32, num_sampled_points) # Initialize distance_to_selected to inf for all points. distance_to_selected = float('inf') * tf.ones((batch_size, num_points)) # For tracking the index to the closest selected point. closest_idx = tf.zeros((batch_size, num_points), dtype=tf.int32) # Current loop index counter. curr_idx = tf.constant(0, dtype=tf.int32) # Get number of valid points (1 is padded, so num_points - num_padded). num_valid_points = tf.cast(tf.cast(num_points, dtype=tf.float32) - tf.reduce_sum(padding, axis=1), dtype=tf.int32) def _BodyFn(curr_idx, distance_to_selected, sampled_idx, closest_idx): """Loop body for farthest point sampler.""" def _GetRandomRealPoint(): """Select the first point. For the first point, we want any random real (non padded) point, so we create a random values per point, and then set all padded ones to some large value (more than the maxval). We then take the min per batch element to get the first points. Returns: Tensor containing the index of a random point selected for each example in the batch. """ random_values = tf.random.uniform((batch_size, num_points), minval=0, maxval=1, dtype=tf.float32, seed=random_seed) random_values = tf.where(tf.equal(padding, 0.0), random_values, padding * 10) return tf.argmin(random_values, axis=1, output_type=tf.int32) def _GetFurthestPoint(): """Get point that is furthest from those already selected. We also bias the sampling towards real points by setting the distance to padded points negative until we are out of real points. Returns: Tensor containing the index of the next farthest point selected for each example in the batch. """ # Set padded points distance to negative so they aren't selected. padding_masked_distance_to_selected = tf.where( tf.equal(padding, 0.0), distance_to_selected, -1.0 * tf.ones( (batch_size, num_points), dtype=tf.float32)) # But only do this when we still have valid points left. padding_masked_distance_to_selected = tf.where( tf.less(curr_idx, num_valid_points), padding_masked_distance_to_selected, distance_to_selected) return tf.argmax(padding_masked_distance_to_selected, axis=-1, output_type=tf.int32) def _GetSeededPoint(): """Select a seeded point. Seeded points are assumed to be at the beginning of the original points. Returns: Tensor containing the index of the next seeded point to select for each example in the batch. """ return tf.ones((batch_size, ), dtype=tf.int32) * curr_idx # Select indices for this loop iteration. def _Seeded(): return tf.cond(tf.less(curr_idx, num_seeded_points), _GetSeededPoint, _GetFurthestPoint) def _Real(): return tf.cond(tf.equal(curr_idx, 0), _GetRandomRealPoint, _GetFurthestPoint) new_selected = tf.cond(tf.greater(num_seeded_points, 0), _Seeded, _Real) sampled_idx = sampled_idx.write(curr_idx, new_selected) # Extract the distance to the latest point selected to update # distance_to_selected. new_selected_gather_idx = tf.stack( [tf.range(batch_size), new_selected], axis=1) if precomputed_squared_distance is not None: new_distance = tf.gather_nd(precomputed_squared_distance, new_selected_gather_idx) else: new_points = tf.reshape( tf.gather_nd(points, new_selected_gather_idx), [batch_size, 1, dims]) new_distance = tf.reshape( SquaredDistanceMatrix(points, new_points), [batch_size, num_points]) is_newly_closest = tf.less(new_distance, distance_to_selected) distance_to_selected = tf.minimum(distance_to_selected, new_distance) # Track the index to the closest selected point. new_selected_tiled = tf.tile([[curr_idx]], [batch_size, num_points]) closest_idx = tf.cond( tf.equal(curr_idx, 0), # At the first loop iteration, the init points are the closest. lambda: new_selected_tiled, # Otherwise, update with the new points based on the distances. lambda: tf.where(is_newly_closest, new_selected_tiled, closest_idx) ) return curr_idx + 1, distance_to_selected, sampled_idx, closest_idx _, _, sampled_idx, closest_idx = tf.while_loop( lambda curr_idx, *args: tf.less(curr_idx, num_sampled_points), _BodyFn, loop_vars=(curr_idx, distance_to_selected, sampled_idx, closest_idx), back_prop=False, maximum_iterations=num_sampled_points) sampled_idx = sampled_idx.stack() # num_sampled_points x n sampled_idx = tf.transpose(sampled_idx, [1, 0]) if isinstance(batch_size, int) and isinstance(num_sampled_points, int): sampled_idx.set_shape((batch_size, num_sampled_points)) return sampled_idx, closest_idx
def FProp(self, theta, inputs, paddings=None): """Apply group normalization. Args: theta: A NestedMap object containing weights' values of this layer and its children layers. inputs: The inputs tensor with shape [batch_size, height, width, channel]. paddings: The paddings tensor with shape [batch_size, height]. Intended to be used for sequence processing where `height` is `time`. Returns: A single tensor as the output after applying group normalization, with the same shape as 'inputs'. Or a output, output_paddings pair if input paddings is not None. """ p = self.params inputs = py_utils.with_dependencies([ py_utils.assert_greater_equal(py_utils.GetRank(inputs), p.input_rank) ], inputs) min_group_size = min(p.min_group_size, p.dim) group_size = max(p.dim // p.num_groups, min_group_size) num_groups = p.dim // group_size input_shape = py_utils.GetShape(inputs) with tf.name_scope(p.name): x = tf.reshape(inputs, input_shape[:-1] + [num_groups, group_size]) expanded_rank = p.input_rank + 1 all_dims = list(range(expanded_rank)) if paddings is None: # Skip d0, d[-2] axes = all_dims[1:-2] + all_dims[-1:] counts, means_ss, variance_ss, _, = tf.nn.sufficient_statistics( x, axes=axes, keepdims=True) norm_mean, norm_variance = tf.nn.normalize_moments( counts, means_ss, variance_ss, None) else: expanded_paddings = tf.reshape( paddings, input_shape[:2] + [1] * (expanded_rank - 2)) # skip the batching and group dim if p.cumulative: # Skip d0, d1 and d[-2] reduce_over_dims = all_dims[2:-2] + all_dims[-1:] norm_mean, norm_variance = ComputeMomentsWithPadding( x, expanded_paddings, reduce_over_dims=reduce_over_dims, cumulative_axis=1, keepdims=True) else: # Skip d0, d[-2] reduce_over_dims = all_dims[1:-2] + all_dims[-1:] norm_mean, norm_variance = ComputeMomentsWithPadding( x, expanded_paddings, reduce_over_dims, keepdims=True) norm_mean = py_utils.CheckNumerics( norm_mean, 'mean of %s failed numeric check' % p.name) norm_variance = py_utils.CheckNumerics( norm_variance, 'variance of %s failed numeric check' % p.name) beta = theta.beta gamma = theta.gamma n = input_shape[0] t = input_shape[1] if p.cumulative else 1 norm_shape = [n, t, 1, num_groups, 1 ] if p.input_rank == 4 else [n, t, num_groups, 1] with tf.control_dependencies([ py_utils.assert_greater_equal( norm_variance, tf.cast(0., norm_variance.dtype)), py_utils.assert_shape_match(norm_shape, tf.shape(norm_mean)), py_utils.assert_shape_match(norm_shape, tf.shape(norm_variance)), ]): x = (x - norm_mean) / tf.sqrt(norm_variance + self._epsilon) x = tf.reshape(x, input_shape) gn_output = x * gamma + beta gn_output = tf.reshape(gn_output, input_shape) if paddings is None: return gn_output else: return gn_output, paddings
def _StreamMoments(self, inputs, paddings, cached_sum, cached_count, cached_var): """Computes mean and variance over the valid data points in inputs. Args: inputs: [B, T, F, N, G] or [B, T, N, G] paddings: [B, T, 1, 1, 1] or [B, T, 1, 1] cached_sum: [B, 1, 1, N, 1] or [B, 1, N, 1] cached_count: same shape as cached_sum. cached_var: same shape as cached_sum. Returns: mean: [B, T, 1, N, 1] or [B, T, N, 1] variance: same shape as mean. new_cached_sum: same shape as cached_sum. new_cached_count: same shape as cached_count. """ tf.logging.vlog(1, 'inputs: %r', inputs) tf.logging.vlog(1, 'paddings: %r', paddings) tf.logging.vlog(1, 'cached_sum: %r', cached_sum) tf.logging.vlog(1, 'cached_count: %r', cached_count) inputs = py_utils.ApplyPadding(paddings, inputs, use_select=False) input_rank = py_utils.GetRank(inputs) assert input_rank is not None, (f'inputs rank must be staic for ' f'{repr(inputs)}') reduce_over_dims = list(range(input_rank)) # Skip B, T, and N. Reduce {F,G} or just G. reduce_over_dims = reduce_over_dims[2:-2] + reduce_over_dims[-1:] tf.logging.vlog(1, 'reduce_over_dims: %s', reduce_over_dims) # [B, T, 1, N, 1] or [B, T, N, 1] sum_v = tf.reduce_sum(inputs, reduce_over_dims, keepdims=True) sum_v = tf.math.cumsum(sum_v, axis=1) sum_v += cached_sum # [B, T, 1, 1, 1] or [B, T, 1, 1] mask = tf.cast(1.0 - paddings, inputs.dtype) count_v = tf.reduce_sum(mask, reduce_over_dims, keepdims=True) count_v = tf.math.cumsum(count_v, axis=1) input_shape = py_utils.GetShape(inputs) if input_rank == 4: # F * G multiplier = input_shape[-1] * input_shape[-3] else: # G multiplier = input_shape[-1] count_v *= multiplier count_v += cached_count tf.logging.vlog(1, 'sum_v: %r', sum_v) tf.logging.vlog(1, 'count_v: %r', count_v) mean = sum_v / tf.maximum(count_v, 1.0) sum_vv = tf.reduce_sum(py_utils.ApplyPadding( paddings, tf.math.squared_difference(inputs, mean), use_select=False), reduce_over_dims, keepdims=True) sum_vv = tf.math.cumsum(sum_vv, axis=1) sum_vv += cached_var cached_sum = sum_v[:, -1:] cached_count = count_v[:, -1:] cached_var = sum_vv[:, -1:] variance = py_utils.with_dependencies([ py_utils.assert_greater_equal(sum_vv, tf.cast(0, sum_vv.dtype)), ], sum_vv / tf.maximum(count_v, 1.0)) return mean, variance, cached_sum, cached_count, cached_var
def MakeLocalMask(seq_len, block_size, left_context, right_context, query_stride=1, dtype=tf.float32): """Makes the mask tensor for a full sequence. The returned mask reflects the given context sizes, where position i attends to tokens in the range [i - (left_context-1), i + right_context]. For example, given seq_len=4, block_size=2, left_context=3, right_context=0, the result mask is [[[0., 0., 1., 0.], 1st query in 1st block attends 1st key. [0., 0., 1., 1.]], 2nd query in 1st block attends 2nd and left keys [[1., 1., 1., 0.], 1st query in 2nd block attends 1st and left keys [0., 1., 1., 1.]]] 2st query in 2nd block attends 2nd and left keys The position i can move by stride, which means queries are pooled by stride. For example, given same params and stride=2, the result mask is [[[0., 0., 1., 1.]], The pooled query in 1st block attends 1st and 2nd keys [[1., 1., 1., 1.]]] The pooled query in 2st block attends 1st, 2nd and left Args: seq_len: int or scalar int tensor. Sequence length. block_size: int. Number of time frames in a block. left_context: int. Left context size. right_context: int. Right context size. query_stride: int. Query stride for funnel pool. dtype: tf.dtype, default is tf.float32. Returns: A tensor of [num_blocks, block_size//stride, context_size] taking values in {0, 1}, where context_size = block_size + (left_context - 1) + right_context Element b, i, j is 1 if in the b-th block, the i-th frame can access the j-th frame in the context. """ assert block_size % query_stride == 0, ( f'block_size({block_size}) must be a multiple of ' f'query_stride({query_stride}).') seq_len = py_utils.with_dependencies([ py_utils.assert_greater_equal( seq_len, 1, message='seq_len must be at least 1') ], seq_len) num_blocks = (seq_len + block_size - 1) // block_size context_size = block_size + (left_context - 1) + right_context # [num_blocks, block_size]: source positions in the original sequence. src_positions = tf.reshape(tf.range(num_blocks * block_size), [num_blocks, block_size]) # [num_blocks,]: source positions at the start of each block. block_start_positions = tf.range(0, num_blocks * block_size, block_size) # [context_size]: positions relative to the block start. relative_context_positions = tf.range(context_size) - (left_context - 1) # [num_blocks, context_size]: target positions in the original sequence. tgt_positions = (block_start_positions[:, tf.newaxis] + relative_context_positions[tf.newaxis, :]) # [num_blocks, block_size, context_size]: position differences between source- # target pairs. position_diff = src_positions[:, :, tf.newaxis] - tgt_positions[:, tf.newaxis, :] # [num_blocks, block_size, context_size]: if attention is allowed between # source-target pairs. valid_atten = tf.math.logical_and(-right_context <= position_diff, position_diff < left_context) # [num_blocks, block_size]: if the source position is valid, not padded. valid_src = src_positions < seq_len # [num_blocks, context_size]: if the target position is valid, not padded. valid_tgt = tf.math.logical_and(0 <= tgt_positions, tgt_positions < seq_len) valid_atten &= tf.math.logical_and(valid_src[:, :, tf.newaxis], valid_tgt[:, tf.newaxis, :]) valid_atten = tf.cast(valid_atten, dtype=dtype) if query_stride: valid_atten = tf.reshape(valid_atten, [ num_blocks, block_size // query_stride, query_stride, context_size ]) valid_atten = tf.reduce_max(valid_atten, axis=-2) return valid_atten
def _StreamMoments(self, inputs, paddings, cached_sum, cached_count, cached_var): """Computes mean and variance over the valid data points in inputs. Args: inputs: [B, T, F, N, G] or [B, T, N, G] paddings: [B, T, 1, 1, 1] or [B, T, 1, 1] cached_sum: [B, 1, 1, N, 1] or [B, 1, N, 1] cached_count: same shape as cached_sum. cached_var: same shape as cached_sum. Returns: mean: [B, T, 1, N, 1] or [B, T, N, 1] variance: same shape as mean. new_cached_sum: same shape as cached_sum. new_cached_count: same shape as cached_count. """ tf.logging.vlog(1, 'inputs: %r', inputs) tf.logging.vlog(1, 'paddings: %r', paddings) tf.logging.vlog(1, 'cached_sum: %r', cached_sum) tf.logging.vlog(1, 'cached_count: %r', cached_count) mask = tf.cast(1.0 - paddings, inputs.dtype) inputs *= tf.cast(mask, inputs.dtype) input_rank = py_utils.GetRank(inputs) assert input_rank is not None, (f'inputs rank must be staic for ' f'{repr(inputs)}') reduce_over_dims = list(range(input_rank)) # Skip B, T, and N. Reduce {F,G} or just G. reduce_over_dims = reduce_over_dims[2:-2] + reduce_over_dims[-1:] tf.logging.vlog(1, 'reduce_over_dims: %s', reduce_over_dims) # [B, T, 1, N, 1] or [B, T, N, 1] sum_v = tf.reduce_sum(inputs, reduce_over_dims, keepdims=True) sum_v = tf.math.cumsum(sum_v, axis=1) sum_v += cached_sum # [B, T, 1, 1, 1] or [B, T, 1, 1] count_v = tf.reduce_sum(mask, reduce_over_dims, keepdims=True) count_v = tf.math.cumsum(count_v, axis=1) input_shape = py_utils.GetShape(inputs) if input_rank == 4: # F * G multiplier = input_shape[-1] * input_shape[-3] else: # G multiplier = input_shape[-1] count_v *= multiplier count_v += cached_count count_v = tf.maximum(count_v, 1.0) tf.logging.vlog(1, 'sum_v: %r', sum_v) tf.logging.vlog(1, 'count_v: %r', count_v) mean = sum_v / count_v if py_utils.FLAGS.tflite_compatible: # TfLite doesn't support broadcasting with 5D tensors. inputs_shape = py_utils.GetShape(inputs) if len(inputs_shape) == 4: tiled_mean = tf.tile(mean, [1, 1, 1, inputs_shape[3]]) else: tiled_mean = tf.tile( mean, [1, 1, inputs_shape[2], 1, inputs_shape[4]]) sum_vv = tf.reduce_sum(tf.math.square(inputs - tiled_mean) * mask, reduce_over_dims, keepdims=True) else: sum_vv = tf.reduce_sum((inputs - mean)**2 * mask, reduce_over_dims, keepdims=True) sum_vv = tf.math.cumsum(sum_vv, axis=1) sum_vv += cached_var cached_sum = sum_v[:, -1:] cached_count = count_v[:, -1:] cached_var = sum_vv[:, -1:] variance = py_utils.with_dependencies([ py_utils.assert_greater_equal(sum_vv, tf.cast(0, sum_vv.dtype)), ], sum_vv / count_v) return mean, variance, cached_sum, cached_count, cached_var