def BatchedOrientedNMSIndices(self, bboxes, scores, nms_iou_threshold, score_threshold, max_boxes_per_class): """Runs batched version of a Per-Class 3D (7-DOF) Non Max Suppression. All outputs have shape [batch_size, num_classes, max_boxes_per_class]. Args: bboxes: A [batch_size, num_boxes, 7] floating point Tensor of bounding boxes in [x, y, z, dx, dy, dz, phi] format. scores: A [batch_size, num_boxes, num_classes] floating point Tensor containing box scores. nms_iou_threshold: Either a float or a list of floats of len num_classes with the IoU threshold to use when determining whether two boxes overlap for purposes of suppression. score_threshold: Either a float or a list of floats of len num_classes with the score threshold that allows NMS to quickly ignore boxes. max_boxes_per_class: An integer scalar with the maximum number of boxes per example to emit per class. Returns: A tuple of 3 tensors: - bbox_indices: An int32 Tensor with the indices of the chosen boxes. Values are in sort order until the class_idx switches. - bbox_scores: A float32 Tensor with the score for each box. - valid_mask: A float32 Tensor with 1/0 values indicating the validity of each box. 1 indicates valid, and 0 invalid. """ bboxes = py_utils.HasShape(bboxes, [-1, -1, 7]) batch_size, num_boxes = py_utils.GetShape(bboxes, 2) scores = py_utils.HasShape(scores, [batch_size, num_boxes, -1]) _, _, num_classes = py_utils.GetShape(scores) # Force the thresholds to be tensors of len num_classes nms_iou_threshold = tf.broadcast_to( tf.convert_to_tensor(nms_iou_threshold), [num_classes]) score_threshold = tf.broadcast_to( tf.convert_to_tensor(score_threshold), [num_classes]) def NMSBody(args): per_sample_bboxes, per_sample_scores = args indices, scores, mask = ops.non_max_suppression_3d( per_sample_bboxes, per_sample_scores, nms_iou_threshold=nms_iou_threshold, score_threshold=score_threshold, max_boxes_per_class=max_boxes_per_class) return indices, scores, mask bbox_indices, bbox_scores, valid_mask = tf.map_fn( fn=NMSBody, elems=(bboxes, scores), dtype=(tf.int32, tf.float32, tf.float32), back_prop=False) output_shape = [batch_size, num_classes, max_boxes_per_class] bbox_indices = py_utils.PadOrTrimTo(bbox_indices, output_shape) bbox_scores = py_utils.PadOrTrimTo(bbox_scores, output_shape) valid_mask = py_utils.PadOrTrimTo(valid_mask, output_shape) return bbox_indices, bbox_scores, valid_mask
def testSpectrumAugmenterWarpMatrixConstructor(self): with self.session(use_gpu=False, graph=tf.Graph()) as sess: inputs = tf.broadcast_to(tf.cast(tf.range(10), dtype=tf.float32), (4, 10)) origin = tf.cast([2, 4, 4, 5], dtype=tf.float32) destination = tf.cast([3, 2, 6, 8], dtype=tf.float32) choose_range = tf.cast([4, 8, 8, 10], dtype=tf.float32) outputs = [] for p in [ spectrum_augmenter.SpectrumAugmenter.Params(), spectrum_augmenter_on_device.SpectrumAugmenterOnDevice. Params() ]: p.name = 'specAug_layers' specaug_layer = p.Instantiate() warp_matrix = specaug_layer._ConstructWarpMatrix( batch_size=4, matrix_size=10, origin=origin, destination=destination, choose_range=choose_range, dtype=tf.float32) output = tf.einsum('bij,bj->bi', warp_matrix, inputs) outputs.append(output) layer_output, layer_output_on_device = sess.run(outputs) self.assertAllClose(layer_output, layer_output_on_device)
def testSpectrumAugmenterWithDynamicTimeWarping(self): with self.session(use_gpu=False, graph=tf.Graph()) as sess: tf.random.set_seed(1234) inputs = tf.broadcast_to(tf.cast(tf.range(10), dtype=tf.float32), (3, 10)) inputs = tf.expand_dims(tf.expand_dims(inputs, -1), -1) paddings = [] for i in range(3): paddings.append( tf.concat( [tf.zeros([1, 2 * i + 5]), tf.ones([1, 5 - 2 * i])], axis=1)) paddings = tf.concat(paddings, axis=0) hs = [] for p in [ spectrum_augmenter.SpectrumAugmenter.Params(), spectrum_augmenter_on_device.SpectrumAugmenterOnDevice. Params() ]: p.name = 'specAug_layers' p.freq_mask_max_bins = 0 p.time_mask_max_frames = 0 p.time_warp_max_ratio = 0.5 p.time_warp_bound = 'dynamic' p.random_seed = 34567 specaug_layer = p.Instantiate() h, _ = specaug_layer.FPropDefaultTheta(inputs, paddings) hs.append(h) layer_output, layer_output_on_device = sess.run(hs) self.assertAllClose(layer_output, layer_output_on_device)
def testSpectrumAugmenterWithFreqWarping(self): with self.session(use_gpu=False, graph=tf.Graph()): tf.random.set_seed(1234) inputs = tf.broadcast_to( tf.cast(tf.range(8), dtype=tf.float32), (5, 1, 8)) inputs = tf.expand_dims(inputs, -1) paddings = tf.zeros([3, 2]) p = spectrum_augmenter.SpectrumAugmenter.Params() p.name = 'specAug_layers' p.freq_mask_max_bins = 0 p.time_mask_max_frames = 0 p.freq_warp_max_bins = 4 p.time_warp_max_frames = 0 p.random_seed = 345678 specaug_layer = p.Instantiate() # pyformat: disable # pylint: disable=bad-whitespace,bad-continuation expected_output = np.array( [[[0.0, 4.0, 4.5714283, 5.142857, 5.714286, 6.285714, 6.8571434, 3.999998]], [[0.0, 0.8, 1.6, 2.4, 3.2, 4.0, 5.3333335, 6.6666665]], [[0.0, 0.6666667, 1.3333334, 2.0, 3.2, 4.4, 5.6000004, 6.8]], [[0.0, 1.3333334, 2.6666667, 4.0, 4.8, 5.6000004, 6.3999996, 5.5999947]], [[0.0, 2.0, 2.857143, 3.7142859, 4.571429, 5.4285717, 6.2857146, 5.999997]]]) # pylint: enable=bad-whitespace,bad-continuation # pyformat: enable h, _ = specaug_layer.FPropDefaultTheta(inputs, paddings) actual_layer_output = self.evaluate(tf.squeeze(h, -1)) print(np.array_repr(actual_layer_output)) self.assertAllClose(actual_layer_output, expected_output)
def testSpectrumAugmenterWarpMatrixConstructor(self): with self.session(use_gpu=False, graph=tf.Graph()): inputs = tf.broadcast_to(tf.cast(tf.range(10), dtype=tf.float32), (4, 10)) origin = tf.cast([2, 4, 4, 5], dtype=tf.float32) destination = tf.cast([3, 2, 6, 8], dtype=tf.float32) choose_range = tf.cast([4, 8, 8, 10], dtype=tf.float32) p = spectrum_augmenter.SpectrumAugmenter.Params() p.name = 'specAug_layers' specaug_layer = p.Instantiate() # pyformat: disable # pylint: disable=bad-whitespace,bad-continuation expected_output = np.array( [[0.0000000, 0.6666667, 1.3333333, 2.0000000, 4.0000000, 5.0000000, 6.0000000, 7.0000000, 8.0000000, 9.0000000], [0.0000000, 2.0000000, 4.0000000, 4.6666667, 5.3333333, 6.0000000, 6.6666667, 7.3333333, 8.0000000, 9.0000000], [0.0000000, 0.6666667, 1.3333333, 2.0000000, 2.6666667, 3.3333333, 4.0000000, 6.0000000, 8.0000000, 9.0000000], [0.0000000, 0.6250000, 1.2500000, 1.8750000, 2.5000000, 3.1250000, 3.7500000, 4.3750000, 5.0000000, 7.5000000]]) # pylint: enable=bad-whitespace,bad-continuation # pyformat: enable warp_matrix = specaug_layer._ConstructWarpMatrix( batch_size=4, matrix_size=10, origin=origin, destination=destination, choose_range=choose_range, dtype=tf.float32) outputs = tf.einsum('bij,bj->bi', warp_matrix, inputs) actual_layer_output = self.evaluate(outputs) print(np.array_repr(actual_layer_output)) self.assertAllClose(actual_layer_output, expected_output)
def _BroadcastExamplePairLabelsToAllItemPairs(example_pair_labels: tf.Tensor, queries_shape: tf.TensorShape, results_shape: tf.TensorShape): """Propagates each example-pair label to all pairs of their items. Args: example_pair_labels: Labels tensor for example pairs, shape [query_batch_size, result_batch_size]. queries_shape: Batch shape of the query examples. Must start with `query_batch_size`. results_shape: Batch shape of the query examples. Must start with `result_batch_size`. Returns: A labels tensor with shape `queries_shape + results_shape`. """ example_pair_labels.shape.assert_has_rank(2) queries_shape.assert_is_fully_defined() results_shape.assert_is_fully_defined() # Expand [q, r] to [q, 1, ..., r, 1, ...] all_slice = slice(None, None, None) reshape_slice = ((all_slice,) + (None,) * (queries_shape.rank - 1) + (all_slice,) + (None,) * (results_shape.rank - 1)) return tf.broadcast_to(example_pair_labels[reshape_slice], queries_shape + results_shape)
def MultiLabelContrastiveLoss(labels, logits, axis: int = -1): """Computes a multi-label generalization of softmax cross entropy loss. This loss generalizes softmax cross entropy in the following sense. - If `labels` are one-hot (over `axis`), this loss is equivalent to softmax cross entropy. Note in this case the per-example loss can be interpreted as -log(p(positive_class)). Here p() is a distribution of over C classes, namely 1 positive class and C-1 negative classes. - In general, if `labels` are N-hot, this function computes the loss `sum_i{ -log(p_i(positive_class_i)) } / N` where p_i() is a distribution over the i'th positive class and the C-N negative classes. Note unlike `tf.nn.softmax_cross_entropy_with_logits()`, this function does not support "soft" labels. Positive and negative labels must be represented as 1 and 0, respectively. Setting a label to any other value causes the example- class pair to be ignored in the loss calculation. This is intended as a feature, to give callers fine-grained control over which pairs are used in the loss. Args: labels: Tensor of labels. Must have the same shape as `logits`. logits: Tensor of logits (scores). Must have the same shape as `labels`. axis: The class dimension, i.e. the one over which probability distributions are normalized. Returns: A Tensor of per-example losses. Has the same type as `logits`, and the same shape, except without `axis`. Typically `labels` and `logits` are both [batch_size, num_classes], in which case the result is [batch_size]. """ labels.shape.assert_is_compatible_with(logits.shape) # Set logits for non-negative pairs to -inf so they're effectively ignored. is_negative_pair = tf.equal(labels, 0) negative_pair_logits = tf.where(is_negative_pair, logits, tf.broadcast_to(float('-inf'), logits.shape)) # Compute binary logits, log(p / (1-p)). Shift inputs by the max negative-pair # score to improve numerical precision. The reason is that # tf.reduce_logsumexp(x) == max(x) + log(sum_i(exp(x[i] - max(x)))) # and if the max is sufficiently large the second term disappears as round-off # error. adjustment = tf.reduce_max(negative_pair_logits, axis=axis, keepdims=True) binary_logits = (logits - adjustment) - tf.reduce_logsumexp( negative_pair_logits - adjustment, axis=axis, keepdims=True) # Accumulate the losses of each positive sample vs. all negative ones. Note # -log_sigmoid == sigmoid_cross_entropy_with_logits in the special case that # all labels are 1. is_positive_pair = tf.cast(tf.equal(labels, 1), binary_logits.dtype) losses = is_positive_pair * -tf.math.log_sigmoid(binary_logits) num_positives = tf.reduce_sum(is_positive_pair, axis=axis) return tf.reduce_sum(losses, axis=axis) / num_positives
def _BBox2DImage(self, bbox_corners_image, input_images): """Compute [xmin, ymin, xmax, ymax] 2D bounding boxes from corners.""" # Clip the boundaries of the bounding box to the image width/height. bci_x = bbox_corners_image[..., 0:1] image_width = tf.broadcast_to( input_images.width[..., tf.newaxis, tf.newaxis], tf.shape(bci_x)) bci_x = tf.clip_by_value(bci_x, 0.0, tf.cast(image_width, tf.float32)) bci_y = bbox_corners_image[..., 1:2] image_height = tf.broadcast_to( input_images.height[..., tf.newaxis, tf.newaxis], tf.shape(bci_y)) bci_y = tf.clip_by_value(bci_y, 0.0, tf.cast(image_height, tf.float32)) bbox_corners_image_clipped = tf.concat([bci_x, bci_y], axis=-1) # Compute the [xmin, ymin, xmax, ymax] bounding boxes from [batch, # num_boxes, 8, 2] extrema. min_vals = tf.math.reduce_min(bbox_corners_image_clipped, axis=2) max_vals = tf.math.reduce_max(bbox_corners_image_clipped, axis=2) bbox2d_corners_image = tf.concat([min_vals, max_vals], axis=2) return bbox2d_corners_image
def _Merge(*xs): """Broadcast all dimensions except the last, and concat on last dim.""" # Stack all shapes and take max on each dimension to get leading shape. leading_shape = tf.stack([tf.shape(x)[:-1] for x in xs]) leading_shape = tf.reduce_max(leading_shape, axis=0) # Broadcast each x. broadcast_xs = [] for x in xs: broadcast_shape = tf.concat([leading_shape, tf.shape(x)[-1:]], axis=0) broadcast_xs.append(tf.broadcast_to(x, broadcast_shape)) # Concat on last dimension. concat_xs = tf.concat(broadcast_xs, axis=-1) return concat_xs
def _PaddedMaxFn(inp): """Apply padded max using reduce_max with paddings replaced by neginf.""" # Replace all padded features with -inf. neginf_padding = tf.where(inp.padding > 0, -np.inf * inp.padding, inp.padding) features = inp.features + neginf_padding[..., tf.newaxis] features = tf.reduce_max(features, axis=-2) # Replace features of all padded points by zeros. If a batch of points are # all padded, then reduce_min over the padding will be 1. We set the # features to be zero, so that we don't get any downstream issue with # NaNs. Note that inf * 0 = NaN. all_padded = tf.cast(tf.reduce_min(inp.padding, axis=-1), tf.bool) all_padded = tf.broadcast_to(all_padded[..., tf.newaxis], py_utils.GetShape(features)) features = tf.where(all_padded, tf.zeros_like(features), features) return py_utils.CheckNumerics(features)
def _FrequencyMask(self, inputs, global_seed, dtype=tf.float32, domain_id_index=0): """Applies frequency masking with given degree to inputs. Args: inputs: Batch of input features of shape (batch_size, time_length, num_freq, channels). global_seed: an integer seed tensor for stateless random ops. dtype: Data type. domain_id_index: domain id index. Returns: Inputs with random frequency masking applied. """ p = self.params # Mask parameters. freq_mask_max_bins = p.freq_mask_max_bins[domain_id_index] multiplicity = p.freq_mask_count[domain_id_index] # If masking length or count is zero, do nothing. if freq_mask_max_bins == 0 or multiplicity == 0: return inputs # Arguments to pass to mask generator. batch_size, _, num_freq, _ = py_utils.GetShape(inputs) choose_range = tf.cast(tf.broadcast_to(num_freq, (batch_size, )), dtype=tf.int32) # Create masks in frequency direction and apply. block_arrays = self._GetMask(tf.shape(inputs)[0], choose_range=choose_range, mask_size=num_freq, global_seed=global_seed, max_length=freq_mask_max_bins, masks_per_frame=0.0, multiplicity=multiplicity, dtype=dtype, max_ratio=1.0) outputs = tf.einsum('bxyc,by->bxyc', inputs, block_arrays) return outputs
def IsSpecialExample(task_ids, special_task_ids): """A utility function indicates whether inputs belong to specific tasks. Args: task_ids: Task ids for the input batch. Tensor of shape [batch]. special_task_ids: A list of specified task ids. Returns: A tensor indicating whether each sample in the batch belong to the specified task. Return a tensor of size [batch]. """ batch_size = py_utils.GetShape(task_ids)[0] return tf.reduce_any( tf.equal( tf.expand_dims(task_ids, -1), tf.cast( tf.broadcast_to( special_task_ids, [batch_size, len(special_task_ids)]), tf.int32)), -1)
def _PaddedMeanFn(inp): """Apply padded mean using reduce_sum and dividing by # real points.""" # Replace all padded features with 0 by masking the padded features out. mask = 1 - inp.padding features = inp.features * mask[..., tf.newaxis] features = tf.reduce_sum(features, axis=-2) num_real_points = tf.reduce_sum(mask, axis=-1, keep_dims=True) # Prevent the divisor of our padded mean from ever being 0, so that # the gradient flowing back through this op doesn't give us NaNs. num_real_points = tf.maximum(num_real_points, 1) features = features / num_real_points # Replace features of all padded points by zeros. If a batch of points are # all padded, then num_real_points will be zero. We set the features to be # zero, so that we don't get any downstream issue with NaNs. # Note that inf * 0 = NaN. all_padded = tf.equal(num_real_points, 0.) all_padded = tf.broadcast_to(all_padded, py_utils.GetShape(features)) features = tf.where(all_padded, tf.zeros_like(features), features) return py_utils.CheckNumerics(features)
def testSpectrumAugmenterWithDynamicTimeWarping(self): with self.session(use_gpu=False, graph=tf.Graph()): tf.random.set_seed(1234) inputs = tf.broadcast_to(tf.cast(tf.range(10), dtype=tf.float32), (3, 10)) inputs = tf.expand_dims(tf.expand_dims(inputs, -1), -1) paddings = [] for i in range(3): paddings.append( tf.concat([tf.zeros([1, 2 * i + 5]), tf.ones([1, 5 - 2 * i])], axis=1)) paddings = tf.concat(paddings, axis=0) p = spectrum_augmenter.SpectrumAugmenter.Params() p.name = 'specAug_layers' p.freq_mask_max_bins = 0 p.time_mask_max_frames = 0 p.time_warp_max_ratio = 0.5 p.time_warp_bound = 'dynamic' p.random_seed = 34567 specaug_layer = p.Instantiate() # pyformat: disable # pylint: disable=bad-whitespace,bad-continuation expected_output = np.array( [[[[0.0000000]], [[1.0000000]], [[2.0000000]], [[3.0000000]], [[4.0000000]], [[5.0000000]], [[6.0000000]], [[7.0000000]], [[8.0000000]], [[9.0000000]]], [[[0.0000000]], [[0.8333333]], [[1.6666666]], [[2.5000000]], [[3.3333333]], [[4.1666665]], [[5.0000000]], [[7.0000000]], [[8.0000000]], [[9.0000000]]], [[[0.0000000]], [[2.0000000]], [[2.8750000]], [[3.7500000]], [[4.6250000]], [[5.5000000]], [[6.3750000]], [[7.2500000]], [[8.1250000]], [[9.0000000]]]]) # pylint: enable=bad-whitespace,bad-continuation # pyformat: enable h, _ = specaug_layer.FPropDefaultTheta(inputs, paddings) actual_layer_output = self.evaluate(h) print(np.array_repr(actual_layer_output)) self.assertAllClose(actual_layer_output, expected_output)
def _BatchScatter(default_tensor, indices, values): """Performs tf.tensor_scatter_nd_update for each batch item. Args: default_tensor: A float tensor of shape [batch, vocab] that contains the default values. indices: An int tensor of shape [batch, k] that represents the k indices of `default_tensor` to update. values: A float tensor of shape [batch, k] that represents the value to replace with for each corresponding element of `indices`. Returns: A tensor like `default_tensor` where the (i, indices[i][j]) element has been replaced with values[i][j]. """ batch_size = tf.shape(default_tensor)[0] # Prepend batch indices to `indices`. batch_indices = tf.range(batch_size, dtype=indices.dtype) batch_indices = tf.expand_dims(batch_indices, 1) batch_indices = tf.broadcast_to(batch_indices, tf.shape(indices)) batch_indices = tf.stack([batch_indices, indices], axis=2) return tf.tensor_scatter_nd_update(default_tensor, batch_indices, values)
def _IgnorePairsWhere(condition, labels): return tf.where(condition, tf.broadcast_to(IGNORE_PAIR_LABEL, labels.shape), labels)
def _EncodeToIds(self, word): # Below: # * a token is a wordpiece ID. # * the tokens array will be merged in-place. # * the candidates array is an array of size len(tokens) - 1. # It contains the token for the merged wordpiece, if it exists, # -1 otherwise. For instance, candidate[3] = id(token[3] + token[4]). # First, split into basic UTF-8 characters (letters). chars = tf.strings.unicode_split(word, 'UTF-8') tokens = self._StringToToken(chars) tokens = tf.where( tf.equal(tokens, NO_TOKEN), # Unseen character. tf.broadcast_to(self.unk_id, tf.shape(tokens)), tokens) # Create initial candidate list. candidates = tf.map_fn(self._MergeTokens, (tokens[:-1], tokens[1:]), dtype=tokens.dtype) def _ShouldMerge(unused_tokens, candidates): """Merge until not possible, or we abort early according to merge_prob.""" return tf.logical_and( tf.reduce_any(tf.not_equal(candidates, NO_TOKEN)), tf.random.uniform([]) < self._merge_prob) def _MergeOneToken(tokens, i): return tf.expand_dims(self._MergeTokens( (tokens[i], tokens[i + 1])), axis=-1) def _MergeCandidates(tokens, candidates): """Merge in the reverse binary tree.""" best_id = tf.argmin(candidates, output_type=tf.int32) # Perform the merge at position best_id. tokens = tf.concat([ tokens[:best_id], [candidates[best_id]], tokens[best_id + 2:] ], axis=0) # Recompute the merge candidates. # Only the neighbors of best_id need to be recomputed. empty = tf.zeros([0], dtype=candidates.dtype) def _MergeLeft(): return tf.concat([ candidates[:best_id - 1], _MergeOneToken(tokens, best_id - 1) ], axis=0) left_candidates = tf.cond(tf.equal(best_id, 0), lambda: empty, _MergeLeft) def _MergeRight(): return tf.concat([ _MergeOneToken(tokens, best_id), candidates[best_id + 2:] ], axis=0) right_candidates = tf.cond( tf.greater_equal(best_id, tf.size(tokens) - 1), lambda: empty, _MergeRight) candidates = tf.concat([left_candidates, right_candidates], axis=0) return tokens, candidates return tf.while_loop(_ShouldMerge, _MergeCandidates, (tokens, candidates), parallel_iterations=1, back_prop=False)[0]
def _StringToToken(self, tokstr): return tf.where(py_x_ops.token_in_vocab(tokstr, vocab=self._pieces), py_x_ops.vocab_token_to_id(tokstr, vocab=self._pieces), tf.broadcast_to(NO_TOKEN, tf.shape(tokstr)))
def GatherK(selected_pos, values, k, num_devices=1): """Gather up to k elements from given tensors at selected pos under SPMD. Example:: # Input k = 3 selected_pos = [ [0, 0, 1, 1], [0, 1, 1, 0], [0, 0, 0, 0], [1, 1, 1, 0], [1, 1, 1, 1], # topk(k=3) largest indices are selected in this row. ] value_2d = [ [1, 3, 5, 7], [9, 11, 13, 15], [17, 19, 21, 23], [25, 27, 29, 31], [33, 35, 37, 39], ] # Output: output = [ [0, 5, 7], [0, 11, 13], [0, 0, 0], [25, 27, 29], [35, 37, 39], ] # Output padding: output_padding = [ [1, 0, 0], [1, 0, 0], [1, 1, 1], [0, 0, 0], [0, 0, 0], ] Args: selected_pos: a 0/1 2D tf.int32 tensor of shape [batch, time]. values: a list of tensors, the rank of each is at least rank=2. [batch, time, ...]. k: a scalar tf.int32 tensor or a Python int. On TPU, k must be a compile-time constant. num_devices: number of TPU devices used in xla_sharding SPMD. Returns: A tuple (output, padding). - output: a list of tensors of shape [batch, k, ...]. - padding: a 2D 0/1 tensor of shape [batch, k], '1's are padded locations. """ global_batch, seq_len = py_utils.GetShape(selected_pos, 2) if num_devices: device_batch = global_batch // num_devices else: device_batch = global_batch for i in range(len(values)): # Assert the first 2 dim of values[i] is [global_batch, seq_len] values[i] = py_utils.HasShape(values[i], [global_batch, seq_len], 2) # indices are 1-based for now, to distinguish between padding and selected # locations. indices = 1 + tf.range(tf.shape(values[0])[1], dtype=tf.int32) # [1, seq_len] indices = tf.expand_dims(indices, axis=0) # if 0, the position is not selected. # [1, seq_len] * [global_batch, seq_len] => [global_batch, t] # -- topk --> [global_batch, k] topk_indices, _ = tf.math.top_k( indices * tf.cast(selected_pos, indices.dtype), k) # [global_batch, k], sorted in ascending order. indices = tf.reverse(topk_indices, [-1]) # [global_batch, k], padded positions are '1's. padding = tf.cast(tf.equal(indices, 0), values[0].dtype) padding = Split(padding, 0, num_devices) # [global_batch, k], zero_based_indices mp_idx = tf.maximum(0, indices - 1) mp_idx = Split(mp_idx, 0, num_devices) # [device_batch, k] if num_devices > 1 and py_utils.use_tpu(): mp_idx = xla_sharding.auto_to_manual_spmd_partition( mp_idx, xla_sharding.get_op_sharding(mp_idx.op)) # [device_batch, k, 1] mp_idx = tf.expand_dims(mp_idx, -1) # [device_batch] batch_ids = tf.range(device_batch, dtype=tf.int32) # [device_batch, 1, 1] batch_ids = tf.reshape(batch_ids, [device_batch, 1, 1]) # [device_batch, k, 1] batch_ids = tf.broadcast_to(batch_ids, [device_batch, k, 1]) # [device_batch, k, 2] final_indices = tf.concat([batch_ids, mp_idx], axis=-1) output = [] for v in values: # Begin manually partition gather. v = Split(v, 0, num_devices) v_shape = v.shape.as_list() if num_devices > 1 and py_utils.use_tpu(): op_sharding = xla_sharding.get_op_sharding(v.op) v = xla_sharding.auto_to_manual_spmd_partition(v, op_sharding) # Returns [global_batch, k, ...] v_out = tf.gather_nd(v, final_indices) if num_devices > 1 and py_utils.use_tpu(): v_shape[1] = k v_out = xla_sharding.manual_to_auto_spmd_partition( v_out, op_sharding, full_shape=tf.TensorShape(v_shape)) output.append(v_out) return output, padding
def _GetMask(self, batch_size, choose_range, mask_size, global_seed, max_length=None, masks_per_frame=0.0, multiplicity=1, dtype=tf.float32, max_ratio=1.0): """Returns fixed size multi-masks starting from random positions. A multi-mask is a mask obtained by applying multiple masks. This function when max_length is given: 1) Sample random mask lengths less than max_length with shape (batch_size, multiplicity). 2) Truncate lengths to a max of (choose_range * max_ratio), so that each mask is fully contained within the corresponding sequence. 3) Random sample start points of shape (batch_size, multiplicity) with in (choose_range - lengths). 4) For each batch, multiple masks (whose number is given by the multiplicity) are constructed. 5) Return a mask of shape (batch_size, mask_size) where masks are obtained by composing the masks constructed in step 4). If masks_per_frame > 0, the number is given by min(masks_per_frame * choose_range, multiplicity). If not, all the masks are composed. The masked regions are set to zero. This function when max_length is not given: 1) Sample random mask lengths less than (choose_range * max_ratio) with shape (batch_size, multiplicity). 2) Proceed to steps 3), 4) and 5) of the above. Args: batch_size: Batch size. Integer number. choose_range: Range within which the masked entries must lie. Tensor of shape (batch_size,). mask_size: Size of the mask. Integer number. global_seed: an integer seed tensor for stateless random ops. max_length: Maximum number of allowed consecutive masked entries. Integer number or None. masks_per_frame: Number of masks per frame. Float number. If > 0, the multiplicity of the mask is set to be masks_per_frame * choose_range. multiplicity: Maximum number of total masks. Integer number. dtype: Data type. max_ratio: Maximum portion of the entire range allowed to be masked. Float number. Returns: mask: a fixed size multi-mask starting from a random position with shape (batch_size, mask_size). """ p = self.params # Non-empty random seed values are only used for testing or when using # stateless random ops. seed_1 and seed_2 are set separately to avoid # correlation of mask size and mask position. if p.use_input_dependent_random_seed: seed_1 = global_seed + 1 seed_2 = global_seed + 2 elif p.random_seed: seed_1 = p.random_seed + 1 seed_2 = 2 * p.random_seed else: seed_1 = p.random_seed seed_2 = p.random_seed # Sample lengths for multiple masks. if max_length and max_length > 0: max_length = tf.broadcast_to(tf.cast(max_length, dtype), (batch_size, )) else: max_length = tf.cast(choose_range, dtype=dtype) * max_ratio random_uniform = _random_uniform_op(p.use_input_dependent_random_seed) masked_portion = random_uniform(shape=(batch_size, multiplicity), minval=0.0, maxval=1.0, dtype=dtype, seed=seed_1) masked_frame_size = self.EinsumBBmBm(max_length, masked_portion) masked_frame_size = tf.cast(masked_frame_size, dtype=tf.int32) # Make sure the sampled length was smaller than max_ratio * length_bound. # Note that sampling in this way was biased # (shorter sequence may over-masked.) choose_range = tf.expand_dims(choose_range, -1) choose_range = tf.tile(choose_range, [1, multiplicity]) length_bound = tf.cast(choose_range, dtype=dtype) length_bound = tf.cast(max_ratio * length_bound, dtype=tf.int32) length = tf.minimum(masked_frame_size, tf.maximum(length_bound, 1)) # Choose starting point. random_start = random_uniform(shape=(batch_size, multiplicity), maxval=1.0, seed=seed_2) start_with_in_valid_range = random_start * tf.cast( (choose_range - length + 1), dtype=dtype) start = tf.cast(start_with_in_valid_range, tf.int32) end = start + length - 1 # Shift starting and end point by small value. delta = tf.constant(0.1) start = tf.expand_dims(tf.cast(start, dtype) - delta, -1) start = tf.tile(start, [1, 1, mask_size]) end = tf.expand_dims(tf.cast(end, dtype) + delta, -1) end = tf.tile(end, [1, 1, mask_size]) # Construct pre-mask of shape (batch_size, multiplicity, mask_size). diagonal = tf.expand_dims( tf.expand_dims(tf.cast(tf.range(mask_size), dtype=dtype), 0), 0) diagonal = tf.tile(diagonal, [batch_size, multiplicity, 1]) pre_mask = tf.cast(tf.math.logical_and(diagonal < end, diagonal > start), dtype=dtype) # Sum masks with appropriate multiplicity. if masks_per_frame > 0: multiplicity_weights = tf.tile( tf.expand_dims(tf.range(multiplicity, dtype=dtype), 0), [batch_size, 1]) multiplicity_tensor = masks_per_frame * tf.cast(choose_range, dtype=dtype) multiplicity_weights = tf.cast( multiplicity_weights < multiplicity_tensor, dtype=dtype) pre_mask = self.EinsumBmtBmBt(pre_mask, multiplicity_weights) else: pre_mask = tf.reduce_sum(pre_mask, 1) mask = tf.cast(1.0 - tf.cast(pre_mask > 0, dtype=dtype), dtype=dtype) if p.fprop_dtype is not None and p.fprop_dtype != p.dtype: mask = tf.cast(mask, p.fprop_dtype) return mask
def _ConstructWarpMatrix(self, batch_size, matrix_size, origin, destination, choose_range, dtype): """Returns warp matrices according to origin, destination and choose_range. This function constructs a batch of warp matrices which maps the batch of origin points to the batch of destination points with fixed boundary coordinates at 0 and choose_range. The warping function, defined by the origin anchor point `origin`, the destination of the origin anchor point `destination` and the length of the domain in the warping axis `choose_range` is a piecewise linear map that fixes the points 0 and `choose_range` and maps `origin` to `destination`. For the warping matrix to be non-singular, destination must lie in the range 1<= destination <= choose_range - 1, so a destination out of this range is adjusted to be in this range before the warping matrix is constructed. The warping map can be explicitly written by first defining the slopes: 1) slope_0 = origin / destination. 2) slope_1 = (choose_range - origin) / (choose_range - destination). 3) slope_2 = 1.0. Then the origin point orig_i of the mapped coordinate i is given by: 1) i < destination: orig_i = slope_0 * i. 2) destination <= i < choose_range: orig_i = slope_1 * i - (slope_1 - slope_0) * destination. 3) i >= choose_range: orig_i = i. Denoting n_i = ceil(orig_i), the warp matrix element warp[i][j] is given by: 1) j = n_i: 1 - n_i + orig_i. 2) j = n_i - 1: n_i - orig_i. 3) Otherwise: 0. Applying the warp matrix to an array of pixels, i.e., warped_pixel[i] = sum_j warp[i][j] * pixel[j], one would get warped_pixel[i] = (n_i-orig_i) pixel[n_i-1] + (1-n_i+orig_i) pixel[n_i]. Args: batch_size: Batch size. Integer number. matrix_size: Dimension of the vector space the warp matrix is applied to. Integer number. origin: Origin anchor point for warping. Tensor of shape (batch_size,) and data type dtype. destination: Destination of the origin anchor point upon warping. Tensor of shape (batch_size,) and data type dtype. choose_range: Range within which the warp reference points must lie. Tensor of shape (batch_size,) data type dtype. dtype: Data type of origin, destination, choose_range and the output warp matrix. Returns: warp_matrix: An array of fixed size warp matrices with shape (batch_size, matrix_size, matrix_size). """ p = self.params # Entries of destination must be in the range # 1 <= destination <= choose_range - 1 # for warp matrix to have non-singular values. destination = tf.minimum(tf.maximum(destination, 1.0), choose_range - 1.0) # Construct piece-wise linear function fixing boundary points # specified by zero, choose_range and matrix size and maps # the origin anchor point to the destination. destination_bc = tf.broadcast_to(destination, (matrix_size, batch_size)) destination_bc = tf.transpose(destination_bc) choose_range_bc = tf.broadcast_to(choose_range, (matrix_size, batch_size)) choose_range_bc = tf.transpose(choose_range_bc) # Slopes of piece-wise linear function. slope_0 = origin / destination slope_1 = (choose_range - origin) / (choose_range - destination) slope_2 = 1.0 # x is a batch of origin matrices. # The origin matrix is the matrix such that # origin[i][j] = Origin coordinate of coordinate i for the warp map. # Denoting the destination of the origin anchor point in the # warp map as "dest," the origin coordinate of point i is given by: # 1) i < dest: slope_0 * i. # 2) dest <= i < choose_range: slope_1 * i - (slope_1 - slope_0) * dest. # 3) i >= choose_range: i. x = tf.broadcast_to(tf.cast(tf.range(matrix_size), dtype=dtype), (batch_size, matrix_size)) x = (self.EinsumBBmBm(slope_0, x) + self.EinsumBBmBm( slope_1 - slope_0, tf.nn.relu(x - destination_bc)) + self.EinsumBBmBm(slope_2 - slope_1, tf.nn.relu(x - choose_range_bc))) x = tf.broadcast_to(x, (matrix_size, batch_size, matrix_size)) x = tf.transpose(x, perm=[1, 2, 0]) # y is a batch of coordinate matrices. # A coordinate matrix is a matrix such that # coordinate[i][j] = j. y = tf.broadcast_to(tf.cast(tf.range(matrix_size), dtype=dtype), (batch_size, matrix_size, matrix_size)) # Warp matrix is obtained by applying hat function element-wise to (x-y). # Denoting the origin point of i under the warp map as orig_i, # and n_i = ceil(orig_i), the warp matrix element warp[i][j] is given by: # 1) j = n_i: 1 - n_i + orig_i. # 2) j = n_i - 1: n_i - orig_i. # 3) Otherwise: 0. # Applying the warp matrix to pixels, i.e., # warped_pixel[i] = sum_j warp[i][j] * original_pixel[j], one would get # warped_pixel[i] = (n_i - orig_i) * original_pixel[n_i-1] # + (1 - n_i + orig_i) * original_pixel[n_i]. warp_matrix = x - y warp_matrix = _hat(warp_matrix) if p.fprop_dtype is not None and p.fprop_dtype != dtype: warp_matrix = tf.cast(warp_matrix, p.fprop_dtype) return warp_matrix
def _AddNoise(self, batch): """Adding noise the src (see https://arxiv.org/pdf/1711.00043). This function implement 3 types of noise (hyparams defined in self.params.denoise): 1) slightly shuffle the sentence following p.shuffle_tok_range 2) randomly drop tokens with probability p.drop_tok_prob 3) randomly mask tokens with probability p.blank_tok_prob The noises are added to the input with probability p.noise_sent_prob. Args: batch: a `.NestedMap` of the input batch. """ def IsSpecialExample(task_ids, special_task_ids): """A utility function indicates whether inputs belong to specific tasks. Args: task_ids: Task ids for the input batch. Tensor of shape [batch]. special_task_ids: A list of specified task ids. Returns: A tensor indicating whether each sample in the batch belong to the specified task. Return a tensor of size [batch]. """ batch_size = py_utils.GetShape(task_ids)[0] return tf.reduce_any( tf.equal( tf.expand_dims(task_ids, -1), tf.cast( tf.broadcast_to( special_task_ids, [batch_size, len(special_task_ids)]), tf.int32)), -1) p = self.params.denoise batch_size = tf.shape(batch.src.ids)[0] source_max_len = tf.shape(batch.src.ids)[1] # Shuffle tokens according to p.shuffle_tok_range noise = tf.random.uniform([batch_size, source_max_len], 0, p.shuffle_tok_range + 1) # Don't shuffle eos or padding shuffle_tok_range = tf.fill([batch_size, source_max_len], float(p.shuffle_tok_range)) shifted_paddings = tf.pad(batch.src.paddings[:, 1:], [[0, 0], [0, 1]], constant_values=1) noise = tf.where(tf.equal(shifted_paddings, 0), noise, shuffle_tok_range) indices = tf.broadcast_to(tf.range(source_max_len, dtype=tf.int32), [batch_size, source_max_len]) noisy_indices = tf.cast(indices, dtype=tf.float32) + noise permutations = tf.argsort(noisy_indices) stacked = tf.stack([batch.src.ids, permutations], axis=1) denoise_src_ids = tf.stack(tf.map_fn(lambda x: tf.gather(x[0], x[1]), stacked), axis=0) # Select tokens to drop with probability=p.drop_tok_prob random_drop_tok = tf.random.uniform([batch_size, source_max_len]) # Don't drop eos token is_keep_tok = tf.math.logical_or( tf.greater(random_drop_tok, p.drop_tok_prob), tf.equal(denoise_src_ids, self._src_tokenizer.eos_id)) denoise_src_ids = tf.ragged.boolean_mask( denoise_src_ids, is_keep_tok).to_tensor(default_value=0, shape=tf.shape(batch.src.ids)) denoise_src_paddings = tf.ragged.boolean_mask( batch.src.paddings, is_keep_tok).to_tensor(default_value=1, shape=tf.shape(batch.src.ids)) # Select tokens to blank with probability=p.blank_tok_prob # Don't blank eos token random_blank_tok = tf.random.uniform([batch_size, source_max_len]) shifted_paddings = tf.pad(denoise_src_paddings[:, 1:], [[0, 0], [0, 1]], constant_values=1) is_blank_tok = tf.math.logical_and( tf.less(random_blank_tok, p.blank_tok_prob), tf.equal(shifted_paddings, 0)) blank_id = tf.fill([batch_size, source_max_len], p.blank_id) denoise_src_ids = tf.where(is_blank_tok, blank_id, denoise_src_ids) # Select denoising task examples with probability=p.denoise_sent_prob random_uniform_sent = tf.random.uniform([batch_size]) is_denoise_sent = tf.math.logical_and( tf.less(random_uniform_sent, p.noise_sent_prob), IsSpecialExample(self._GetTaskIds(batch.src.source_ids[:, 0]), p.task_ids)) batch.src.ids = tf.where(is_denoise_sent, denoise_src_ids, batch.src.ids) batch.src.paddings = tf.where(is_denoise_sent, denoise_src_paddings, batch.src.paddings) batch.src.ids_indicator = 1 - batch.src.paddings batch.src.weights = batch.src.ids_indicator