def _GetMask(self, batch_size, choose_range, mask_size, max_length=None, masks_per_frame=0.0, multiplicity=1, dtype=tf.float32, max_ratio=1.0): """Returns fixed size multi-masks starting from random positions. A multi-mask is a mask obtained by applying multiple masks. This function when max_length is given: 1) Sample random mask lengths less than max_length with shape (batch_size, multiplicity). 2) Truncate lengths to a max of (choose_range * max_ratio), so that each mask is fully contained within the corresponding sequence. 3) Random sample start points of shape (batch_size, multiplicity) with in (choose_range - lengths). 4) For each batch, multiple masks (whose number is given by the multiplicity) are constructed. 5) Return a mask of shape (batch_size, mask_size) where masks are obtained by composing the masks constructed in step 4). If masks_per_frame > 0, the number is given by min(masks_per_frame * choose_range, multiplicity). If not, all the masks are composed. The masked regions are set to zero. This function when max_length is not given: 1) Sample random mask lengths less than (choose_range * max_ratio) with shape (batch_size, multiplicity). 2) Proceed to steps 3), 4) and 5) of the above. Args: batch_size: Batch size. Integer number. choose_range: Range within which the masked entries must lie. Tensor of shape (batch_size,). mask_size: Size of the mask. Integer number. max_length: Maximum number of allowed consecutive masked entries. Integer number or None. masks_per_frame: Number of masks per frame. Float number. If > 0, the multiplicity of the mask is set to be masks_per_frame * choose_range. multiplicity: Maximum number of total masks. Integer number. dtype: Data type. max_ratio: Maximum portion of the entire range allowed to be masked. Float number. Returns: mask: a fixed size multi-mask starting from a random position with shape (batch_size, mask_size). """ p = self.params # Non-empty random seed values are only used for testing # seed_1 and seed_2 are set separately to avoid correlation of # mask size and mask position. if p.random_seed: seed_1 = p.random_seed + 1 seed_2 = 2 * p.random_seed else: seed_1 = p.random_seed seed_2 = p.random_seed # Sample lengths for multiple masks. if max_length and max_length > 0: max_length = tf.broadcast_to(tf.cast(max_length, dtype), (batch_size,)) else: max_length = tf.cast(choose_range, dtype=dtype) * max_ratio masked_portion = tf.random.uniform((batch_size, multiplicity), minval=0.0, maxval=1.0, dtype=dtype, seed=seed_1) masked_frame_size = tf.einsum('b,bm->bm', max_length, masked_portion) masked_frame_size = tf.cast(masked_frame_size, dtype=tf.int32) # Make sure the sampled length was smaller than max_ratio * length_bound. # Note that sampling in this way was biased # (shorter sequence may over-masked.) choose_range = tf.expand_dims(choose_range, -1) choose_range = tf.tile(choose_range, [1, multiplicity]) length_bound = tf.cast(choose_range, dtype=dtype) length_bound = tf.cast(max_ratio * length_bound, dtype=tf.int32) length = tf.minimum(masked_frame_size, tf.maximum(length_bound, 1)) # Choose starting point. random_start = tf.random.uniform((batch_size, multiplicity), maxval=1.0, seed=seed_2) start_with_in_valid_range = random_start * tf.cast( (choose_range - length + 1), dtype=dtype) start = tf.cast(start_with_in_valid_range, tf.int32) end = start + length - 1 # Shift starting and end point by small value. delta = tf.constant(0.1) start = tf.expand_dims(tf.cast(start, dtype) - delta, -1) start = tf.tile(start, [1, 1, mask_size]) end = tf.expand_dims(tf.cast(end, dtype) + delta, -1) end = tf.tile(end, [1, 1, mask_size]) # Construct pre-mask of shape (batch_size, multiplicity, mask_size). diagonal = tf.expand_dims( tf.expand_dims(tf.cast(tf.range(mask_size), dtype=dtype), 0), 0) diagonal = tf.tile(diagonal, [batch_size, multiplicity, 1]) pre_mask = tf.cast( tf.logical_and(diagonal < end, diagonal > start), dtype=dtype) # Sum masks with appropriate multiplicity. if masks_per_frame > 0: multiplicity_weights = tf.tile( tf.expand_dims(tf.range(multiplicity, dtype=dtype), 0), [batch_size, 1]) multiplicity_tensor = masks_per_frame * tf.cast(choose_range, dtype=dtype) multiplicity_weights = tf.cast( multiplicity_weights < multiplicity_tensor, dtype=dtype) pre_mask = tf.einsum('bmt,bm->bt', pre_mask, multiplicity_weights) else: pre_mask = tf.reduce_sum(pre_mask, 1) mask = tf.cast(1.0 - tf.cast(pre_mask > 0, dtype=dtype), dtype=dtype) if p.fprop_dtype is not None and p.fprop_dtype != p.dtype: mask = tf.cast(mask, p.fprop_dtype) return mask
def __init__(self, learning_rate, momentum=0.0, initial_accumulator_value=0.0, start_preconditioning_steps=1000, statistics_computation_frequency=1, matrix_epsilon=1e-6, synchronous_preconditioning=False, second_moment_averaging=1.0, fallback_to_diagonal_dim=4096, max_any_dim=6656, block_size=4096, block_partition_threshold_size=1000000, global_step=None, exponent_multiplier=1.0, name="DistributedShampoo"): """Construct a DistributedShampoo optimizer. Args: learning_rate: A `Tensor` or a floating point value. The learning rate. momentum: A `Tensor` or a floating point value. Momentum is not applied to sparse updates. initial_accumulator_value: A floating point value. start_preconditioning_steps: A int32 value which indicates when to start preconditioning. statistics_computation_frequency: A int32 step value which indicates how often to compute statistics for preconditioning. matrix_epsilon: An epsilon regularizer to make the matrices positive definite. synchronous_preconditioning: Whether to run preconditioning synchronously. second_moment_averaging: 1.0 means sum of gradients squares, while less than 1.0 switches to RMSProp style exponential moving averages of the second moments. fallback_to_diagonal_dim: Fallback to diagonal version of AFMA if the any of the dimension is larger than fallback_to_diagonal_dim. max_any_dim: If maximum value for any dimension is greater than this value we skip preconditioning and fall back to the diagonal. block_size: Dimension of the partitioned tensors. block_partition_threshold_size: Partitions diemnsions beyond this size. global_step: Global step for training. exponent_multiplier: A multiplier 'e` for the exponent for the inverse calculation. e * -1/(2*rank). Only applies when calculating inverses through svd. name: Optional name prefix for the operations created when applying gradients. """ super(DistributedShampoo, self).__init__(False, name) self._learning_rate = learning_rate self._momentum = momentum self._initial_accumulator_value = initial_accumulator_value self._start_preconditioning_steps = start_preconditioning_steps self._matrix_epsilon = matrix_epsilon self._synchronous_preconditioning = synchronous_preconditioning self._second_moment_averaging = second_moment_averaging self._fallback_to_diagonal_dim = fallback_to_diagonal_dim self._max_any_dim = max_any_dim self._block_size = block_size # NOTE: On XLA - int64 is not handled properly. if global_step is not None: self._global_step = tf.cast(tf.identity(global_step), tf.int32) else: self._global_step = tf.cast( tf.identity(tf.train.get_or_create_global_step()), tf.int32) self._run_nondiagonal_update = tf.greater_equal( self._global_step, self._start_preconditioning_steps) start_steps_f = tf.cast(self._start_preconditioning_steps, tf.float32) global_step_f = tf.cast(self._global_step, tf.float32) self._run_nondiagonal_update_warmup = tf.minimum( 1.0, tf.maximum((global_step_f - start_steps_f) / start_steps_f, 0.0)) # Computes statistics every K steps. self._statistics_computation_frequency = statistics_computation_frequency self._run_statistics_computation = tf.equal( tf.math.floormod(self._global_step, self._statistics_computation_frequency), 0) # All vars that are preconditioned. self._all_vars_for_preconditioning = [] self._exponent_multiplier = exponent_multiplier self._partition_info = PartitionConfig(block_partition_threshold_size, block_size) self._partitioner_metadata = {}
def _GetWarpMatrix(self, batch_size, choose_range, matrix_size, global_seed, max_warp_frames=None, dtype=tf.float32, max_ratio=1.0): """Returns warp matrices starting from random positions. In this function when max_warp_frames != None: 1) Sample random warp displacements from the interval [-max_warp_frames, max_warp_frames) to yield shift tensor with shape (batch_size,). 2) Truncate lengths to a maximum magnitude of (choose_range * max_ratio), so that each shift is fully contained within the corresponding sequence. 3) Random sample origin points of shape (batch_size, multiplicity) with in [shift, choose_range - shift). 4) Return a batch of 1-D linear maps that fix the boundary points and shift the origin point by the shift. When max_warp_frames == None: 1) Sample random warp displacements with magnitudes less than (choose_range * max_ratio) to yield shift tensor with shape (batch_size,). 2) Proceed through steps 3), 4). Args: batch_size: Batch size. Integer number. choose_range: Range within which the warp reference points must lie. Tensor of shape (batch_size,). matrix_size: Dimension of vector space warp matrix is applied to. Integer number. global_seed: an integer seed tensor for stateless random ops. max_warp_frames: Upper-bound on the warp distance. Integer or None. dtype: Data type. max_ratio: Maximum ratio between the shift distance and choose_range. Float number. Returns: warp_matrix: An array of fixed size warp matrices with shape (batch_size, matrix_size, matrix_size). """ p = self.params # Non-empty random seed values are only used for testing or when using # stateless random ops. seed_3, seed_4, and seed_5 are set separately to # avoid correlation of warp magnitude and origin position. if p.use_input_dependent_random_seed: seed_3 = global_seed + 3 seed_4 = global_seed + 4 seed_5 = global_seed + 5 elif p.random_seed: seed_3 = p.random_seed - 1 seed_4 = p.random_seed - 1 seed_5 = 2 * p.random_seed + 1 else: seed_3 = p.random_seed seed_4 = p.random_seed seed_5 = p.random_seed choose_range_dtype = tf.cast(choose_range, dtype=dtype) length_upper_bound = tf.cast(max_ratio * choose_range_dtype, dtype=tf.int32) # Set shift length. random_uniform = _random_uniform_op(p.use_input_dependent_random_seed) if max_warp_frames and max_warp_frames > 0: shift = random_uniform(shape=(batch_size, ), minval=-1 * max_warp_frames, maxval=max_warp_frames + 1, dtype=tf.int32, seed=seed_3) else: random_ratio = random_uniform(shape=(batch_size, ), minval=-1.0, maxval=1.0, dtype=dtype, seed=seed_4) shift = tf.cast( random_ratio * tf.cast(length_upper_bound, dtype=dtype), tf.int32) # Make sure the sampled length was smaller than max_ratio * length_bound. # Note that sampling in this way is biased. # (Shorter sequence may over-masked.) final_shift = tf.maximum(-length_upper_bound, tf.minimum(shift, length_upper_bound)) # Choose origin anchor point. mid_range = tf.cast(choose_range, dtype=tf.int32) mid_range = tf.maximum(choose_range - 2, 0) random_origin = random_uniform(shape=(batch_size, ), maxval=1.0, seed=seed_5) origin_with_in_valid_range = random_origin * tf.cast(mid_range, dtype=dtype) origin = tf.cast(origin_with_in_valid_range, tf.int32) + 1 # Set destination point of the origin anchor point under the warp map. destination = origin + final_shift # Cast origin and destination. origin = tf.cast(origin, dtype=dtype) destination = tf.cast(destination, dtype=dtype) return self._ConstructWarpMatrix(batch_size=batch_size, matrix_size=matrix_size, origin=origin, destination=destination, choose_range=choose_range_dtype, dtype=dtype)
def _ConstructWarpMatrix(self, batch_size, matrix_size, origin, destination, choose_range, dtype): """Returns warp matrices according to origin, destination and choose_range. This function constructs a batch of warp matrices which maps the batch of origin points to the batch of destination points with fixed boundary coordinates at 0 and choose_range. The warping function, defined by the origin anchor point `origin`, the destination of the origin anchor point `destination` and the length of the domain in the warping axis `choose_range` is a piecewise linear map that fixes the points 0 and `choose_range` and maps `origin` to `destination`. For the warping matrix to be non-singular, destination must lie in the range 1<= destination <= choose_range - 1, so a destination out of this range is adjusted to be in this range before the warping matrix is constructed. The warping map can be explicitly written by first defining the slopes: 1) slope_0 = origin / destination. 2) slope_1 = (choose_range - origin) / (choose_range - destination). 3) slope_2 = 1.0. Then the origin point orig_i of the mapped coordinate i is given by: 1) i < destination: orig_i = slope_0 * i. 2) destination <= i < choose_range: orig_i = slope_1 * i - (slope_1 - slope_0) * destination. 3) i >= choose_range: orig_i = i. Denoting n_i = ceil(orig_i), the warp matrix element warp[i][j] is given by: 1) j = n_i: 1 - n_i + orig_i. 2) j = n_i - 1: n_i - orig_i. 3) Otherwise: 0. Applying the warp matrix to an array of pixels, i.e., warped_pixel[i] = sum_j warp[i][j] * pixel[j], one would get warped_pixel[i] = (n_i-orig_i) pixel[n_i-1] + (1-n_i+orig_i) pixel[n_i]. Args: batch_size: Batch size. Integer number. matrix_size: Dimension of the vector space the warp matrix is applied to. Integer number. origin: Origin anchor point for warping. Tensor of shape (batch_size,) and data type dtype. destination: Destination of the origin anchor point upon warping. Tensor of shape (batch_size,) and data type dtype. choose_range: Range within which the warp reference points must lie. Tensor of shape (batch_size,) data type dtype. dtype: Data type of origin, destination, choose_range and the output warp matrix. Returns: warp_matrix: An array of fixed size warp matrices with shape (batch_size, matrix_size, matrix_size). """ p = self.params # Entries of destination must be in the range # 1 <= destination <= choose_range - 1 # for warp matrix to have non-singular values. destination = tf.minimum(tf.maximum(destination, 1.0), choose_range - 1.0) # Construct piece-wise linear function fixing boundary points # specified by zero, choose_range and matrix size and maps # the origin anchor point to the destination. destination_bc = tf.broadcast_to(destination, (matrix_size, batch_size)) destination_bc = tf.transpose(destination_bc) choose_range_bc = tf.broadcast_to(choose_range, (matrix_size, batch_size)) choose_range_bc = tf.transpose(choose_range_bc) # Slopes of piece-wise linear function. slope_0 = origin / destination slope_1 = (choose_range - origin) / (choose_range - destination) slope_2 = 1.0 # x is a batch of origin matrices. # The origin matrix is the matrix such that # origin[i][j] = Origin coordinate of coordinate i for the warp map. # Denoting the destination of the origin anchor point in the # warp map as "dest," the origin coordinate of point i is given by: # 1) i < dest: slope_0 * i. # 2) dest <= i < choose_range: slope_1 * i - (slope_1 - slope_0) * dest. # 3) i >= choose_range: i. x = tf.broadcast_to(tf.cast(tf.range(matrix_size), dtype=dtype), (batch_size, matrix_size)) x = (self.EinsumBBmBm(slope_0, x) + self.EinsumBBmBm( slope_1 - slope_0, tf.nn.relu(x - destination_bc)) + self.EinsumBBmBm(slope_2 - slope_1, tf.nn.relu(x - choose_range_bc))) x = tf.broadcast_to(x, (matrix_size, batch_size, matrix_size)) x = tf.transpose(x, perm=[1, 2, 0]) # y is a batch of coordinate matrices. # A coordinate matrix is a matrix such that # coordinate[i][j] = j. y = tf.broadcast_to(tf.cast(tf.range(matrix_size), dtype=dtype), (batch_size, matrix_size, matrix_size)) # Warp matrix is obtained by applying hat function element-wise to (x-y). # Denoting the origin point of i under the warp map as orig_i, # and n_i = ceil(orig_i), the warp matrix element warp[i][j] is given by: # 1) j = n_i: 1 - n_i + orig_i. # 2) j = n_i - 1: n_i - orig_i. # 3) Otherwise: 0. # Applying the warp matrix to pixels, i.e., # warped_pixel[i] = sum_j warp[i][j] * original_pixel[j], one would get # warped_pixel[i] = (n_i - orig_i) * original_pixel[n_i-1] # + (1 - n_i + orig_i) * original_pixel[n_i]. warp_matrix = x - y warp_matrix = _hat(warp_matrix) if p.fprop_dtype is not None and p.fprop_dtype != dtype: warp_matrix = tf.cast(warp_matrix, p.fprop_dtype) return warp_matrix
def FProp(self, theta, inputs, paddings): """Applies causal pooling to inputs. Args: theta: A `.NestedMap` object containing weights' values of this layer and its children layers. inputs: The inputs tensor. It is expected to be of shape [batch, time, frequency, channel]. The time dimension corresponds to the height dimension as in images and the frequency dimension corresponds to the width dimension as in images. paddings: The paddings tensor. It is expected to be of shape [batch, time]. Returns: outputs, out_paddings pair. - outputs: has the same shape as inputs. - out_paddings: has the same tshape as paddings. """ p = self.params if p.left_context == -1: if p.pooling_type == 'AVG': cumulative_sum = tf.math.cumsum(inputs, axis=1) cumulative_count = 1.0 + tf.range(py_utils.GetShape(inputs)[1], dtype=p.dtype) cumulative_mean = cumulative_sum / cumulative_count[ tf.newaxis, :, tf.newaxis, tf.newaxis] cumulative_mean *= 1.0 - paddings[..., tf.newaxis, tf.newaxis] return cumulative_mean, paddings else: raise NotImplementedError( 'Cumulative max pooling not implemented.') window_size = p.left_context left_pad_size = window_size - 1 large_negative = p.dtype.max * tf.constant(-0.7, dtype=p.dtype) # For max pooling, use a large negative padding value such that the max # element is almost always from a non-padding position. pad_value = 0 if p.pooling_type == 'AVG' else large_negative inputs = tf.pad(inputs, [[0, 0], [left_pad_size, 0], [0, 0], [0, 0]], constant_values=pad_value) out_feature = tf.nn.pool(inputs, window_shape=(window_size, 1), pooling_type=p.pooling_type, padding='VALID') if p.pooling_type == 'AVG': # Count the fraction of non-padding elements inside each pooling window. max_seq_len = py_utils.GetShape(paddings)[1] num_non_padded_elements = tf.range(1, 1 + max_seq_len, dtype=p.dtype) num_non_padded_elements = tf.minimum(num_non_padded_elements, tf.cast(window_size, p.dtype)) non_padded_ratio = num_non_padded_elements / tf.cast( window_size, p.dtype) # Divide by non-padding ratios to eliminate the effect of padded zeros. out_feature *= tf.math.reciprocal_no_nan( non_padded_ratio[tf.newaxis, :, tf.newaxis, tf.newaxis]) out_feature *= 1.0 - paddings[..., tf.newaxis, tf.newaxis] return out_feature, paddings
def _BodyFn(curr_idx, distance_to_selected, sampled_idx, closest_idx): """Loop body for farthest point sampler.""" def _GetRandomRealPoint(): """Select the first point. For the first point, we want any random real (non padded) point, so we create a random values per point, and then set all padded ones to some large value (more than the maxval). We then take the min per batch element to get the first points. Returns: Tensor containing the index of a random point selected for each example in the batch. """ random_values = tf.random.uniform((batch_size, num_points), minval=0, maxval=1, dtype=tf.float32, seed=random_seed) random_values = tf.where( tf.equal(padding, 0.0), random_values, padding * 10) return tf.argmin(random_values, axis=1, output_type=tf.int32) def _GetFurthestPoint(): """Get point that is furthest from those already selected. We also bias the sampling towards real points by setting the distance to padded points negative until we are out of real points. Returns: Tensor containing the index of the next farthest point selected for each example in the batch. """ # Set padded points distance to negative so they aren't selected. padding_masked_distance_to_selected = tf.where( tf.equal(padding, 0.0), distance_to_selected, -1.0 * tf.ones( (batch_size, num_points), dtype=tf.float32)) # But only do this when we still have valid points left. padding_masked_distance_to_selected = tf.where( tf.less(curr_idx, num_valid_points), padding_masked_distance_to_selected, distance_to_selected) return tf.argmax( padding_masked_distance_to_selected, axis=-1, output_type=tf.int32) def _GetSeededPoint(): """Select a seeded point. Seeded points are assumed to be at the beginning of the original points. Returns: Tensor containing the index of the next seeded point to select for each example in the batch. """ return tf.ones((batch_size,), dtype=tf.int32) * curr_idx # Select indices for this loop iteration. def _Seeded(): return tf.cond( tf.less(curr_idx, num_seeded_points), _GetSeededPoint, _GetFurthestPoint) def _Real(): return tf.cond( tf.equal(curr_idx, 0), _GetRandomRealPoint, _GetFurthestPoint) new_selected = tf.cond(tf.greater(num_seeded_points, 0), _Seeded, _Real) sampled_idx = sampled_idx.write(curr_idx, new_selected) # Extract the distance to the latest point selected to update # distance_to_selected. new_selected_gather_idx = tf.stack([tf.range(batch_size), new_selected], axis=1) if precomputed_squared_distance is not None: new_distance = tf.gather_nd(precomputed_squared_distance, new_selected_gather_idx) else: new_points = tf.reshape( tf.gather_nd(points, new_selected_gather_idx), [batch_size, 1, dims]) new_distance = tf.reshape( SquaredDistanceMatrix(points, new_points), [batch_size, num_points]) is_newly_closest = tf.less(new_distance, distance_to_selected) distance_to_selected = tf.minimum(distance_to_selected, new_distance) # Track the index to the closest selected point. new_selected_tiled = tf.tile([[curr_idx]], [batch_size, num_points]) closest_idx = tf.cond( tf.equal(curr_idx, 0), # At the first loop iteration, the init points are the closest. lambda: new_selected_tiled, # Otherwise, update with the new points based on the distances. lambda: tf.where(is_newly_closest, new_selected_tiled, closest_idx)) return curr_idx + 1, distance_to_selected, sampled_idx, closest_idx
def Value(self): return tf.minimum(self.rampup_schedule.Value(), self.decay_schedule.Value())
def Value(self, step=None): return tf.minimum( self.rampup_schedule.Value(step), self.decay_schedule.Value(step))
def FProp(self, theta, x, x_paddings=None, eos_id=1, force_sample_last_token=True): """Applies SymbolInsertionLayer. We take in a `x`, which represents the groundtruth sequence (i.e., English sequence). We return a sampled rollin (observed) canvas (i.e., random subset of the English sequence), as well as the target (indices) for an insertion-based model (i.e., the targets given the random observed subset). Args: theta: Ignored, this can be None. x: The symbol ids of shape `[batch_size, time_dim]`. x_paddings: The paddings (1 or 0) of shape `[batch_size, time_dim]` where 0 is valid and 1 is invalid. eos_id: The <eos> token id to represent end-of-slot. force_sample_last_token: Set True to force sample the last token of `x`. Returns: A `NestedMap`. - canvas: The canvas (based off of the `rollin_policy`) of shape [batch_size, c_dim]. Note that, `c_dim` <= `time_dim` but need not be equal. - canvas_indices: The canvas indices (into `x`). - canvas_paddings: The paddings of `canvas_indices`. - target_indices: The target indices of shape [num_targets, 3]. `num_targets` is the number of total targets in the entire batch. [:, 0] captures the batch, [:, 1] captures the slot, and [:, 2] captures the token. Each row [batch, slot, vocab] represents the indices of the target -- i.e., the batch, slot and vocab combination of the target. Typical usage of these indices is to tf.gather_nd the log-probs (from the softmax layer). - target_weights: The target weights. Raises: ValueError: If invalid params. """ p = self.params batch_size = py_utils.GetShape(x)[0] time_dim = py_utils.GetShape(x)[1] if x_paddings is None: x_paddings = tf.zeros([batch_size, time_dim], tf.float32) oracle_policy = p.oracle_policy rollin_policy = (oracle_policy if p.rollin_policy == 'oracle' else p.rollin_policy) if rollin_policy != 'uniform': raise ValueError('Unknown or unsupported rollin policy: %s' % rollin_policy) if oracle_policy != 'uniform': raise ValueError('Unknown or unsupported oracle policy: %s' % oracle_policy) x_len = tf.to_int32(tf.round(tf.reduce_sum(1 - x_paddings, 1))) # Compute the desired length per example in the batch. ratio = tf.random.uniform([batch_size], 0.0, 1.0, seed=p.random_seed) if force_sample_last_token: c_len = tf.minimum( tf.cast(ratio * tf.cast(x_len, tf.float32), tf.int32), x_len - 1) + 1 else: c_len = tf.minimum( tf.cast(ratio * tf.cast(x_len + 1, tf.float32), tf.int32), x_len) # Compute the maximum length across the batch. c_len_max = tf.reduce_max(c_len) # Grab subset of random valid indices per example. z_logits = tf.cast( tf.expand_dims(tf.range(time_dim), 0) >= tf.expand_dims(x_len, 1), tf.float32) * -1e9 if force_sample_last_token: # Force sample the last token -- i.e., as indexed by `x_len - 1`. We can # accomplish this by add +LARGE_NUMBER to the logits. z_logits += tf.cast( tf.equal(tf.expand_dims(tf.range(time_dim), 0), tf.expand_dims(x_len - 1, 1)), tf.float32) * 1e9 # Gumbel-max trick to sample (we only sample valid positions per sample in # the batch). z = -tf.math.log(-tf.math.log( tf.random.uniform([batch_size, time_dim], seed=p.random_seed))) unused_c_values, c_indices = tf.nn.top_k(z_logits + z, time_dim) # Trim everything > c_len_max. c_indices = c_indices[:, :c_len_max] # Invalidate any indices >= c_len, we use the last index as the default # invalid index. c_indices = tf.where( tf.expand_dims(tf.range(c_len_max), 0) < tf.expand_dims(c_len, 1), c_indices, tf.fill(py_utils.GetShape(c_indices), time_dim - 1)) # Materialize the canvas. c_indices = tf.sort(c_indices) c = tf.gather_nd( x, tf.stack([ tf.reshape( tf.tile(tf.expand_dims(tf.range(batch_size), 1), [1, c_len_max]), [-1]), tf.reshape(c_indices, [-1]) ], 1)) c = tf.reshape(c, [batch_size, c_len_max]) # Compute the paddings. c_paddings = 1 - tf.sequence_mask( c_len, c_len_max, dtype=x_paddings.dtype) c *= tf.cast(1 - c_paddings, tf.int32) indices = tf.concat([ tf.reshape( tf.tile(tf.expand_dims(tf.range(batch_size), 1), [1, c_len_max]), [batch_size * c_len_max, 1]), tf.reshape(c_indices, [batch_size * c_len_max, 1]) ], 1) x_token_is_observed = tf.scatter_nd( indices, tf.ones([batch_size * c_len_max], tf.int32), py_utils.GetShape(x)) # `x_segments` captures which slot each `x` belongs to (both observed and # tokens that need to be observed). x_segments = tf.cumsum(x_token_is_observed, 1, exclusive=True) x_token_is_observed = tf.cast(x_token_is_observed, tf.bool) prev_x_token_is_observed = tf.pad(x_token_is_observed[:, :-1], [[0, 0], [1, 0]], constant_values=True) x_token_is_observed = tf.reshape(x_token_is_observed, [-1]) prev_x_token_is_observed = tf.reshape(prev_x_token_is_observed, [-1]) x_is_valid = tf.cast(1 - x_paddings, tf.bool) x_is_valid = tf.reshape(x_is_valid, [-1]) # Remap all the observed to <eos>, note some of these need a zero weight # (or else there would be <eos> and valid token in the same slot). target_indices = tf.cast(tf.reshape(x, [-1, 1]), tf.int32) target_indices = tf.where( x_token_is_observed, tf.fill(py_utils.GetShape(target_indices), eos_id), target_indices) # TODO(williamchan): We give uniform 1.0 weight, however, math suggests # we may want to weigh this term by the original sequence length. target_weights = tf.ones_like(target_indices, tf.float32) # We need to set all the weights for <eos> which actually have valid tokens # in the slot to zero. target_weights = tf.where( x_token_is_observed & ~prev_x_token_is_observed, tf.zeros_like(target_weights), target_weights) # TODO(williamchan): Consider dropping the entries w/ weight zero. # Add the batch and slot indices. target_indices = tf.concat([ tf.reshape( tf.tile(tf.expand_dims(tf.range(batch_size), 1), [1, time_dim]), [batch_size * time_dim, 1]), tf.reshape(x_segments, [-1, 1]), target_indices ], 1) # Select only the valid indices. The selected valid ones include slots w/ # <eos>. target_indices = target_indices[x_is_valid] target_weights = target_weights[x_is_valid] return py_utils.NestedMap(canvas=c, canvas_indices=c_indices, canvas_paddings=c_paddings, target_indices=target_indices, target_weights=target_weights)
def flat_beam_search(batch_size, beam_size, max_steps, dec_callback, dec_state, bos_id=1, eos_id=2, length_norm_alpha=0.8, beam_gap=3.0, top_k_fn=tf.math.top_k, prefix=None, prefix_len=None, fprop_dtype=tf.float32, ext_size=0, nbest_size=None, debug=True): """Flat beam search. Args: batch_size: batch size beam_size: beam size limit in number of hyps max_steps: max steps dec_callback: decoder callback (see above) dec_state: decoder state bos_id: <s> token id eos_id: </s> token id length_norm_alpha: length normalization parameter beam_gap: early stopping threshold; None to disable top_k_fn: top_k function to call prefix: (optional) int32 tensor [batch_size, prefix_max] prefix_len: (optional) int32 tensor [batch_size] fprop_dtype: fprop dtype ext_size: int >= beam_size, extension buffer size nbest_size: number of returned hyps, default is beam_size debug: log intermediate vlaues with tpu_summary.tensor() Returns: (loop_vars, dec_state, nbest) where nbest = (topk_ids, topk_len, topk_score) """ assert beam_size > 0 assert batch_size > 0 assert max_steps > 0 buf_size = beam_size * max_steps output_len = max_steps if prefix is None: assert prefix_len is None # Create prefix of start tokens. prefix = tf.zeros([batch_size, beam_size], dtype=tf.int32) prefix += tf.one_hot(beam_size - 1, beam_size, dtype=tf.int32) * bos_id prefix_len = tf.ones([batch_size], dtype=tf.int32) else: assert int(prefix.shape[0]) == batch_size, (batch_size, prefix.shape) assert int(prefix_len.shape[0]) == batch_size, (batch_size, prefix_len.shape) output_len += int(prefix.shape[1]) if debug: tpu_summary.tensor('prefix', prefix) tpu_summary.tensor('prefix_len', prefix_len) with tf.name_scope('init_state'): t = tf.constant(0) tgt_id = tf.zeros([batch_size, beam_size], dtype=tf.int32) tgt_id += bos_id tgt_pos = tf.zeros([batch_size, beam_size], dtype=tf.int32) tgt_mask = tf.zeros([batch_size, beam_size, buf_size], dtype=fprop_dtype) tgt_mask += tf.one_hot(tf.range(beam_size), buf_size, dtype=fprop_dtype) hyp_score = tf.zeros([batch_size, beam_size], dtype=fprop_dtype) # penalize all hyps except the first hyp_score -= tf.cast(tf.range(beam_size, dtype=tf.float32) * 1e5, dtype=fprop_dtype) nbest_size = nbest_size or beam_size nbest_score = tf.zeros([batch_size, nbest_size], dtype=fprop_dtype) nbest_score -= 1e9 nbest_score_norm = nbest_score nbest_mask = tf.zeros([batch_size, nbest_size, buf_size], dtype=fprop_dtype) with tf.name_scope('init_ext'): # Initialize the extension buffer. # # Extension buffer stores a (potentially large) set of 'extensions', # which consist of a hypothesis (represented by ext_mask) and next token # (represented by ext_id). At each decoder iteration, top_k extensions # from each hypothesis are added to the buffer and sorted by score. # # Then top beam_size extensions are removed from the buffer and used # in the next decoder iteration. And top 'ext_size' remaining extensions # are carried over to be possibly evaluated at a later step. # # As a result of this manipulation, the decoder is no longer restricted # to always compare hyps of the same token length at each iteration. # In particular, for a fixed length N it can generate more than beam_size # terminated hyps. # # Setting ext_size = 0 disables this feautre. if ext_size: ext_id = tf.zeros([batch_size, ext_size], dtype=tf.int32) ext_score = tf.zeros([batch_size, ext_size], dtype=fprop_dtype) ext_score -= 1e9 ext_mask = tf.zeros([batch_size, ext_size, buf_size], dtype=fprop_dtype) else: ext_size = ext_id = ext_score = ext_mask = 0 with tf.name_scope('init_prefix'): # rename prefix->pfx for shorter variables pfx = tf.cast(prefix, tf.int32) pfx_len = tf.cast(prefix_len, tf.int32) del prefix, prefix_len # Before the first call to dec_callback() the prefix shall be packed into # the tgt_id buffer as follows: # # [ - - - - - - P P P P P P P* - - - ] ^ # [ - - P P P P P P P P P P P* - - - ] | batch # [ - - - - - - - - - - - P P* - - - ] V # |<---- prefix len ----> |<-- beam --> # # The last meaningful token in the prefix (P*) # must be located at the same position in all batch rows. # # We then make one dec_callback() with full prefix (minus P*) # which will populate the initial dec_state # (for transformer -- self-attention key/value cache) # # The last block [batch, beam] then becomes the first tgt_id for the loop. pfx_max = int(pfx.shape[1]) pfx_mul = pfx_max // beam_size assert pfx_max == pfx_mul * beam_size, (pfx_max, pfx_mul, beam_size) pfx_time = tf.range(pfx_max) pfx_indexes = pfx_time - pfx_max + tf.expand_dims(pfx_len - 1, 1) pfx_pad = tf.cast(tf.greater_equal(pfx_indexes, 0), tf.int32) # Exclude final pfx token. pfx_id = tf.roll(pfx, shift=1, axis=-1) * pfx_pad pfx_last = pfx[:, -1] buf_time = tf.range(buf_size) pfx_time_mask = tf.cast( tf.less_equal(tf.expand_dims(buf_time, 0), tf.expand_dims(pfx_time, 1)), fprop_dtype) pfx_mask = tf.einsum('BQ,QK->BQK', tf.cast(pfx_pad, fprop_dtype), pfx_time_mask) # Remove padding. assert buf_size > pfx_max pfx_pad_long = tf.pad(pfx_pad, [(0, 0), (0, buf_size - pfx_max)], constant_values=1) pfx_mask *= tf.cast(tf.expand_dims(pfx_pad_long, axis=1), tf.float32) pfx_segment_id = pfx_pad pfx_pos = pfx_indexes * pfx_pad if debug: tpu_summary.tensor('pfx_id', pfx_id) tpu_summary.tensor('pfx_len', pfx_len) tpu_summary.tensor('pfx_pos', pfx_pos) tpu_summary.tensor('pfx_last', pfx_last) # Now call decoder with prefix minus P*: # 'dec_state' now shall contain the key/value cache for prefix tokens # (for transformer models), and 'logits' we can either discard or # roll into the initial hyp_score. Discard is simpler. with tf.name_scope('prefix_fprop'): # TODO(krikun): remove extra type checks assert (pfx_id.dtype == tf.int32), (pfx_id.dtype) assert (pfx_segment_id.dtype == tf.int32), (pfx_segment_id.dtype) assert (pfx_pos.dtype == tf.int32), (pfx_pos.dtype) assert (pfx_mask.dtype == fprop_dtype), (pfx_mask.dtype) assert (t.dtype == tf.int32), (t.dtype) logits, dec_state = dec_callback(pfx_id, pfx_segment_id, pfx_pos, pfx_mask, dec_state, t) del logits # Now construct the initial state for the rest of the beam search loop. # 'tgt_id' is simply 'pfx_last' padded to [batch, beam] shape # 'tgt_pos' is different for each batch row and is equal to prefix_len # 'tgt_segment_id' always 1 (no packing) # 'hyp_score' is 0 for beam=0 and negative for beam>=1 tgt_id = tf.zeros([batch_size, beam_size], tf.int32) + tf.expand_dims( pfx_last, 1) tgt_pos = tf.zeros([batch_size, beam_size], tf.int32) + tf.expand_dims( (pfx_len - 1), 1) hyp_score = tf.zeros( [batch_size, beam_size], dtype=fprop_dtype) - tf.cast( tf.range(beam_size, dtype=tf.float32) * 1e5, dtype=fprop_dtype) # TODO(krikun) Here we make initial 't' constant and determined by the # shape of the prefix tensor 'pfx_max'. It is possible to make it dynamic # as t ~ max(pfx_len) / beam_size and this will more steps for beam search # however 'max' results in a very slow all-to-all for 'max' on 16x16 # and variable number of decoder steps may result in bad latency. t = tf.cast(tf.math.ceil(pfx_max / beam_size), tf.int32) # Initial tgt_mask is such that each token P* has attention on itself # (as usual) and on all prefix tokens before it, which are not padding. tgt_mask = tf.zeros([batch_size, beam_size, buf_size], dtype=fprop_dtype) tgt_mask += tf.cast( tf.expand_dims( tf.pad(pfx_pad, [[0, 0], [0, (buf_size - pfx_max)]]), 1), fprop_dtype) tgt_mask += tf.one_hot(tf.range(beam_size) + t * beam_size, buf_size, dtype=fprop_dtype) if debug: tpu_summary.tensor('tgt_id', tgt_id) tpu_summary.tensor('tgt_pos', tgt_pos) tpu_summary.tensor('tgt_mask', tgt_mask) tpu_summary.tensor('t', t) with tf.name_scope('init_hist'): # h_tgt_id is used to recover topk_ids from nbest_mask h_tgt_id = tf.TensorArray(dtype=tf.int32, size=max_steps) h_tgt_pos = tf.TensorArray(dtype=tf.int32, size=max_steps) # When non-trivial prefix is present we also write prefix ids to # h_tgt_id so that the full sequence including prefix can be recovered # by unmask() below. When prefix is empty, pfx_id shape is [batch, 0] # and the loop below becomes a no-op. # TODO(krikun): maybe a tf.while_loop is more appropriate here. for i, x_i in enumerate(tf.split(pfx_id, pfx_mul, 1)): h_tgt_id = h_tgt_id.write(i, x_i) for i, x_i in enumerate(tf.split(pfx_pos, pfx_mul, 1)): h_tgt_pos = h_tgt_pos.write(i, x_i) hist = (h_tgt_id, h_tgt_pos) tf.logging.info('hist=%r', hist) nbest_hyps = (nbest_mask, nbest_score, nbest_score_norm) tf.logging.info('nbest_hyps=%r', nbest_hyps) ext = (ext_id, ext_score, ext_mask) tf.logging.info('ext=%r', ext) loop_vars = (t, tgt_id, tgt_pos, tgt_mask, hyp_score, nbest_hyps, ext, hist) tf.logging.info('loop_vars=%r', loop_vars) def loop_step(loop_vars, dec_state): # pylint: disable=missing-docstring tf.logging.info('loop_vars=%r', loop_vars) tf.logging.info('dec_state=%r', dec_state) (t, tgt_id, tgt_pos, tgt_mask, hyp_score, nbest_hyps, ext, hist) = loop_vars (ext_id, ext_score, ext_mask) = ext (h_tgt_id, h_tgt_pos) = hist h_tgt_id = h_tgt_id.write(t, tgt_id, name='h_tgt_id') h_tgt_pos = h_tgt_pos.write(t, tgt_pos, name='h_tgt_pos') # not using tf.ones() here because of XLA compilation error tgt_segment_id = tgt_id * 0 + 1 logits, dec_state = dec_callback(tgt_id, tgt_segment_id, tgt_pos, tgt_mask, dec_state, t) # take predicted EOS score for each hyp and compute normalized score eos_score = hyp_score + tf.cast(logits[:, :, eos_id], hyp_score.dtype) def length_norm(t): t = tf.cast(t, fprop_dtype) alpha = length_norm_alpha tf.logging.info('length_norm.alpha=%r', alpha) return tf.math.pow((t + 5.) / 5., alpha) hyp_len = tgt_pos - tf.expand_dims((pfx_len - 1), -1) eos_score_norm = eos_score / length_norm(hyp_len) # update the n-best list nbest_hyps = update_nbest(nbest_hyps, (tgt_mask, hyp_score, eos_score_norm)) if debug: tpu_summary.tensor('eos_score', eos_score) tpu_summary.tensor('hyp_len', hyp_len) # take top k tokens for each hyp k = beam_size with tf.name_scope('topk1'): top_score, top_id = top_k_fn(logits, k) top_score = tf.cast(top_score, fprop_dtype) top_score += tf.expand_dims(hyp_score, -1) top_score -= 1e9 * tf.cast(tf.equal(top_id, eos_id), fprop_dtype) top_score = tf.reshape(top_score, [batch_size, beam_size * k]) top_id = tf.reshape(top_id, [batch_size, beam_size * k]) top_mask = tf.repeat(tgt_mask, beam_size, 1) if debug: tpu_summary.tensor('top_id', top_id) tpu_summary.tensor('top_score', top_score) # tpu_summary.tensor('top_mask', top_mask) with tf.name_scope('update_ext'): # combine top k tokens with extension buffer (if any) if ext_size: ext_id = tf.concat([ext_id, top_id], 1) ext_score = tf.concat([ext_score, top_score], 1) ext_mask = tf.concat([ext_mask, top_mask], 1) else: ext_id, ext_score, ext_mask = top_id, top_score, top_mask # sort by score ext_score, i = tf.math.top_k(ext_score, ext_size + beam_size) i1 = tf.one_hot(i, ext_size + beam_size * k, dtype=fprop_dtype) ext_mask = tf.einsum('bkt,bjk->bjt', ext_mask, i1) ext_id = einsum_i32('bk,bjk->bj', ext_id, i1) # pick top beam_size extensions to evaluate at next iteration if ext_size: hyp_score = ext_score[:, :beam_size] ext_score = ext_score[:, beam_size:] tgt_id = ext_id[:, :beam_size] ext_id = ext_id[:, beam_size:] tgt_mask = ext_mask[:, :beam_size] ext_mask = ext_mask[:, beam_size:] else: hyp_score, tgt_id, tgt_mask = ext_score, ext_id, ext_mask ext_score = ext_id = ext_mask = 0 tgt_pos = tf.reduce_sum(tgt_mask, -1) tgt_pos = tf.cast(tgt_pos, tf.int32) t += 1 with tf.name_scope('tgt_mask_extend'): tgt_mask += tf.one_hot(tf.range(beam_size) + t * beam_size, buf_size, dtype=fprop_dtype) ext = (ext_id, ext_score, ext_mask) hist = (h_tgt_id, h_tgt_pos) loop_vars = (t, tgt_id, tgt_pos, tgt_mask, hyp_score, nbest_hyps, ext, hist) tf.logging.info('loop_vars=%r', loop_vars) tf.logging.info('dec_state=%r', dec_state) return loop_vars, dec_state def loop_cond(loop_vars, dec_state): # pylint: disable=missing-docstring tf.logging.info('loop_vars=%r', loop_vars) tf.logging.info('dec_state=%r', dec_state) if beam_gap is None: (t, _, _, _, _, _, _, _) = loop_vars return t < max_steps else: (t, _, _, _, _, nbest_hyps, _, _) = loop_vars (_, nbest_score, _) = nbest_hyps # stop early if all current hyps are significantly worse than nbest diff = tf.reduce_min( tf.reduce_min(nbest_score, -1) - tf.reduce_max(hyp_score, -1)) return tf.math.logical_and(t < max_steps, diff < beam_gap) with tf.name_scope('flat_beam_search_loop'): (loop_vars, dec_state) = tf.while_loop(loop_cond, loop_step, loop_vars=(loop_vars, dec_state), back_prop=False, swap_memory=False, maximum_iterations=max_steps) # flatten all tensorarrays into tensors (t, tgt_id, tgt_pos, tgt_mask, hyp_score, nbest_hyps, ext, hist) = loop_vars (nbest_mask, nbest_score, nbest_score_norm) = nbest_hyps (h_tgt_id, h_tgt_pos) = hist h_tgt_id = h_tgt_id.stack() h_tgt_pos = h_tgt_pos.stack() hist = (h_tgt_id, h_tgt_pos) loop_vars = (t, tgt_id, tgt_pos, tgt_mask, hyp_score, nbest_hyps, ext, hist) # recover topk_ids from nbest_mask and tgt_id history h = tf.transpose(h_tgt_id, [1, 0, 2]) h = tf.reshape(h, [batch_size, buf_size]) def unmask(h, m): with tf.name_scope('unmask'): tpu_summary.tensor('unmask_h', h) tpu_summary.tensor('unmask_m', m) t = tf.cumsum(m, -1) * m - 1 mh = einsum_i32('bkt,bt->bkt', m, h) t2 = tf.one_hot(tf.cast(t, tf.int32), output_len, dtype=fprop_dtype) x = einsum_i32('bkt,bktT->bkT', mh, t2) return tf.cast(x, h.dtype) topk_ids = unmask(h, nbest_mask) topk_len = tf.reduce_sum(nbest_mask, -1) topk_len = tf.cast(topk_len, tf.int32) # add eos, because nbest_mask does not encode eos topk_ids += eos_id * tf.one_hot(topk_len, output_len, dtype=tf.int32) topk_len += 1 topk_len = tf.minimum(topk_len, output_len) topk_score = nbest_score_norm nbest = (topk_ids, topk_len, topk_score) return loop_vars, dec_state, nbest
def CornerLoss(self, gt_bboxes, predicted_bboxes): """Corner regularization loss. This function computes the corner loss, an alternative regression loss for box residuals. This was used in the Frustum-PointNets paper [1]. We compute the predicted bboxes (all 8 corners) and compute a SmoothedL1 loss between the corners of the predicted boxes and ground truth. Hence, this loss can help encourage the model to maximize the IoU of the predictions. [1] Frustum PointNets for 3D Object Detection from RGB-D Data https://arxiv.org/pdf/1711.08488.pdf TODO(bcyang): support arbitrary input shapes [..., 7]. Args: gt_bboxes: tf.float32 of shape [batch_size, num_centers, num_anchor_bboxes_per_center, 7] which contains (x, y, z, dx, dy, dz, phi), corresponding to ground truth bbox parameters. predicted_bboxes: tf.float32 of same shape as gt_bboxes containing predicted bbox parameters. Returns: tf.float32 Tensor of shape [batch_size, num_centers, num_anchor_bboxes_per_center] where each entry contains the corner loss for the corresponding bbox. """ batch_size, num_centers, num_anchor_bboxes_per_center = py_utils.GetShape( gt_bboxes, 3) gt_bboxes = py_utils.HasShape( gt_bboxes, [batch_size, num_centers, num_anchor_bboxes_per_center, 7]) predicted_bboxes = py_utils.HasShape( predicted_bboxes, [batch_size, num_centers, num_anchor_bboxes_per_center, 7]) gt_bboxes = tf.reshape( gt_bboxes, [batch_size, num_centers * num_anchor_bboxes_per_center, 7]) predicted_bboxes = tf.reshape( predicted_bboxes, [batch_size, num_centers * num_anchor_bboxes_per_center, 7]) rot = tf.constant([[[0., 0., 0., 0., 0., 0., np.pi]]], dtype=tf.float32) rotated_gt_bboxes = gt_bboxes + rot gt_corners = geometry.BBoxCorners(gt_bboxes) rotated_gt_corners = geometry.BBoxCorners(rotated_gt_bboxes) predicted_corners = geometry.BBoxCorners(predicted_bboxes) corner_dist = tf.norm(predicted_corners - gt_corners, axis=-1) rotated_corner_dist = tf.norm( predicted_corners - rotated_gt_corners, axis=-1) total_dist = tf.reduce_sum(corner_dist, axis=-1) rotated_total_dist = tf.reduce_sum(rotated_corner_dist, axis=-1) min_dist = tf.minimum(total_dist, rotated_total_dist) huber_loss = self.ScaledHuberLoss( labels=tf.zeros_like(total_dist), predictions=min_dist) huber_loss = tf.reshape( huber_loss, [batch_size, num_centers, num_anchor_bboxes_per_center]) return huber_loss