def Proc(record): """Parses a serialized tf.Example record.""" outputs = [ ('source_id', tf.VarLenFeature(tf.int64)), ('source_padding', tf.VarLenFeature(tf.float32)), ('target_id', tf.VarLenFeature(tf.int64)), ('target_padding', tf.VarLenFeature(tf.float32)), ('target_label', tf.VarLenFeature(tf.int64)), ('target_weight', tf.VarLenFeature(tf.float32)), ] features = tf.parse_single_example(record, dict(outputs)) for k, v in six.iteritems(features): features[k] = v.values bucket_key = tf.cast( tf.maximum(tf.reduce_sum(1.0 - features['source_padding']), tf.reduce_sum(1.0 - features['target_padding'])), tf.int32) return [features[k] for k, _ in outputs], bucket_key
def QuantizeWeight(self, w): p = self.params w_min = tf.reduce_min(w) w_max = tf.reduce_max(w) # NOTE: We force a small, non-zero range because otherwise, zero weights # can cause downstream inference engines to blow up. w_min = tf.minimum(w_min, -p.quantize_weight_epsilon) w_max = tf.maximum(w_max, p.quantize_weight_epsilon) quant_w = self._MaybeFakeQuant(w, w_min, w_max, num_bits=p.bits) if self.do_eval: return quant_w else: # If quantizing during training, skip quantization if it produces # NANs. Sometimes early in the training process, things are unstable # and ranges can produce numerical instability that makes it # impossible to perform a fake_quant. quant_w_has_nans = tf.math.is_nan(quant_w) return tf.where(quant_w_has_nans, w, quant_w)
def SequenceTrimLastToken(x, x_paddings): """Trims the last token off of sequence `x`, and set trimmed elements to 0. Args: x: A sequence of tokens of shape [batch_size, x_len_max]. x_paddings: The paddings of `x`. Returns: A tuple. - The new sequence, Tensor of shape [batch_size, x_len_max]. - The new paddings, Tensor of shape [batch_size, x_len_max]. """ x_len = tf.reduce_sum(1 - x_paddings, 1) x_len_max = py_utils.GetShape(x)[1] x_trimmed_len = tf.maximum(x_len - 1, 0) x_trimmed_paddings = tf.sequence_mask(x_trimmed_len, x_len_max, x_paddings.dtype) x_trimmed = x * tf.cast(x_trimmed_paddings, x.dtype) return x_trimmed, 1 - x_trimmed_paddings
def TopKAccuracy(k, logits, labels, weights): """Compute top-k accuracy. Args: k: An int scalar. Top-k. logits: A [N, C] float tensor. labels: A [N] int vector. weights: A [N] float vector. Returns: A float scalar. The accuracy at precision k. """ logits = py_utils.HasRank(logits, 2) n, _ = tf.unstack(tf.shape(logits), 2) labels = py_utils.HasShape(labels, [n]) weights = py_utils.HasShape(weights, [n]) correct = tf.nn.in_top_k(targets=labels, predictions=logits, k=k) return tf.reduce_sum(tf.cast(correct, weights.dtype) * weights) / tf.maximum( 1e-8, tf.reduce_sum(weights))
def _FPropChunk(self, pcm_audio_chunk, pcm_audio_paddings): p = self.params pcm_audio_chunk = tf.cast(pcm_audio_chunk, tf.float32) if p.use_divide_stream: pcm_audio_chunk = pcm_audio_chunk / 32768.0 # shape: [batch, time, _frame_size] framed_signal = tf.signal.frame(pcm_audio_chunk, self._frame_size, self._frame_step, p.pad_end) # Pre-emphasis. if p.preemph != 0.0: preemphasized = self._ApplyPreemphasis(framed_signal) else: preemphasized = framed_signal[..., :-1] # Noise. if p.noise_scale > 0.0: noise_signal = tf.random.normal( tf.shape(preemphasized), stddev=p.noise_scale, mean=0.0, seed=p.random_seed) else: noise_signal = 0.0 # Apply window fn. windowed_signal = preemphasized + noise_signal if self._window_fn is not None: window = self._window_fn(self._frame_size - 1, framed_signal.dtype) windowed_signal *= window mel_spectrogram = self._MelSpectrogram(windowed_signal) mel_spectrogram_log = tf.math.log( tf.maximum(float(p.output_floor), mel_spectrogram)) # Mean and stddev. mel_spectrogram_norm = ( (mel_spectrogram_log - tf.convert_to_tensor(p.per_bin_mean)) / tf.convert_to_tensor(p.per_bin_stddev)) return mel_spectrogram_norm, self._GetMelPadding(pcm_audio_paddings)
def _PaddedMeanFn(inp): """Apply padded mean using reduce_sum and dividing by # real points.""" # Replace all padded features with 0 by masking the padded features out. mask = 1 - inp.padding features = inp.features * mask[..., tf.newaxis] features = tf.reduce_sum(features, axis=-2) num_real_points = tf.reduce_sum(mask, axis=-1, keep_dims=True) # Prevent the divisor of our padded mean from ever being 0, so that # the gradient flowing back through this op doesn't give us NaNs. num_real_points = tf.maximum(num_real_points, 1) features = features / num_real_points # Replace features of all padded points by zeros. If a batch of points are # all padded, then num_real_points will be zero. We set the features to be # zero, so that we don't get any downstream issue with NaNs. # Note that inf * 0 = NaN. all_padded = tf.equal(num_real_points, 0.) all_padded = tf.broadcast_to(all_padded, py_utils.GetShape(features)) features = tf.where(all_padded, tf.zeros_like(features), features) return py_utils.CheckNumerics(features)
def _GetSequenceLength(self, example): """Returns sequence length for the example NestedMap from the dataset. This function is used by the TFDatasetBatchBySequenceLength DataSource to obtain the key used for bucketing. Bucketing separates examples into groups before batching, such that each batch contains only examples within a certain length. Args: example: A NestedMap containing an input example. Tensors in the example do not have a leading batch dimension. Returns: An integer sequence length for the example. """ return tf.cast( tf.round( tf.maximum(tf.reduce_sum(1.0 - example.src.paddings), tf.reduce_sum(1.0 - example.tgt.paddings))), tf.int32)
def _MaybePadSourceInputs(self, src_inputs, src_paddings): p = self.params if not p.append_eos_frame: return src_inputs, src_paddings per_src_len = tf.reduce_sum(1 - src_paddings, 1) per_src_len += 1 max_src_len = tf.reduce_max(per_src_len) input_shape = tf.shape(src_inputs) input_len = tf.maximum(input_shape[1], tf.cast(max_src_len, tf.int32)) pad_steps = input_len - input_shape[1] src_inputs = tf.concat([ src_inputs, tf.zeros(inplace_ops.inplace_update(input_shape, 1, pad_steps), src_inputs.dtype) ], 1) src_paddings = 1 - tf.sequence_mask( tf.reshape(per_src_len, [input_shape[0]]), tf.reshape( input_len, []), src_paddings.dtype) return src_inputs, src_paddings
def _setup_sparsity(self): begin_step = self._spec.sparsity_function_begin_step end_step = self._spec.sparsity_function_end_step initial_sparsity = self._spec.initial_sparsity target_sparsity = self._spec.target_sparsity exponent = self._spec.sparsity_function_exponent with tf.name_scope(self._spec.name): p = tf.minimum( 1.0, tf.maximum( 0.0, tf.div(tf.cast(self._global_step - begin_step, tf.float32), end_step - begin_step))) sparsity = tf.add(tf.multiply(initial_sparsity - target_sparsity, tf.pow(1 - p, exponent)), target_sparsity, name='sparsity') return sparsity
def AddMultiCurveSubplot(fig, tensors, paddings, labels, xlabels=None, **kwargs): """Adds a multi curve subplot to Matplotlib figure. Plots one line for each entry in tensors and assigns a plot label legend. Args: fig: The Matplotlib figure. tensors: List of tensors of shape [batch, length] paddings: Paddings for 'tensors' with shape [batch, length] with 0. in valid positions and 1. in invalid. Or list of padding tensors of same length as tensors. labels: A list of tensor names (strings) of the same length as 'tensors'. xlabels: A string tensor of shape [batch] with an xlabel per batch. **kwargs: With optional, title, xlabel, ylabel, fontsize. """ data = [] row_labels = [] if isinstance(paddings, tf.Tensor): paddings = [paddings] * len(tensors) batch_size = py_utils.GetShape(paddings[0])[0] max_lengths = tf.zeros([batch_size], tf.int32) for t, l, p in zip(tensors, labels, paddings): max_lengths = tf.maximum(max_lengths, py_utils.LengthsFromPaddings(p)) if t is not None: data.append(py_utils.ApplyPadding(p, t)) row_labels.append(l) shape = py_utils.GetShape(data[0], 2) data = tf.reshape(tf.concat(data, -1), [shape[0], len(data), shape[1]]) args = [data, max_lengths] if xlabels is not None: args.append(xlabels) fig.AddSubplot(args, plot_func=_AddMultiCurveRowPlots, row_labels=row_labels, **kwargs)
def ComputeMomentsWithPadding(inputs, padding, reduce_over_dims, enable_cross_replica_sum_on_tpu=False, keepdims=False): """Computes mean and variance over the valid data points in inputs.""" mask = 1.0 - padding inputs = py_utils.with_dependencies([ py_utils.assert_equal(tf.rank(inputs), tf.rank(mask)), py_utils.assert_greater_equal(mask, tf.zeros_like(mask)), ], inputs) sum_v = tf.reduce_sum(inputs * tf.cast(mask, inputs.dtype), reduce_over_dims, keepdims=keepdims) count_v = tf.reduce_sum(mask, reduce_over_dims, keepdims=keepdims) # Input shape is guaranteed to be a multiple of mask shape because the # inputs * mask op above was successfully broadcasted. input_size_on_reduced_dims = tf.reduce_prod( tf.gather(tf.shape(inputs), reduce_over_dims)) mask_size_on_reduced_dims = tf.reduce_prod( tf.gather(tf.shape(mask), reduce_over_dims)) mask_multiplier = tf.math.truediv(input_size_on_reduced_dims, mask_size_on_reduced_dims) count_v *= tf.cast(mask_multiplier, count_v.dtype) if py_utils.use_tpu() and enable_cross_replica_sum_on_tpu: sum_v = tf.tpu.cross_replica_sum(sum_v) count_v = tf.tpu.cross_replica_sum(count_v) count_v = tf.maximum(count_v, 1.0) mean = sum_v / count_v sum_vv = tf.reduce_sum((inputs - mean) * (inputs - mean) * mask, reduce_over_dims, keepdims=keepdims) if py_utils.use_tpu() and enable_cross_replica_sum_on_tpu: sum_vv = tf.tpu.cross_replica_sum(sum_vv) variance = py_utils.with_dependencies([ py_utils.assert_greater_equal(sum_vv, tf.zeros_like(sum_vv)), ], sum_vv / count_v) return mean, variance
def SphericalCoordinatesTransform(points_xyz): """Converts points from xyz coordinates to spherical coordinates. https://en.wikipedia.org/wiki/Spherical_coordinate_system#Coordinate_system_conversions for definitions of the transformations. Args: points_xyz: A floating point tensor with shape [..., 3], where the inner 3 dimensions correspond to xyz coordinates. Returns: A floating point tensor with the same shape [..., 3], where the inner dimensions correspond to (dist, theta, phi), where phi corresponds to azimuth/yaw (rotation around z), and theta corresponds to pitch/inclination (rotation around y). """ dist = tf.sqrt(tf.reduce_sum(tf.square(points_xyz), axis=-1)) theta = tf.acos(points_xyz[..., 2] / tf.maximum(dist, 1e-7)) # Note: tf.atan2 takes in (y, x). phi = tf.atan2(points_xyz[..., 1], points_xyz[..., 0]) return tf.stack([dist, theta, phi], axis=-1)
def GetState(self, theta): """Gets the state from theta.""" p = self.params if p.is_inference: # State is not used for inference. Just return dummy. return tf.zeros([1], tf.float32) else: # Calculations/vars need to be float but these can be ints in the params. clip_end_step = tf.cast(p.clip_end_step, tf.float32) clip_start_step = tf.cast(p.clip_start_step, tf.float32) quant_start_step = tf.cast(p.quant_start_step, tf.float32) global_step = tf.cast(theta.global_step, tf.float32) # Will be negative if before clipping starts. clip_ratio = (tf.minimum(clip_end_step - clip_start_step, global_step - clip_start_step) / tf.maximum(1.0, clip_end_step - clip_start_step)) # Currently fq is either on (1.0) or off (-1.0). Progressive quantization # may later occupy 0..1.0. fq_ratio = tf.where(global_step < quant_start_step, -1.0, 1.0) return tf.stack([clip_ratio, fq_ratio])
def _RecordTensor(self, t_name): p = self.params if self.do_eval: return [] accumulator_name = self._GetAccumulatorNameForTensor(t_name) accumulator = self.accumulators[accumulator_name] min_var = self._GetQStateVar(t_name, 'min') max_var = self._GetQStateVar(t_name, 'max') # Unpack state tensor. current_value = accumulator.GetValue() count = current_value[0] min_value = current_value[1] max_value = current_value[2] accumulator.Reset() def Ema(variable, value): return (1.0 - p.ema_decay) * (variable - value) # Note that small floating point issues can cause ranges that naturally # begin or end at zero to move slightly past, causing hard failures # downstream (checks that all ranges straddle zero). We therefore repeat # the straddling constraint here. return [ tf.assign( min_var, tf.minimum( 0., min_var - tf.where(count > 0., Ema(min_var, min_value), 0.))), tf.assign( max_var, tf.maximum( 0., max_var - tf.where(count > 0., Ema(max_var, max_value), 0.))), ]
def __init__(self, learning_rate, momentum=0.0, initial_accumulator_value=0.0, start_preconditioning_steps=1000, statistics_computation_frequency=1, matrix_epsilon=1e-6, synchronous_preconditioning=False, second_moment_averaging=1.0, fallback_to_diagonal_dim=4096, max_any_dim=6656, block_size=4096, block_partition_threshold_size=1000000, global_step=None, exponent_multiplier=1.0, name="DistributedShampoo"): """Construct a DistributedShampoo optimizer. Args: learning_rate: A `Tensor` or a floating point value. The learning rate. momentum: A `Tensor` or a floating point value. Momentum is not applied to sparse updates. initial_accumulator_value: A floating point value. start_preconditioning_steps: A int32 value which indicates when to start preconditioning. statistics_computation_frequency: A int32 step value which indicates how often to compute statistics for preconditioning. matrix_epsilon: An epsilon regularizer to make the matrices positive definite. synchronous_preconditioning: Whether to run preconditioning synchronously. second_moment_averaging: 1.0 means sum of gradients squares, while less than 1.0 switches to RMSProp style exponential moving averages of the second moments. fallback_to_diagonal_dim: Fallback to diagonal version of AFMA if the any of the dimension is larger than fallback_to_diagonal_dim. max_any_dim: If maximum value for any dimension is greater than this value we skip preconditioning and fall back to the diagonal. block_size: Dimension of the partitioned tensors. block_partition_threshold_size: Partitions diemnsions beyond this size. global_step: Global step for training. exponent_multiplier: A multiplier 'e` for the exponent for the inverse calculation. e * -1/(2*rank). Only applies when calculating inverses through svd. name: Optional name prefix for the operations created when applying gradients. """ super().__init__(False, name) self._learning_rate = learning_rate self._momentum = momentum self._initial_accumulator_value = initial_accumulator_value self._start_preconditioning_steps = start_preconditioning_steps self._matrix_epsilon = matrix_epsilon self._synchronous_preconditioning = synchronous_preconditioning self._second_moment_averaging = second_moment_averaging self._fallback_to_diagonal_dim = fallback_to_diagonal_dim self._max_any_dim = max_any_dim self._block_size = block_size # NOTE: On XLA - int64 is not handled properly. if global_step is not None: self._global_step = tf.cast(tf.identity(global_step), tf.int32) else: self._global_step = tf.cast( tf.identity(tf.train.get_or_create_global_step()), tf.int32) self._run_nondiagonal_update = tf.greater_equal( self._global_step, self._start_preconditioning_steps) start_steps_f = tf.cast(self._start_preconditioning_steps, tf.float32) global_step_f = tf.cast(self._global_step, tf.float32) self._run_nondiagonal_update_warmup = tf.minimum( 1.0, tf.maximum((global_step_f - start_steps_f) / start_steps_f, 0.0)) # Computes statistics every K steps. self._statistics_computation_frequency = statistics_computation_frequency self._run_statistics_computation = tf.equal( tf.math.floormod(self._global_step, self._statistics_computation_frequency), 0) # All vars that are preconditioned. self._all_vars_for_preconditioning = [] self._exponent_multiplier = exponent_multiplier self._partition_info = PartitionConfig(block_partition_threshold_size, block_size) self._partitioner_metadata = {}
def FProp(self, theta, x, paddings=None, update=False): """Computes distances of the given input 'x' to all centroids. This implementation applies layer normalization on 'x' internally first, and the returned 'dists' is computed using the normalized 'x'. Args: theta: A `.NestedMap` of weights' values of this layer. x: A tensor of shape [B, L, N, H]. paddings: If not None, a tensor of shape [B, L]. update: bool, whether to update centroids using x. Returns: dists: "distances" of the given input 'x' to all centroids. Shape [B, L, N, K]. k_means_loss: the average squared Euclidean distances to the closest centroid, a scalar. """ p = self.params if paddings is None: paddings = tf.zeros_like(x[:, :, 0, 0]) # Shape [B, L, 1, 1] paddings_4d = paddings[:, :, None, None] if p.apply_layer_norm: x = KMeansClusteringForAtten.LayerNorm(x, p.epsilon) # 'x' is normalized (but theta.means is not), we use negative dot product to # approximate the Euclidean distance here. dists = -tf.einsum('BLNH, NKH -> BLNK', x, theta.means) # For padded positions we update the distances to very large numbers. very_large_dists = tf.ones_like(dists) * tf.constant( 0.1, dtype=dists.dtype) * dists.dtype.max paddings_tiled = tf.tile(paddings_4d, [1, 1, p.num_heads, p.num_clusters]) dists = tf.where(paddings_tiled > 0.0, very_large_dists, dists) # Shape [B, L, N, K], the same as 'dists' above. nearest_one_hot = tf.one_hot( tf.math.argmin(dists, axis=-1), p.num_clusters, dtype=py_utils.FPropDtype(p)) # Same shape as the input 'x'. nearest_centroid = tf.einsum('BLNK, NKH -> BLNH', nearest_one_hot, theta.means) diff = tf.math.squared_difference(x, tf.stop_gradient(nearest_centroid)) diff = py_utils.ApplyPadding(paddings_4d, diff) diff = tf.math.reduce_mean(diff, axis=2) # The commitment loss which when back proped against encourages the 'x' # values to commit to their chosen centroids. k_means_loss = tf.math.reduce_sum(diff) / tf.math.reduce_sum(1.0 - paddings) summary_utils.scalar('k_means/squared_distance_loss', k_means_loss) # TODO(zhouwk): investigate normalizing theta.means after each update. means_norm = tf.norm(theta.means) summary_utils.scalar('k_means/centroid_l2_norm/min', tf.math.reduce_min(means_norm)) summary_utils.scalar('k_means/centroid_l2_norm/mean', tf.math.reduce_mean(means_norm)) if not update: return dists, k_means_loss # To update the centroids (self.vars.means), we apply gradient descent on # the mini-batch of input 'x', which yields the following: # new_centroid = centroid + (1 - decay) * (x_mean - centroid) # where x_mean is the average over all the input vectors closest to this # centroid. # # Note that this approach is equivalent with backprop via # loss = tf.math.reduce_mean( # tf.math.squared_difference(tf.stop_gradient(x), nearest_centroid))) # , except that here the learning rate is independently set via 'decay'. # Ensure that the padded positions are not used to update the centroids. nearest_one_hot = py_utils.ApplyPadding(paddings_4d, nearest_one_hot) # Sum away batch and sequence length dimensions to get per cluster count. # Shape: [N, K] per_cluster_count = tf.reduce_sum(nearest_one_hot, axis=[0, 1]) summary_utils.histogram('k_means/per_cluster_vec_count', per_cluster_count) # Sum of the input 'x' per each closest centroid. sum_x = tf.einsum('BLNK, BLNH -> NKH', nearest_one_hot, x) if py_utils.use_tpu(): per_cluster_count = tf.tpu.cross_replica_sum(per_cluster_count) sum_x = tf.tpu.cross_replica_sum(sum_x) # If per_cluster_count for a cluster is 0, then 'nearest_one_hot' in that # cluster's position will always be 0, hence 'sum_x' in that dimension will # be 0. new_means = sum_x / tf.maximum( tf.constant(1.0, dtype=per_cluster_count.dtype), tf.expand_dims(per_cluster_count, axis=-1)) # We use exponential moving average. TODO(zhouwk): investigate smooth this # over an exponentially moving averaged per cluster count. # # Note that we intentionally do not normalize the means after this update # as empirically this works better. update_means_diff = tf.cast((1.0 - p.decay) * (new_means - theta.means), self.vars.means.dtype) return py_utils.with_dependencies( [tf.assign_add(self.vars.means, update_means_diff)], dists), k_means_loss
def _StreamMoments(self, inputs, paddings, cached_sum, cached_count, cached_var): """Computes mean and variance over the valid data points in inputs. Args: inputs: [B, T, F, N, G] or [B, T, N, G] paddings: [B, T, 1, 1, 1] or [B, T, 1, 1] cached_sum: [B, 1, 1, N, 1] or [B, 1, N, 1] cached_count: same shape as cached_sum. cached_var: same shape as cached_sum. Returns: mean: [B, T, 1, N, 1] or [B, T, N, 1] variance: same shape as mean. new_cached_sum: same shape as cached_sum. new_cached_count: same shape as cached_count. """ tf.logging.vlog(1, 'inputs: %r', inputs) tf.logging.vlog(1, 'paddings: %r', paddings) tf.logging.vlog(1, 'cached_sum: %r', cached_sum) tf.logging.vlog(1, 'cached_count: %r', cached_count) inputs = py_utils.ApplyPadding(paddings, inputs, use_select=False) input_rank = py_utils.GetRank(inputs) assert input_rank is not None, (f'inputs rank must be staic for ' f'{repr(inputs)}') reduce_over_dims = list(range(input_rank)) # Skip B, T, and N. Reduce {F,G} or just G. reduce_over_dims = reduce_over_dims[2:-2] + reduce_over_dims[-1:] tf.logging.vlog(1, 'reduce_over_dims: %s', reduce_over_dims) # [B, T, 1, N, 1] or [B, T, N, 1] sum_v = tf.reduce_sum(inputs, reduce_over_dims, keepdims=True) sum_v = tf.math.cumsum(sum_v, axis=1) sum_v += cached_sum # [B, T, 1, 1, 1] or [B, T, 1, 1] mask = tf.cast(1.0 - paddings, inputs.dtype) count_v = tf.reduce_sum(mask, reduce_over_dims, keepdims=True) count_v = tf.math.cumsum(count_v, axis=1) input_shape = py_utils.GetShape(inputs) if input_rank == 4: # F * G multiplier = input_shape[-1] * input_shape[-3] else: # G multiplier = input_shape[-1] count_v *= multiplier count_v += cached_count tf.logging.vlog(1, 'sum_v: %r', sum_v) tf.logging.vlog(1, 'count_v: %r', count_v) mean = sum_v / tf.maximum(count_v, 1.0) sum_vv = tf.reduce_sum(py_utils.ApplyPadding( paddings, tf.math.squared_difference(inputs, mean), use_select=False), reduce_over_dims, keepdims=True) sum_vv = tf.math.cumsum(sum_vv, axis=1) sum_vv += cached_var cached_sum = sum_v[:, -1:] cached_count = count_v[:, -1:] cached_var = sum_vv[:, -1:] variance = py_utils.with_dependencies([ py_utils.assert_greater_equal(sum_vv, tf.cast(0, sum_vv.dtype)), ], sum_vv / tf.maximum(count_v, 1.0)) return mean, variance, cached_sum, cached_count, cached_var
def Pos(x): return tf.maximum(tf.constant(1e-8, x.dtype), x)
def PadToTargetSeqLen(tensor, constant): length = tf.shape(tensor)[1] pad = tf.maximum(0, p.beam_search.target_seq_len - length) return tf.pad(tensor, [[0, 0], [0, pad]], constant_values=constant)
def _StreamMoments(self, inputs, paddings, cached_sum, cached_count, cached_var): """Computes mean and variance over the valid data points in inputs. Args: inputs: [B, T, F, N, G] or [B, T, N, G] paddings: [B, T, 1, 1, 1] or [B, T, 1, 1] cached_sum: [B, 1, 1, N, 1] or [B, 1, N, 1] cached_count: same shape as cached_sum. cached_var: same shape as cached_sum. Returns: mean: [B, T, 1, N, 1] or [B, T, N, 1] variance: same shape as mean. new_cached_sum: same shape as cached_sum. new_cached_count: same shape as cached_count. """ tf.logging.vlog(1, 'inputs: %r', inputs) tf.logging.vlog(1, 'paddings: %r', paddings) tf.logging.vlog(1, 'cached_sum: %r', cached_sum) tf.logging.vlog(1, 'cached_count: %r', cached_count) mask = tf.cast(1.0 - paddings, inputs.dtype) inputs *= tf.cast(mask, inputs.dtype) input_rank = py_utils.GetRank(inputs) assert input_rank is not None, (f'inputs rank must be staic for ' f'{repr(inputs)}') reduce_over_dims = list(range(input_rank)) # Skip B, T, and N. Reduce {F,G} or just G. reduce_over_dims = reduce_over_dims[2:-2] + reduce_over_dims[-1:] tf.logging.vlog(1, 'reduce_over_dims: %s', reduce_over_dims) # [B, T, 1, N, 1] or [B, T, N, 1] sum_v = tf.reduce_sum(inputs, reduce_over_dims, keepdims=True) sum_v = tf.math.cumsum(sum_v, axis=1) sum_v += cached_sum # [B, T, 1, 1, 1] or [B, T, 1, 1] count_v = tf.reduce_sum(mask, reduce_over_dims, keepdims=True) count_v = tf.math.cumsum(count_v, axis=1) input_shape = py_utils.GetShape(inputs) if input_rank == 4: # F * G multiplier = input_shape[-1] * input_shape[-3] else: # G multiplier = input_shape[-1] count_v *= multiplier count_v += cached_count count_v = tf.maximum(count_v, 1.0) tf.logging.vlog(1, 'sum_v: %r', sum_v) tf.logging.vlog(1, 'count_v: %r', count_v) mean = sum_v / count_v if py_utils.FLAGS.tflite_compatible: # TfLite doesn't support broadcasting with 5D tensors. inputs_shape = py_utils.GetShape(inputs) if len(inputs_shape) == 4: tiled_mean = tf.tile(mean, [1, 1, 1, inputs_shape[3]]) else: tiled_mean = tf.tile( mean, [1, 1, inputs_shape[2], 1, inputs_shape[4]]) sum_vv = tf.reduce_sum(tf.math.square(inputs - tiled_mean) * mask, reduce_over_dims, keepdims=True) else: sum_vv = tf.reduce_sum((inputs - mean)**2 * mask, reduce_over_dims, keepdims=True) sum_vv = tf.math.cumsum(sum_vv, axis=1) sum_vv += cached_var cached_sum = sum_v[:, -1:] cached_count = count_v[:, -1:] cached_var = sum_vv[:, -1:] variance = py_utils.with_dependencies([ py_utils.assert_greater_equal(sum_vv, tf.cast(0, sum_vv.dtype)), ], sum_vv / count_v) return mean, variance, cached_sum, cached_count, cached_var
def ComputeLoss(self, theta, predictions, input_batch): """Computes loss and other metrics for the given predictions. Args: theta: A `.NestedMap` object containing variable values of this task. predictions: The output of `ComputePredictions`, contains: logits - [b, nx, ny, nz, na, 7 + num_classes]. na is the number of anchor boxes per cell. [..., :7] are (dx, dy, dz, dw, dl, dh, dt). input_batch: The input batch from which we accesses the groundtruth. Returns: Two dicts defined as BaseTask.ComputeLoss. """ p = self.params predicted_residuals = py_utils.HasShape( predictions.residuals, [-1, -1, -1, -1, p.num_anchors, 7]) predicted_class_logits = py_utils.HasShape( predictions.classification_logits, [-1, -1, -1, -1, p.num_anchors, p.num_classes]) bs, nx, ny, nz, na, _ = py_utils.GetShape(predicted_class_logits, 6) # Compute class and regression weights. class_weights = input_batch.assigned_cls_mask class_weights = py_utils.HasShape(class_weights, [bs, nx, ny, nz, na]) reg_weights = input_batch.assigned_reg_mask reg_weights = py_utils.HasShape(reg_weights, [bs, nx, ny, nz, na]) reg_weights = tf.expand_dims(reg_weights, -1) if p.loss_norm_type == LossNormType.NORM_BY_NUM_POSITIVES: # Compute number of positive anchors per example. foreground_mask = py_utils.HasShape(input_batch.assigned_reg_mask, [bs, nx, ny, nz, na]) # Sum to get the number of foreground anchors for each example. loss_normalization = tf.reduce_sum(foreground_mask, axis=[1, 2, 3, 4]) loss_normalization = tf.maximum(loss_normalization, tf.ones_like(loss_normalization)) # Reshape for broadcasting. loss_normalization = tf.reshape(loss_normalization, [bs, 1, 1, 1, 1, 1]) class_weights /= loss_normalization reg_weights /= loss_normalization # Classification loss. assigned_gt_labels = py_utils.HasShape(input_batch.assigned_gt_labels, [bs, nx, ny, nz, na]) class_loss = py_utils.SigmoidCrossEntropyFocalLoss( logits=predicted_class_logits, labels=tf.one_hot(assigned_gt_labels, p.num_classes), alpha=p.focal_loss_alpha, gamma=p.focal_loss_gamma) class_loss *= class_weights[..., tf.newaxis] class_loss_sum = tf.reduce_sum(class_loss) # Regression loss. anchor_localization_residuals = py_utils.HasShape( input_batch.anchor_localization_residuals, [bs, nx, ny, nz, na, 7]) # Location and dimensions loss. reg_loc_and_dims_loss = self._utils.ScaledHuberLoss( predictions=py_utils.HasShape(predicted_residuals[..., :6], [bs, nx, ny, nz, na, 6]), labels=anchor_localization_residuals[..., :6], delta=1 / (3.**2)) # Rotation loss with SmoothL1(sin(delta)). rot_delta = (predicted_residuals[..., 6:] - input_batch.anchor_localization_residuals[..., 6:]) if p.use_atan2_heading_loss: atan2_of_delta = tf.atan2(tf.sin(rot_delta), tf.cos(rot_delta)) reg_rot_loss = self._utils.ScaledHuberLoss( predictions=atan2_of_delta, labels=tf.zeros_like(atan2_of_delta), delta=1 / (3.**2)) else: # Rotation loss with SmoothL1(sin(delta)). reg_rot_loss = self._utils.ScaledHuberLoss( predictions=tf.sin(rot_delta), labels=tf.zeros_like(rot_delta), delta=1 / (3.**2)) # Direction loss if p.direction_classifier_weight > 0.0: # The target rotations are in the assigned_gt_bbox tensor, # which already has assigned a gt bounding box to every anchor. rot_target = input_batch.assigned_gt_bbox[..., 6] # If rotation is > 0, the class is 1, else it is 0. rot_dir = tf.cast(rot_target > 0., tf.int32) # Compute one-hot labels as a target. rot_dir_onehot = tf.one_hot(rot_dir, 2) # Manually handle loss reduction. dir_loss = tf.losses.softmax_cross_entropy( onehot_labels=rot_dir_onehot, logits=predictions.predicted_dir, weights=tf.squeeze(reg_weights, axis=-1), reduction=tf.losses.Reduction.NONE) # Reduce across all dimensions (we'll divide by the batch size below). dir_loss_sum = tf.reduce_sum(dir_loss) else: dir_loss_sum = 0.0 # Compute loss contribution from location and dimension separately. reg_loc_loss = reg_loc_and_dims_loss[..., :3] * reg_weights reg_loc_loss_sum = tf.reduce_sum(reg_loc_loss) reg_dim_loss = reg_loc_and_dims_loss[..., 3:6] * reg_weights reg_dim_loss_sum = tf.reduce_sum(reg_dim_loss) # Compute rotation loss contribution. reg_rot_loss *= reg_weights reg_rot_loss_sum = tf.reduce_sum(reg_rot_loss) # Num. predictions. # TODO(zhifengc): Consider other normalization factors. E.g., # of bboxes. preds = tf.cast(bs, class_loss_sum.dtype) # Normalize all of the components by batch size. reg_loc_loss = reg_loc_loss_sum / preds reg_dim_loss = reg_dim_loss_sum / preds reg_rot_loss = reg_rot_loss_sum / preds class_loss = class_loss_sum / preds dir_loss = dir_loss_sum / preds # Compute total localization regression loss. reg_loss = (p.location_loss_weight * reg_loc_loss + p.dimension_loss_weight * reg_dim_loss + p.rotation_loss_weight * reg_rot_loss) # Apply weights to normalized class losses. loss = (class_loss * p.classification_loss_weight + reg_loss * p.localization_loss_weight + dir_loss * p.direction_classifier_weight) metrics_dict = { 'loss': (loss, preds), 'loss/class': (class_loss, preds), 'loss/reg': (reg_loss, preds), 'loss/reg/rot': (reg_rot_loss, preds), 'loss/reg/loc': (reg_loc_loss, preds), 'loss/reg/dim': (reg_dim_loss, preds), 'loss/dir': (dir_loss, preds), } # Calculate dimension errors min_angle_rad = -np.pi if p.use_atan2_heading_loss else 0 gt_bboxes = self._utils_3d.ResidualsToBBoxes( input_batch.anchor_bboxes, anchor_localization_residuals, min_angle_rad=min_angle_rad, max_angle_rad=np.pi) predicted_bboxes = self._utils_3d.ResidualsToBBoxes( input_batch.anchor_bboxes, predicted_residuals, min_angle_rad=min_angle_rad, max_angle_rad=np.pi) dimension_errors_dict = self._BBoxDimensionErrors( gt_bboxes, predicted_bboxes, reg_weights) metrics_dict.update(dimension_errors_dict) per_example_dict = { 'residuals': predicted_residuals, 'classification_logits': predicted_class_logits, } return metrics_dict, per_example_dict
def FProp(self, theta, input_batch): # pyformat: disable """Compute features for the pillars and convert them back to a dense grid. Args: theta: A `.NestedMap` object containing variable values of this task. input_batch: A `.NestedMap` object containing input tensors. Following keys are required: - grid_num_points: Integer tensor with shape [batch size, nx, ny, nz], where nx, ny, nz corresponds to the grid sizes (i.e., number of voxels in each axis dimension). - pillar_points: Float tensor with shape [batch size, num_pillars, num_points_per_pillar, 3 + num_laser_features] - pillar_centers: Float tensor with shape [batch size, num_pillars, num_points_per_pillar, 3] - pillar_locations: Float tensor with shape [batch size, num_pillars, 3] Returns: The dense features with shape [b, nx, ny, nz * fdims]. """ # pyformat: enable p = self.params bs, nx, ny, nz = py_utils.GetShape(input_batch.grid_num_points, 4) # Process points to concatenate a set of fixed features (e.g., # add means, centers, normalize points to means). num_features = 3 + p.num_laser_features pillar_points = py_utils.HasShape(input_batch.pillar_points, [bs, -1, -1, num_features]) _, npillars, npoints, _ = py_utils.GetShape(pillar_points, 4) pillar_xyz = pillar_points[..., :3] # Compute number of points per pillar and prepare for broadcasting. pillar_num_points = tf.gather_nd(input_batch.grid_num_points, input_batch.pillar_locations, batch_dims=1) pillar_num_points = pillar_num_points[..., tf.newaxis, tf.newaxis] # Compute mean by computing sum and dividing by number of points. Clip the # denominator by 1.0 to gracefully handle empty pillars. pillar_sum = tf.reduce_sum(pillar_xyz, axis=2, keep_dims=True) pillar_means = pillar_sum / tf.maximum( tf.cast(pillar_num_points, tf.float32), 1.0) pillar_feats = pillar_points[..., 3:] pillar_centers = py_utils.HasShape(input_batch.pillar_centers, [bs, -1, 1, 3]) pillar_concat = tf.concat(axis=3, values=[ pillar_xyz - pillar_means, pillar_feats, tf.tile(pillar_means, [1, 1, npoints, 1]), tf.tile(pillar_centers, [1, 1, npoints, 1]) ]) # Featurize pillars. pillar_features = self.featurizer.FProp(theta.featurizer, pillar_concat) # Convert back to the dense grid. pillar_locations = py_utils.HasShape(input_batch.pillar_locations, [bs, npillars, 3]) dense_features = SparseToDense(grid_shape=(nx, ny, nz), locations=pillar_locations, feats=pillar_features) return dense_features
def _GetWarpMatrix(self, batch_size, choose_range, matrix_size, global_seed, max_warp_frames=None, dtype=tf.float32, max_ratio=1.0): """Returns warp matrices starting from random positions. In this function when max_warp_frames != None: 1) Sample random warp displacements from the interval [-max_warp_frames, max_warp_frames) to yield shift tensor with shape (batch_size,). 2) Truncate lengths to a maximum magnitude of (choose_range * max_ratio), so that each shift is fully contained within the corresponding sequence. 3) Random sample origin points of shape (batch_size, multiplicity) with in [shift, choose_range - shift). 4) Return a batch of 1-D linear maps that fix the boundary points and shift the origin point by the shift. When max_warp_frames == None: 1) Sample random warp displacements with magnitudes less than (choose_range * max_ratio) to yield shift tensor with shape (batch_size,). 2) Proceed through steps 3), 4). Args: batch_size: Batch size. Integer number. choose_range: Range within which the warp reference points must lie. Tensor of shape (batch_size,). matrix_size: Dimension of vector space warp matrix is applied to. Integer number. global_seed: an integer seed tensor for stateless random ops. max_warp_frames: Upper-bound on the warp distance. Integer or None. dtype: Data type. max_ratio: Maximum ratio between the shift distance and choose_range. Float number. Returns: warp_matrix: An array of fixed size warp matrices with shape (batch_size, matrix_size, matrix_size). """ p = self.params # Non-empty random seed values are only used for testing or when using # stateless random ops. seed_3, seed_4, and seed_5 are set separately to # avoid correlation of warp magnitude and origin position. if p.use_input_dependent_random_seed: seed_3 = global_seed + 3 seed_4 = global_seed + 4 seed_5 = global_seed + 5 elif p.random_seed: seed_3 = p.random_seed - 1 seed_4 = p.random_seed - 1 seed_5 = 2 * p.random_seed + 1 else: seed_3 = p.random_seed seed_4 = p.random_seed seed_5 = p.random_seed choose_range_dtype = tf.cast(choose_range, dtype=dtype) length_upper_bound = tf.cast(max_ratio * choose_range_dtype, dtype=tf.int32) # Set shift length. random_uniform = _random_uniform_op(p.use_input_dependent_random_seed) if max_warp_frames and max_warp_frames > 0: shift = random_uniform(shape=(batch_size, ), minval=-1 * max_warp_frames, maxval=max_warp_frames + 1, dtype=tf.int32, seed=seed_3) else: random_ratio = random_uniform(shape=(batch_size, ), minval=-1.0, maxval=1.0, dtype=dtype, seed=seed_4) shift = tf.cast( random_ratio * tf.cast(length_upper_bound, dtype=dtype), tf.int32) # Make sure the sampled length was smaller than max_ratio * length_bound. # Note that sampling in this way is biased. # (Shorter sequence may over-masked.) final_shift = tf.maximum(-length_upper_bound, tf.minimum(shift, length_upper_bound)) # Choose origin anchor point. mid_range = tf.cast(choose_range, dtype=tf.int32) mid_range = tf.maximum(choose_range - 2, 0) random_origin = random_uniform(shape=(batch_size, ), maxval=1.0, seed=seed_5) origin_with_in_valid_range = random_origin * tf.cast(mid_range, dtype=dtype) origin = tf.cast(origin_with_in_valid_range, tf.int32) + 1 # Set destination point of the origin anchor point under the warp map. destination = origin + final_shift # Cast origin and destination. origin = tf.cast(origin, dtype=dtype) destination = tf.cast(destination, dtype=dtype) return self._ConstructWarpMatrix(batch_size=batch_size, matrix_size=matrix_size, origin=origin, destination=destination, choose_range=choose_range_dtype, dtype=dtype)
def FProp(self, theta, current_step): p = self.params step_num = tf.cast(current_step, tf.float32) learning_rate = tf.math.rsqrt(tf.maximum(step_num, p.warmup_steps)) learning_rate *= p.multiplier return learning_rate
def _internal_apply_dense(self, grad, var, magnitude_optimizer_apply_fn, direction_optimizer_apply_fn): # pylint: disable=g-doc-args """Main optimization logic of AdaGraft, which calls the child optimizers. Args: grad: Tensor containing gradients. var: Tensor containing parameter values. magnitude_optimizer_apply_fn: Apply magnitude optimizer. direction_optimizer_apply_fn: Apply direction optimizer. Returns: The final update op, which increments var by the grafted step. Pseudocode: - Copy weights into scratch space 'scratch_copy'. - Run magnitude_optimizer in-place. - Use scratch copy to figure out how far we moved ('magnitude_step'). - Copy weights back. - Run direction_optimizer in-place. - Move weights along the line segment with scratch_copy. """ if self.use_global_norm: self._variables.append(var) # Slot with current parameter values scratch_slot = self.get_slot(var, "scratch_copy") old_var = tf.assign(scratch_slot, var) with tf.control_dependencies([old_var]): m_updated_var = magnitude_optimizer_apply_fn(grad, var) # pylint: disable=protected-access # Run magnitude optimizer and compute the norm of the update. with tf.control_dependencies([m_updated_var]): m_step = var - old_var m_step_norm = tf.norm(m_step) if self.diagnostic or self.use_global_norm: m_step_norm = tf.assign(self.get_slot(var, "m_step_norm"), m_step_norm) # Run direction optimizer and compute its norm, and the direction. with tf.control_dependencies([m_step_norm]): flushed_var = tf.assign(var, old_var) with tf.control_dependencies([flushed_var]): d_updated_var = direction_optimizer_apply_fn(grad, var) # pylint: disable=protected-access # Run an update of the direction optimizer with magnitude optimizer norm. with tf.control_dependencies([d_updated_var]): d_step = var - old_var d_step_norm = tf.norm(d_step) if self.diagnostic or self.use_global_norm: d_step_norm = tf.assign(self.get_slot(var, "d_step_norm"), d_step_norm) if self.use_global_norm: flushed_var = tf.assign(var, old_var) with tf.control_dependencies([d_step_norm, flushed_var]): return tf.assign(scratch_slot, d_step) step = tf.where(tf.greater(d_step_norm, 0), (m_step_norm / tf.maximum(d_step_norm, 1e-30)) * d_step, tf.zeros_like(d_step)) return tf.assign(var, old_var + self._learning_rate_tensor * step)
def ComputeLoss(self, theta, predictions, input_batch): """Compute loss for the sparse detector model v1. Args: theta: A `.NestedMap` object containing variable values of this task. predictions: A `.NestedMap` object containing residuals and classification_logits. input_batch: A `.NestedMap` expected to contain cell_center_xyz, cell_points_xyz, cell_feature, anchor_bboxes, anchor_localization_residuals, assigned_gt_labels, and assigned_cls_mask. See class doc string for details. Returns: Two dicts: - A dict containing str keys and (metric, weight) pairs as values, where one of the keys is expected to be 'loss'. - A dict containing arbitrary tensors describing something about each training example, where the first dimension of each tensor is the batch index. """ p = self.params batch_size, num_centers = py_utils.GetShape( input_batch.cell_center_xyz, 2) # Assert shapes of inputs. anchor_bboxes = py_utils.HasShape( input_batch.anchor_bboxes, [batch_size, num_centers, p.num_anchor_bboxes_per_center, 7]) anchor_localization_residuals = py_utils.HasShape( input_batch.anchor_localization_residuals, [batch_size, num_centers, p.num_anchor_bboxes_per_center, 7]) predicted_residuals = py_utils.HasShape( predictions.residuals, [batch_size, num_centers, p.num_anchor_bboxes_per_center, 7]) assigned_gt_labels = py_utils.HasShape( input_batch.assigned_gt_labels, [batch_size, num_centers, p.num_anchor_bboxes_per_center]) predicted_classification_logits = py_utils.HasShape( predictions.classification_logits, [ batch_size, num_centers, p.num_anchor_bboxes_per_center, p.num_classes ]) # assigned_cls_mask is for weighting the classification loss. # Ignored targets will have their mask = 0; this happens when their IOU is # not high enough to be a foreground object and not low enough to be # background. class_weights = py_utils.HasShape( input_batch.assigned_cls_mask, [batch_size, num_centers, p.num_anchor_bboxes_per_center]) class_weights = tf.reshape( class_weights, [batch_size, num_centers, p.num_anchor_bboxes_per_center, 1]) # Broadcast per class loss weights. For each anchor, there are num_classes # prediction heads, we weight the outputs of these heads by the per class # loss weights. per_class_loss_weight = tf.constant([[[p.per_class_loss_weight]]], dtype=tf.float32) per_class_loss_weight = py_utils.HasShape(per_class_loss_weight, [1, 1, 1, p.num_classes]) class_weights *= per_class_loss_weight class_weights = py_utils.HasShape(class_weights, [ batch_size, num_centers, p.num_anchor_bboxes_per_center, p.num_classes ]) # We use assigned_reg_mask for masking the regression loss. # Only foreground objects will have assigned_reg_mask = 1. reg_weights = py_utils.HasShape( input_batch.assigned_reg_mask, [batch_size, num_centers, p.num_anchor_bboxes_per_center]) reg_weights = tf.reshape( reg_weights, [batch_size, num_centers, p.num_anchor_bboxes_per_center, 1]) if p.loss_norm_type == LossNormType.NORM_BY_NUM_POS_PER_CENTER: # Compute number of positive anchors per example. foreground_mask = py_utils.HasShape( input_batch.assigned_reg_mask, [batch_size, num_centers, p.num_anchor_bboxes_per_center]) # Sum to get the number of foreground anchors for each example. loss_normalization = tf.reduce_sum(foreground_mask, axis=2) loss_normalization = tf.maximum(loss_normalization, tf.ones_like(loss_normalization)) # Reshape for broadcasting. loss_normalization = tf.reshape(loss_normalization, [batch_size, num_centers, 1, 1]) # Normalize so that the loss is independent of # centers. loss_normalization *= num_centers class_weights /= loss_normalization reg_weights /= loss_normalization classification_loss = py_utils.SigmoidCrossEntropyFocalLoss( logits=predicted_classification_logits, labels=tf.one_hot(assigned_gt_labels, p.num_classes), alpha=p.focal_loss_alpha, gamma=p.focal_loss_gamma) # Apply mask. classification_loss *= class_weights # TODO(jngiam): Consider normalizing by num_foreground_anchors for each # example instead. This would match the 1/N_positive normalization in # point pillars. # Reduce sum over centers, boxes and classes. classification_loss = tf.reduce_sum(classification_loss, axis=[1, 2, 3]) # Reduce mean over batch. classification_loss = tf.reduce_mean(classification_loss) # Localization regression loss with Huber loss (SmoothL1). regression_loc_and_dims_loss = self._utils_3d.ScaledHuberLoss( labels=anchor_localization_residuals[..., :6], predictions=predicted_residuals[..., :6], delta=p.huber_loss_delta) # Rotation loss is computed on a transform on rotation_delta. For a # direction aware loss, we simply wrap the angles to -pi to pi; for a loss # that is symmetric to direction (i.e., rotating by pi), we use a sin # transform. rotation_delta_transform = tf.sin if p.direction_aware_rot_loss: rotation_delta_transform = functools.partial(geometry.WrapAngleRad, min_val=-np.pi, max_val=np.pi) rotation_delta = (predicted_residuals[..., 6:] - anchor_localization_residuals[..., 6:]) regression_rotation_loss = self._utils_3d.ScaledHuberLoss( labels=tf.zeros_like(rotation_delta), predictions=rotation_delta_transform(rotation_delta), delta=p.huber_loss_delta) reg_loc_loss = regression_loc_and_dims_loss[..., :3] reg_dim_loss = regression_loc_and_dims_loss[..., 3:6] gt_bboxes = self._utils_3d.ResidualsToBBoxes( anchor_bboxes, anchor_localization_residuals, min_angle_rad=-np.pi, max_angle_rad=np.pi) predicted_bboxes = self._utils_3d.ResidualsToBBoxes( anchor_bboxes, predicted_residuals, min_angle_rad=-np.pi, max_angle_rad=np.pi) # Apply mask to individual losses. # # And then reduce sum over centers, boxes, residuals, and batch # and divide by the batch_size. regression_rotation_loss *= reg_weights reg_rot_loss = tf.reduce_sum(regression_rotation_loss) / batch_size reg_loc_loss *= reg_weights reg_loc_loss = tf.reduce_sum(reg_loc_loss) / batch_size reg_dim_loss *= reg_weights reg_dim_loss = tf.reduce_sum(reg_dim_loss) / batch_size # Do not create corner loss graph if weight is 0.0 # TODO(bcyang): Remove condition after fixing corner loss NaN issue if p.corner_loss_weight != 0.0: reg_corner_loss = self._utils_3d.CornerLoss( gt_bboxes=gt_bboxes, predicted_bboxes=predicted_bboxes) reg_corner_loss = tf.expand_dims(reg_corner_loss, axis=-1) reg_corner_loss *= reg_weights reg_corner_loss = tf.reduce_sum(reg_corner_loss) / batch_size else: reg_corner_loss = 0.0 # Sum components of regression loss. regression_loss = (p.location_loss_weight * reg_loc_loss + p.dimension_loss_weight * reg_dim_loss + p.rotation_loss_weight * reg_rot_loss + p.corner_loss_weight * reg_corner_loss) # Compute total loss. total_loss = (p.loss_weight_localization * regression_loss + p.loss_weight_classification * classification_loss) metrics_dict = py_utils.NestedMap({ 'loss': (total_loss, batch_size), 'loss/regression': (regression_loss, batch_size), 'loss/regression/loc': (reg_loc_loss, batch_size), 'loss/regression/dim': (reg_dim_loss, batch_size), 'loss/regression/rot': (reg_rot_loss, batch_size), 'loss/regression/corner': (reg_corner_loss, batch_size), 'loss/classification': (classification_loss, batch_size), }) # Calculate dimension errors dimension_errors_dict = self._BBoxDimensionErrors( gt_bboxes, predicted_bboxes, reg_weights) metrics_dict.update(dimension_errors_dict) per_example_dict = py_utils.NestedMap({ 'residuals': predicted_residuals, 'classification_logits': predicted_classification_logits, 'predicted_bboxes': predicted_bboxes, 'gt_bboxes': gt_bboxes, 'reg_weights': reg_weights, }) return metrics_dict, per_example_dict
def _GetMask(self, batch_size, choose_range, mask_size, global_seed, max_length=None, masks_per_frame=0.0, multiplicity=1, dtype=tf.float32, max_ratio=1.0): """Returns fixed size multi-masks starting from random positions. A multi-mask is a mask obtained by applying multiple masks. This function when max_length is given: 1) Sample random mask lengths less than max_length with shape (batch_size, multiplicity). 2) Truncate lengths to a max of (choose_range * max_ratio), so that each mask is fully contained within the corresponding sequence. 3) Random sample start points of shape (batch_size, multiplicity) with in (choose_range - lengths). 4) For each batch, multiple masks (whose number is given by the multiplicity) are constructed. 5) Return a mask of shape (batch_size, mask_size) where masks are obtained by composing the masks constructed in step 4). If masks_per_frame > 0, the number is given by min(masks_per_frame * choose_range, multiplicity). If not, all the masks are composed. The masked regions are set to zero. This function when max_length is not given: 1) Sample random mask lengths less than (choose_range * max_ratio) with shape (batch_size, multiplicity). 2) Proceed to steps 3), 4) and 5) of the above. Args: batch_size: Batch size. Integer number. choose_range: Range within which the masked entries must lie. Tensor of shape (batch_size,). mask_size: Size of the mask. Integer number. global_seed: an integer seed tensor for stateless random ops. max_length: Maximum number of allowed consecutive masked entries. Integer number or None. masks_per_frame: Number of masks per frame. Float number. If > 0, the multiplicity of the mask is set to be masks_per_frame * choose_range. multiplicity: Maximum number of total masks. Integer number. dtype: Data type. max_ratio: Maximum portion of the entire range allowed to be masked. Float number. Returns: mask: a fixed size multi-mask starting from a random position with shape (batch_size, mask_size). """ p = self.params # Non-empty random seed values are only used for testing or when using # stateless random ops. seed_1 and seed_2 are set separately to avoid # correlation of mask size and mask position. if p.use_input_dependent_random_seed: seed_1 = global_seed + 1 seed_2 = global_seed + 2 elif p.random_seed: seed_1 = p.random_seed + 1 seed_2 = 2 * p.random_seed else: seed_1 = p.random_seed seed_2 = p.random_seed # Sample lengths for multiple masks. if max_length and max_length > 0: max_length = tf.broadcast_to(tf.cast(max_length, dtype), (batch_size, )) else: max_length = tf.cast(choose_range, dtype=dtype) * max_ratio random_uniform = _random_uniform_op(p.use_input_dependent_random_seed) masked_portion = random_uniform(shape=(batch_size, multiplicity), minval=0.0, maxval=1.0, dtype=dtype, seed=seed_1) masked_frame_size = self.EinsumBBmBm(max_length, masked_portion) masked_frame_size = tf.cast(masked_frame_size, dtype=tf.int32) # Make sure the sampled length was smaller than max_ratio * length_bound. # Note that sampling in this way was biased # (shorter sequence may over-masked.) choose_range = tf.expand_dims(choose_range, -1) choose_range = tf.tile(choose_range, [1, multiplicity]) length_bound = tf.cast(choose_range, dtype=dtype) length_bound = tf.cast(max_ratio * length_bound, dtype=tf.int32) length = tf.minimum(masked_frame_size, tf.maximum(length_bound, 1)) # Choose starting point. random_start = random_uniform(shape=(batch_size, multiplicity), maxval=1.0, seed=seed_2) start_with_in_valid_range = random_start * tf.cast( (choose_range - length + 1), dtype=dtype) start = tf.cast(start_with_in_valid_range, tf.int32) end = start + length - 1 # Shift starting and end point by small value. delta = tf.constant(0.1) start = tf.expand_dims(tf.cast(start, dtype) - delta, -1) start = tf.tile(start, [1, 1, mask_size]) end = tf.expand_dims(tf.cast(end, dtype) + delta, -1) end = tf.tile(end, [1, 1, mask_size]) # Construct pre-mask of shape (batch_size, multiplicity, mask_size). diagonal = tf.expand_dims( tf.expand_dims(tf.cast(tf.range(mask_size), dtype=dtype), 0), 0) diagonal = tf.tile(diagonal, [batch_size, multiplicity, 1]) pre_mask = tf.cast(tf.math.logical_and(diagonal < end, diagonal > start), dtype=dtype) # Sum masks with appropriate multiplicity. if masks_per_frame > 0: multiplicity_weights = tf.tile( tf.expand_dims(tf.range(multiplicity, dtype=dtype), 0), [batch_size, 1]) multiplicity_tensor = masks_per_frame * tf.cast(choose_range, dtype=dtype) multiplicity_weights = tf.cast( multiplicity_weights < multiplicity_tensor, dtype=dtype) pre_mask = self.EinsumBmtBmBt(pre_mask, multiplicity_weights) else: pre_mask = tf.reduce_sum(pre_mask, 1) mask = tf.cast(1.0 - tf.cast(pre_mask > 0, dtype=dtype), dtype=dtype) if p.fprop_dtype is not None and p.fprop_dtype != p.dtype: mask = tf.cast(mask, p.fprop_dtype) return mask
def _resource_apply_dense(self, grad, var): if grad is None: tf.logging.warning('Gradient is None for variable %s' % var.name) return [] grad_dtype = var.dtype # TODO(lepikhin): add to params grad = tf.cast(grad, grad_dtype) factored_dims = self._factored_dims(var.shape.as_list()) if factored_dims: vr = self.get_slot(var, 'vr') vc = self.get_slot(var, 'vc') else: v = self.get_slot(var, 'v') if self._beta1: m = self.get_slot(var, 'm') cond = tf.constant(True) def _Upd(c, x): if not self._cond_is_finite: return c c = tf.math.logical_and(c, tf.reduce_all(tf.math.is_finite(x))) c = tf.math.logical_and( c, tf.reduce_all(tf.math.logical_not(tf.math.is_inf(x)))) return c def _Wrap(fn, x, y): if not self._cond_is_finite: return fn(x, y) return tf.cond(cond, lambda: fn(x, y), lambda: x) with tf.variable_scope(var.name[:-2] + '/Adafactor'): grad_squared = tf.math.square(grad) + tf.cast( self._epsilon1, grad_dtype) cond = _Upd(cond, grad_squared) decay_rate = tf.cast(self._decay_rate, var.dtype) old_val = tf.identity( var) # TODO(lepikhin): introduce gradient dtype lr = GetLrValue(self._learning_rate) if self._multiply_by_parameter_scale: update_scale = self._parameter_scale(old_val) * tf.cast( lr, grad_dtype) else: update_scale = lr mixing_rate = tf.cast(1.0 - decay_rate, grad_dtype) update_scale = tf.cast(update_scale, grad_dtype) updates = [] if factored_dims: d0, d1 = factored_dims vr_axis, vc_axis = d0, d1 grad_squared_row_mean = tf.reduce_mean(grad_squared, axis=vr_axis) grad_squared_col_mean = tf.reduce_mean(grad_squared, axis=vc_axis) # new_vr = (decay_rate * vr + mixing_rate * grad_squared_row_mean) new_vr = vr * decay_rate + grad_squared_row_mean * mixing_rate # new_vc = (decay_rate * vc + mixing_rate * grad_squared_col_mean) new_vc = vc * decay_rate + grad_squared_col_mean * mixing_rate cond = _Upd(cond, new_vr) cond = _Upd(cond, new_vc) vr_update = _Wrap(tf.assign, vr, new_vr) vc_update = _Wrap(tf.assign, vc, new_vc) updates.extend([vr_update, vc_update]) long_term_mean = tf.reduce_mean(new_vr, -1, keepdims=True) r_factor = tf.math.rsqrt(new_vr / long_term_mean) c_factor = tf.math.rsqrt(new_vc) x = grad * tf.expand_dims(r_factor, vr_axis) * tf.expand_dims( c_factor, vc_axis) else: new_v = v * decay_rate + grad_squared * mixing_rate cond = _Upd(cond, new_v) v_update = _Wrap(tf.assign, v, new_v) updates.append(v_update) x = grad * tf.math.rsqrt(new_v) if self._clipping_threshold is not None: clipping_denom = tf.maximum( tf.constant(1.0, grad_dtype), py_utils.ReduceRms(x) / tf.constant(self._clipping_threshold, grad_dtype)) x /= clipping_denom subtrahend = x * update_scale if self._beta1: new_m = (m * tf.constant(self._beta1, dtype=grad_dtype) + subtrahend * tf.constant(1.0 - self._beta1, dtype=grad_dtype)) subtrahend = new_m cond = _Upd(cond, new_m) updates.append(_Wrap(tf.assign, m, new_m)) # It is critical to use assign_sub instead of tf.assign(var - subtrahend) # for the case of bfloat16 activations, so as to avoid repeatedly # rounding the slice value, which results in poor quality. cond = _Upd(cond, subtrahend) var_update = _Wrap(tf.assign_sub, var, subtrahend) updates.append(var_update) return tf.group(*updates)
def _ConstructWarpMatrix(self, batch_size, matrix_size, origin, destination, choose_range, dtype): """Returns warp matrices according to origin, destination and choose_range. This function constructs a batch of warp matrices which maps the batch of origin points to the batch of destination points with fixed boundary coordinates at 0 and choose_range. The warping function, defined by the origin anchor point `origin`, the destination of the origin anchor point `destination` and the length of the domain in the warping axis `choose_range` is a piecewise linear map that fixes the points 0 and `choose_range` and maps `origin` to `destination`. For the warping matrix to be non-singular, destination must lie in the range 1<= destination <= choose_range - 1, so a destination out of this range is adjusted to be in this range before the warping matrix is constructed. The warping map can be explicitly written by first defining the slopes: 1) slope_0 = origin / destination. 2) slope_1 = (choose_range - origin) / (choose_range - destination). 3) slope_2 = 1.0. Then the origin point orig_i of the mapped coordinate i is given by: 1) i < destination: orig_i = slope_0 * i. 2) destination <= i < choose_range: orig_i = slope_1 * i - (slope_1 - slope_0) * destination. 3) i >= choose_range: orig_i = i. Denoting n_i = ceil(orig_i), the warp matrix element warp[i][j] is given by: 1) j = n_i: 1 - n_i + orig_i. 2) j = n_i - 1: n_i - orig_i. 3) Otherwise: 0. Applying the warp matrix to an array of pixels, i.e., warped_pixel[i] = sum_j warp[i][j] * pixel[j], one would get warped_pixel[i] = (n_i-orig_i) pixel[n_i-1] + (1-n_i+orig_i) pixel[n_i]. Args: batch_size: Batch size. Integer number. matrix_size: Dimension of the vector space the warp matrix is applied to. Integer number. origin: Origin anchor point for warping. Tensor of shape (batch_size,) and data type dtype. destination: Destination of the origin anchor point upon warping. Tensor of shape (batch_size,) and data type dtype. choose_range: Range within which the warp reference points must lie. Tensor of shape (batch_size,) data type dtype. dtype: Data type of origin, destination, choose_range and the output warp matrix. Returns: warp_matrix: An array of fixed size warp matrices with shape (batch_size, matrix_size, matrix_size). """ p = self.params # Entries of destination must be in the range # 1 <= destination <= choose_range - 1 # for warp matrix to have non-singular values. destination = tf.minimum(tf.maximum(destination, 1.0), choose_range - 1.0) # Construct piece-wise linear function fixing boundary points # specified by zero, choose_range and matrix size and maps # the origin anchor point to the destination. destination_bc = tf.broadcast_to(destination, (matrix_size, batch_size)) destination_bc = tf.transpose(destination_bc) choose_range_bc = tf.broadcast_to(choose_range, (matrix_size, batch_size)) choose_range_bc = tf.transpose(choose_range_bc) # Slopes of piece-wise linear function. slope_0 = origin / destination slope_1 = (choose_range - origin) / (choose_range - destination) slope_2 = 1.0 # x is a batch of origin matrices. # The origin matrix is the matrix such that # origin[i][j] = Origin coordinate of coordinate i for the warp map. # Denoting the destination of the origin anchor point in the # warp map as "dest," the origin coordinate of point i is given by: # 1) i < dest: slope_0 * i. # 2) dest <= i < choose_range: slope_1 * i - (slope_1 - slope_0) * dest. # 3) i >= choose_range: i. x = tf.broadcast_to(tf.cast(tf.range(matrix_size), dtype=dtype), (batch_size, matrix_size)) x = (self.EinsumBBmBm(slope_0, x) + self.EinsumBBmBm( slope_1 - slope_0, tf.nn.relu(x - destination_bc)) + self.EinsumBBmBm(slope_2 - slope_1, tf.nn.relu(x - choose_range_bc))) x = tf.broadcast_to(x, (matrix_size, batch_size, matrix_size)) x = tf.transpose(x, perm=[1, 2, 0]) # y is a batch of coordinate matrices. # A coordinate matrix is a matrix such that # coordinate[i][j] = j. y = tf.broadcast_to(tf.cast(tf.range(matrix_size), dtype=dtype), (batch_size, matrix_size, matrix_size)) # Warp matrix is obtained by applying hat function element-wise to (x-y). # Denoting the origin point of i under the warp map as orig_i, # and n_i = ceil(orig_i), the warp matrix element warp[i][j] is given by: # 1) j = n_i: 1 - n_i + orig_i. # 2) j = n_i - 1: n_i - orig_i. # 3) Otherwise: 0. # Applying the warp matrix to pixels, i.e., # warped_pixel[i] = sum_j warp[i][j] * original_pixel[j], one would get # warped_pixel[i] = (n_i - orig_i) * original_pixel[n_i-1] # + (1 - n_i + orig_i) * original_pixel[n_i]. warp_matrix = x - y warp_matrix = _hat(warp_matrix) if p.fprop_dtype is not None and p.fprop_dtype != dtype: warp_matrix = tf.cast(warp_matrix, p.fprop_dtype) return warp_matrix
def try_apply_dense(self, grad, var): assert grad is not None cond = tf.constant(True) is_finite_checks = [] stats = {} grad_dtype = var.dtype # TODO(lepikhin): add to params grad = tf.cast(grad, grad_dtype) factored_dims = self._factored_dims(var.shape.as_list()) if factored_dims: vr = self.get_slot(var, 'vr') vc = self.get_slot(var, 'vc') else: v = self.get_slot(var, 'v') if self._beta1: m = self.get_slot(var, 'm') def _Upd(c, k, x): stats[k] = x is_finite_checks.append(tf.reduce_all(tf.math.is_finite(x))) return c with tf.variable_scope(var.name[:-2] + '/Adafactor'): grad_squared = tf.math.square(grad) + tf.cast( self._epsilon1, grad_dtype) cond = _Upd(cond, 'grad_squared', grad_squared) # 0 (factored) decay_rate = tf.cast(self._decay_rate, var.dtype) old_val = tf.identity( var) # TODO(lepikhin): introduce gradient dtype assert self._multiply_by_parameter_scale lr = GetLrValue(self._learning_rate) if self._multiply_by_parameter_scale: parameter_scale = self._parameter_scale(old_val) cond = _Upd(cond, 'parameter_scale', parameter_scale) # 1 (factored) update_scale = self._parameter_scale(old_val) * tf.cast( lr, grad_dtype) else: update_scale = lr mixing_rate = tf.cast(1.0 - decay_rate, grad_dtype) update_scale = tf.cast(update_scale, grad_dtype) if factored_dims: d0, d1 = factored_dims vr_axis, vc_axis = d0, d1 grad_squared_row_mean = tf.reduce_mean(grad_squared, axis=vr_axis) grad_squared_col_mean = tf.reduce_mean(grad_squared, axis=vc_axis) # new_vr = (decay_rate * vr + mixing_rate * grad_squared_row_mean) new_vr = vr * decay_rate + grad_squared_row_mean * mixing_rate # new_vc = (decay_rate * vc + mixing_rate * grad_squared_col_mean) new_vc = vc * decay_rate + grad_squared_col_mean * mixing_rate cond = _Upd(cond, 'new_vr', new_vr) # 2 (factored) cond = _Upd(cond, 'new_vc', new_vc) # 3 (factored) # vr_update = _Wrap(tf.assign, vr, new_vr) # vc_update = _Wrap(tf.assign, vc, new_vc) # updates.extend([vr_update, vc_update]) long_term_mean = tf.reduce_mean(new_vr, -1, keepdims=True) r_factor = tf.math.rsqrt(new_vr / long_term_mean) c_factor = tf.math.rsqrt(new_vc) mult = tf.expand_dims(r_factor, vr_axis) * tf.expand_dims( c_factor, vc_axis) cond = _Upd(cond, 'mult', mult) # 4 (factored) x = grad * mult else: new_v = v * decay_rate + grad_squared * mixing_rate cond = _Upd(cond, 'new_v', new_v) # v_update = _Wrap(tf.assign, v, new_v) # updates.append(v_update) x = grad * tf.math.rsqrt(new_v) assert self._clipping_threshold is not None if self._clipping_threshold is not None: clipping_denom = tf.maximum( tf.constant(1.0, grad_dtype), py_utils.ReduceRms(x) / tf.constant(self._clipping_threshold, grad_dtype)) x /= clipping_denom cond = _Upd(cond, 'x', x) subtrahend = x * update_scale if self._beta1: new_m = (m * tf.constant(self._beta1, dtype=grad_dtype) + subtrahend * tf.constant(1.0 - self._beta1, dtype=grad_dtype)) subtrahend = new_m cond = _Upd(cond, 'new_m', new_m) # updates.append(_Wrap(tf.assign, m, new_m)) # It is critical to use assign_sub instead of tf.assign(var - subtrahend) # for the case of bfloat16 activations, so as to avoid repeatedly # rounding the slice value, which results in poor quality. cond = _Upd(cond, 'subtrahend', subtrahend) # 5 (factored) # var_update = _Wrap(tf.assign_sub, var, subtrahend) # updates.append(var_update) return is_finite_checks, stats