def _KeepTopP(sorted_log_probs, p): """Keeps the top-p probability mass of `sorted_log_probs`. For each row, elements that are not included in the first `p` probability mass are set to `LARGE_NEGATIVE_NUMBER`. The first element is always kept as-is. Args: sorted_log_probs: A float tensor of shape [batch, k] that represents log-probabilities sorted in descending order. The probabilities do not need to sum to 1. p: A float tensor of shape [batch] that represents a probability threshold for each batch item. Returns: A tensor like `sorted_log_probs` where elements outside the top-p probability mass are set to `LARGE_NEGATIVE_NUMBER`. """ sorted_cum_probs = tf.math.cumsum(tf.exp(sorted_log_probs), exclusive=True, axis=-1) mask = tf.less(sorted_cum_probs, tf.expand_dims(p, axis=1)) # Set mask[:, 0] = True to always keep the first element. batch_size = tf.shape(mask)[0] true = tf.ones([batch_size, 1], dtype=tf.bool) mask = tf.concat([true, mask[:, 1:]], axis=1) filtered_sorted_log_probs = tf.where( mask, sorted_log_probs, tf.fill( tf.shape(sorted_log_probs), tf.constant(LARGE_NEGATIVE_NUMBER, dtype=sorted_log_probs.dtype))) return filtered_sorted_log_probs
def BuildDataSource(self, data_source_from_file_pattern_fn): """Read and return input batch. Args: data_source_from_file_pattern_fn: a function to read and return input batch from a string file_pattern Returns: A NestedMap containing: data: a tuple of tf.Tensor or `.NestedMap` of tf.Tensor Raises: ValueError: inconsistent sizes between boundaries and datasource_params, specification of unsupported datasources, or out of order boundaries. """ p = self.params if len(p.datasource_params) != len(p.boundaries) + 1: raise ValueError( 'Expected p.datasource_params to have one more entry than ' 'p.boundaries. Found %d datasource_params, and %d boundaries' % (len(p.datasource_params), len(p.boundaries))) for ds_p in p.datasource_params: if 'bprop_variable_filters' in ds_p: if any(filter for filter in ds_p.bprop_variable_filters): raise ValueError( 'CurriculumDataSource does not support distinct ' 'bprop_variable_filters per stage.') for idx in range(len(p.boundaries) - 1): if p.boundaries[idx] > p.boundaries[idx + 1]: raise ValueError( 'Expected p.boundaries to monotonically increase, but ' 'found %d > %d at position %d' % (p.boundaries[idx], p.boundaries[idx + 1], idx)) global_step = py_utils.GetGlobalStep() datasources = [ds_p.Instantiate() for ds_p in p.datasource_params] def GetDatasourceFn(idx): def DatasourceFn(): datasource = datasources[idx].BuildDataSource( data_source_from_file_pattern_fn) datasource.pop('bprop_variable_filters', None) return datasource return DatasourceFn cases = [] for idx in range(len(p.boundaries)): cases.append((tf.less( global_step, tf.constant(p.boundaries[idx], dtype=global_step.dtype)), GetDatasourceFn(idx))) ret = tf.case(cases, default=GetDatasourceFn(-1)) ret.bprop_variable_filters = p.bprop_variable_filters return ret
def _Notvisible(seg_id, seg_pos): a, b = tf.expand_dims(seg_id, -1), tf.expand_dims(seg_id, -2) return tf.cast( tf.math.logical_or( tf.less(tf.expand_dims(seg_pos, -1), tf.expand_dims(seg_pos, -2)), tf.math.logical_or( tf.not_equal(a, b), tf.math.logical_not( tf.math.logical_or( tf.cast(a, tf.bool), tf.cast(b, tf.bool))))), tf.float32)
def GetNext(self): p = self.params global_step = py_utils.GetGlobalStep() cases = [] for idx in range(len(p.boundaries)): cases.append((tf.less( global_step, tf.constant(p.boundaries[idx], dtype=global_step.dtype)), self.sub[idx].GetNext)) return tf.case(cases, default=self.sub[-1].GetNext)
def _Value(self, current_step): """Returns the current clipping cap.""" p = self.params start_step = tf.cast(p.start_step, tf.float32) end_step = tf.cast(p.end_step, tf.float32) current_step = tf.cast(current_step, tf.float32) steps_ratio = ( tf.minimum(end_step - start_step, current_step - start_step) / (end_step - start_step)) rmax_tensor = (steps_ratio * p.end_cap + (1.0 - steps_ratio) * p.start_cap) return tf.cond(tf.less(current_step, p.start_step), lambda: tf.cast(p.start_cap, tf.float32), lambda: tf.cast(rmax_tensor, tf.float32))
def SequenceConcat(x, x_paddings, y, y_paddings, pad=0): """Concats sequence `x` with sequence `y`. This function is length aware (based off the paddings). Args: x: A sequence of tokens of shape [batch_size, x_len_max]. x_paddings: The paddings of `x`. y: A sequence of tokens of shape [batch_size, y_len_max]. y_paddings: The paddings of `y`. pad: The <pad> token to fill the concatenated sequence (of type integer). Returns: A tuple. - Concatenation of `x` and `y` of shape [batch_size, x_len_max + y_len_max]. - Paddings of the concatenation of shape [batch_size, x_len_max + y_len_max]. """ # Get the length (w/ eos). x_len = tf.cast(tf.round(tf.reduce_sum(1 - x_paddings, 1)), tf.int32) y_len = tf.cast(tf.round(tf.reduce_sum(1 - y_paddings, 1)), tf.int32) batch_size = py_utils.GetShape(x)[0] y_len_max = py_utils.GetShape(y)[1] # Pad `x` with necessary <pad>. x = tf.concat([x, tf.fill(py_utils.GetShape(y), pad)], 1) # Replace all <pad> with 0. x = tf.where(tf.not_equal(x, pad), x, tf.fill(py_utils.GetShape(x), 0)) # Compute the write indices of `y` in `xy`. indices = tf.stack([ tf.tile(tf.expand_dims(tf.range(batch_size), 1), [1, y_len_max]), (tf.tile(tf.expand_dims(tf.range(y_len_max), 0), [batch_size, 1]) + tf.expand_dims(x_len, 1)), ], 2) xy = x + tf.scatter_nd(indices, y, py_utils.GetShape(x)) # We need to remap all <pad> to `pad`. xy = tf.where( tf.less(tf.expand_dims(tf.range(py_utils.GetShape(xy)[1]), 0), tf.expand_dims(x_len + y_len, 1)), xy, tf.fill(py_utils.GetShape(xy), pad)) xy_paddings = 1 - tf.sequence_mask(x_len + y_len, py_utils.GetShape(xy)[1], x_paddings.dtype) return xy, xy_paddings
def maybe_update_masks(): with tf.name_scope(self._spec.name): is_step_within_pruning_range = tf.logical_and( tf.greater_equal(self._global_step, self._spec.begin_pruning_step), # If end_pruning_step is negative, keep pruning forever! tf.logical_or( tf.less_equal(self._global_step, self._spec.end_pruning_step), tf.less(self._spec.end_pruning_step, 0))) is_pruning_step = tf.less_equal( tf.add(self._last_update_step, self._spec.pruning_frequency), self._global_step) return tf.logical_and(is_step_within_pruning_range, is_pruning_step)
def _GetFurthestPoint(): """Get point that is furthest from those already selected. We also bias the sampling towards real points by setting the distance to padded points negative until we are out of real points. """ # Set padded points distance to negative so they aren't selected. padding_masked_distance_to_selected = tf.where( tf.equal(padding, 0.0), distance_to_selected, -1.0 * tf.ones( (batch_size, num_points), dtype=tf.float32)) # But only do this when we still have valid points left. padding_masked_distance_to_selected = tf.where( tf.less(curr_idx, num_valid_points), padding_masked_distance_to_selected, distance_to_selected) return tf.argmax(padding_masked_distance_to_selected, axis=-1, output_type=tf.int32)
def _Extract(self, features): """Returns the laser Tensor.""" p = self.params ret = super()._Extract(features) all_vxyz = [] all_classes = [] for lidar in p.lidar_names: for ri in p.lidar_returns: feature_name = 'laser_%s_%s' % (lidar, ri) laser_data = tf.reshape( _Dense(features[feature_name]), [-1, 3 + p.num_features]) num = py_utils.GetShape(laser_data)[0] # We expect lidar_$lidar_$ri and lidar_$lidar_$ri_flow has # same number of points. feature_name += '_flow' laser_data = tf.reshape(_Dense(features[feature_name]), [num, 3 + 1]) points_vxyz = laser_data[..., 0:3] points_classes = laser_data[..., 3] all_vxyz += [points_vxyz] all_classes += [points_classes] # Stack all of the points along the major dimension points_vxyz = tf.concat(all_vxyz, axis=0) points_class = tf.concat(all_classes, axis=0) # The precomputed class uses -1 to mean 5 in our current code. points_class = tf.where( tf.less(points_class, 0), 5. * tf.ones_like(points_class), points_class) if p.max_num_points is not None: assert 'points_padding' in ret points_vxyz = py_utils.PadOrTrimTo(points_vxyz, [p.max_num_points, 3]) points_class = py_utils.PadOrTrimTo(points_class, [p.max_num_points]) assert 'points_xyz' in ret ret.world_flow = points_vxyz ret.pointwise_class = tf.cast(points_class, tf.int32) return ret
def _BodyFn(curr_idx, distance_to_selected, sampled_idx, closest_idx): """Loop body for farthest point sampler.""" def _GetRandomRealPoint(): """Select the first point. For the first point, we want any random real (non padded) point, so we create a random values per point, and then set all padded ones to some large value (more than the maxval). We then take the min per batch element to get the first points. Returns: Tensor containing the index of a random point selected for each example in the batch. """ random_values = tf.random.uniform((batch_size, num_points), minval=0, maxval=1, dtype=tf.float32, seed=random_seed) random_values = tf.where(tf.equal(padding, 0.0), random_values, padding * 10) return tf.argmin(random_values, axis=1, output_type=tf.int32) def _GetFurthestPoint(): """Get point that is furthest from those already selected. We also bias the sampling towards real points by setting the distance to padded points negative until we are out of real points. Returns: Tensor containing the index of the next farthest point selected for each example in the batch. """ # Set padded points distance to negative so they aren't selected. padding_masked_distance_to_selected = tf.where( tf.equal(padding, 0.0), distance_to_selected, -1.0 * tf.ones( (batch_size, num_points), dtype=tf.float32)) # But only do this when we still have valid points left. padding_masked_distance_to_selected = tf.where( tf.less(curr_idx, num_valid_points), padding_masked_distance_to_selected, distance_to_selected) return tf.argmax(padding_masked_distance_to_selected, axis=-1, output_type=tf.int32) def _GetSeededPoint(): """Select a seeded point. Seeded points are assumed to be at the beginning of the original points. Returns: Tensor containing the index of the next seeded point to select for each example in the batch. """ return tf.ones((batch_size, ), dtype=tf.int32) * curr_idx # Select indices for this loop iteration. def _Seeded(): return tf.cond(tf.less(curr_idx, num_seeded_points), _GetSeededPoint, _GetFurthestPoint) def _Real(): return tf.cond(tf.equal(curr_idx, 0), _GetRandomRealPoint, _GetFurthestPoint) new_selected = tf.cond(tf.greater(num_seeded_points, 0), _Seeded, _Real) sampled_idx = sampled_idx.write(curr_idx, new_selected) # Extract the distance to the latest point selected to update # distance_to_selected. new_selected_gather_idx = tf.stack( [tf.range(batch_size), new_selected], axis=1) if precomputed_squared_distance is not None: new_distance = tf.gather_nd(precomputed_squared_distance, new_selected_gather_idx) else: new_points = tf.reshape( tf.gather_nd(points, new_selected_gather_idx), [batch_size, 1, dims]) new_distance = tf.reshape( SquaredDistanceMatrix(points, new_points), [batch_size, num_points]) is_newly_closest = tf.less(new_distance, distance_to_selected) distance_to_selected = tf.minimum(distance_to_selected, new_distance) # Track the index to the closest selected point. new_selected_tiled = tf.tile([[curr_idx]], [batch_size, num_points]) closest_idx = tf.cond( tf.equal(curr_idx, 0), # At the first loop iteration, the init points are the closest. lambda: new_selected_tiled, # Otherwise, update with the new points based on the distances. lambda: tf.where(is_newly_closest, new_selected_tiled, closest_idx) ) return curr_idx + 1, distance_to_selected, sampled_idx, closest_idx
def FarthestPointSampler(points, padding, num_sampled_points, precomputed_squared_distance=None, num_seeded_points=0, random_seed=None): """Samples num_sampled_points from points using farthest point sampling. Algorithm: 1. Start by selecting a random point and adding to a selected set. 2. For all remaining points, find the furthest point from those selected. 3. Add furthest point to selected. 4. Repeat 2-3 until num_sampled_points are selected. More details at https://en.wikipedia.org/wiki/Farthest-first_traversal This output of this function can be used with tf.batch_gather to extract the desired points, for example: tf.batch_gather(points, sampled_idx) Args: points: floating point tf.Tensor of shape [N, P1, dims] padding: A floating point tf.Tensor of shape [N, P1] with 0 if the point is real, and 1 otherwise. num_sampled_points: integer number of points to sample. precomputed_squared_distance: optional tf.Tensor of shape [N, P1, P1] of distances between each point. if None, distances will be computed on the fly. num_seeded_points: If num_seeded_points > 0, then the first num_seeded_points in points are considered to be seeded in the FPS sampling. Note that we assume that these points are *not* padded, and do not check padding when seeding them. random_seed: optional integer random seed to use with all the random ops. Returns: A tuple of tf.Tensors (sampled_idx, closest_idx) of types (tf.int32, tf.int32). sampled_idx is of shape [N, num_sampled_points] representing the indices selected using the sampler. This will have range of [0, P1]. closest_idx is of shape [N, P1] representing the indices of the closest sampled points for each input point. closest_idx is used in PCNN as part of the pooling operation: each point is assigned to the closest sampled point and a max is taken over them. This will have a range of [0, P2] with the index of the closest sampled point that remains. """ points = py_utils.HasRank(points, 3) batch_size, num_points, dims = py_utils.GetShape(points, 3) points = py_utils.with_dependencies( [py_utils.assert_greater_equal(num_points, num_sampled_points)], points) # Add a tiny bit of noise to the distance matrix or points so all # points are unique. This will also ensure true repeated points # like padded points are only selected after all valid points are selected. if precomputed_squared_distance is not None: precomputed_squared_distance = py_utils.HasShape( precomputed_squared_distance, [batch_size, num_points, num_points]) precomputed_squared_distance += tf.random.uniform( (batch_size, num_points, 1), minval=1e-6, maxval=1e-5, dtype=tf.float32, seed=random_seed) else: points += tf.random.uniform((batch_size, num_points, dims), minval=1e-6, maxval=1e-5, dtype=tf.float32, seed=random_seed) # TensorArray to store the sampled indices in the loop. sampled_idx = tf.TensorArray(tf.int32, num_sampled_points) # Initialize distance_to_selected to inf for all points. distance_to_selected = float('inf') * tf.ones((batch_size, num_points)) # For tracking the index to the closest selected point. closest_idx = tf.zeros((batch_size, num_points), dtype=tf.int32) # Current loop index counter. curr_idx = tf.constant(0, dtype=tf.int32) # Get number of valid points (1 is padded, so num_points - num_padded). num_valid_points = tf.cast(tf.cast(num_points, dtype=tf.float32) - tf.reduce_sum(padding, axis=1), dtype=tf.int32) def _BodyFn(curr_idx, distance_to_selected, sampled_idx, closest_idx): """Loop body for farthest point sampler.""" def _GetRandomRealPoint(): """Select the first point. For the first point, we want any random real (non padded) point, so we create a random values per point, and then set all padded ones to some large value (more than the maxval). We then take the min per batch element to get the first points. Returns: Tensor containing the index of a random point selected for each example in the batch. """ random_values = tf.random.uniform((batch_size, num_points), minval=0, maxval=1, dtype=tf.float32, seed=random_seed) random_values = tf.where(tf.equal(padding, 0.0), random_values, padding * 10) return tf.argmin(random_values, axis=1, output_type=tf.int32) def _GetFurthestPoint(): """Get point that is furthest from those already selected. We also bias the sampling towards real points by setting the distance to padded points negative until we are out of real points. Returns: Tensor containing the index of the next farthest point selected for each example in the batch. """ # Set padded points distance to negative so they aren't selected. padding_masked_distance_to_selected = tf.where( tf.equal(padding, 0.0), distance_to_selected, -1.0 * tf.ones( (batch_size, num_points), dtype=tf.float32)) # But only do this when we still have valid points left. padding_masked_distance_to_selected = tf.where( tf.less(curr_idx, num_valid_points), padding_masked_distance_to_selected, distance_to_selected) return tf.argmax(padding_masked_distance_to_selected, axis=-1, output_type=tf.int32) def _GetSeededPoint(): """Select a seeded point. Seeded points are assumed to be at the beginning of the original points. Returns: Tensor containing the index of the next seeded point to select for each example in the batch. """ return tf.ones((batch_size, ), dtype=tf.int32) * curr_idx # Select indices for this loop iteration. def _Seeded(): return tf.cond(tf.less(curr_idx, num_seeded_points), _GetSeededPoint, _GetFurthestPoint) def _Real(): return tf.cond(tf.equal(curr_idx, 0), _GetRandomRealPoint, _GetFurthestPoint) new_selected = tf.cond(tf.greater(num_seeded_points, 0), _Seeded, _Real) sampled_idx = sampled_idx.write(curr_idx, new_selected) # Extract the distance to the latest point selected to update # distance_to_selected. new_selected_gather_idx = tf.stack( [tf.range(batch_size), new_selected], axis=1) if precomputed_squared_distance is not None: new_distance = tf.gather_nd(precomputed_squared_distance, new_selected_gather_idx) else: new_points = tf.reshape( tf.gather_nd(points, new_selected_gather_idx), [batch_size, 1, dims]) new_distance = tf.reshape( SquaredDistanceMatrix(points, new_points), [batch_size, num_points]) is_newly_closest = tf.less(new_distance, distance_to_selected) distance_to_selected = tf.minimum(distance_to_selected, new_distance) # Track the index to the closest selected point. new_selected_tiled = tf.tile([[curr_idx]], [batch_size, num_points]) closest_idx = tf.cond( tf.equal(curr_idx, 0), # At the first loop iteration, the init points are the closest. lambda: new_selected_tiled, # Otherwise, update with the new points based on the distances. lambda: tf.where(is_newly_closest, new_selected_tiled, closest_idx) ) return curr_idx + 1, distance_to_selected, sampled_idx, closest_idx _, _, sampled_idx, closest_idx = tf.while_loop( lambda curr_idx, *args: tf.less(curr_idx, num_sampled_points), _BodyFn, loop_vars=(curr_idx, distance_to_selected, sampled_idx, closest_idx), back_prop=False, maximum_iterations=num_sampled_points) sampled_idx = sampled_idx.stack() # num_sampled_points x n sampled_idx = tf.transpose(sampled_idx, [1, 0]) if isinstance(batch_size, int) and isinstance(num_sampled_points, int): sampled_idx.set_shape((batch_size, num_sampled_points)) return sampled_idx, closest_idx
def ProcessFeatures(self, features): """Process extracted features. Args: features: A dict of extracted Tensors from the records. Returns: A tuple of tensors: - bucket_id: A scalar int Tensor. - extracted: a NestedMap of Tensors extracted. """ def ExtractAndFilter(e): with tf.name_scope(e.params.name): with tf.name_scope('extract'): # Filter out extracted features from other extractors. filtered_features = {} if self.params.record_type == 'TEXT': # Text extractors only produce {'line': record} and their # FeatureMap() is empty, so don't do any filtering. filtered_features = features else: filtered_keys = e.FeatureMap().keys() | e.ContextMap().keys() filtered_features = { k: v for k, v in features.items() if k in filtered_keys } try: if self.params.batched_input: extracted = e.ExtractBatch(filtered_features) else: extracted = e.Extract(filtered_features) except Exception as exc: # pylint:disable=bare-except # Raise exception with context about which extractor failed. raise RuntimeError('Failed running extractor ' f'{e.params.name}. ' 'See above exception for details.') from exc with tf.name_scope('filter'): if self.params.batched_input: bucket = e.FilterBatch(extracted) else: bucket = e.Filter(extracted) return bucket, extracted bucket_extracted = self._extractors.Transform(ExtractAndFilter) buckets = bucket_extracted.Transform(lambda x: x[0]) extracted = bucket_extracted.Transform(lambda x: x[1]) # Return the maximum bucket id so that any extractor can decide whether # to filter the entire example. max_bucket = tf.reduce_max(buckets.Flatten()) def NullLike(): """A function to return the same Tensor signature as Preprocess. This is necessary for the tf.cond() to avoid executing the preprocessor for examples that are going to be dropped because it exceeds the bucket limit; tf.cond() requires that the output of both branches yields the same structure. Returns: A structure with the same Tensor dtype as the output of Preprocess. """ shapes = self.Shape() rets = [] for dtype, shape in zip(self.DType().Flatten(), shapes.Flatten()): if shape.is_fully_defined(): rets += [tf.zeros(dtype=dtype, shape=shape)] else: rets += [tf.zeros(dtype=dtype, shape=[])] # Our best guess. return shapes.Pack(rets) def Preprocess(extracted): for key, preprocessor in zip(self.params.preprocessors_order, self.preprocessors): with tf.name_scope(key), tf.name_scope(preprocessor.params.name): if self.params.batched_input: extracted = preprocessor.TransformBatchedFeatures(extracted) else: extracted = preprocessor.TransformFeatures(extracted) return extracted # If the extractor wants to filter the example, don't run the preprocessor. # # Preprocessors can then assume that only examples that pass filtering will # be executed. # # Note that the NullLike branch may return tensors with shapes different # from self.Shape(). final_output = tf.cond( tf.less(max_bucket, BUCKET_UPPER_BOUND), lambda: Preprocess(extracted), NullLike) return max_bucket, final_output
def ComputePredictions(self, encoder_outputs, pronunciations, is_inference=False): """Computes the predictions from the encoder_outputs, updating losses. Despite the name, this function does the bulk of the decoding and loss computation, incrementing the loss at each time step. Args: encoder_outputs: a NestedMap consisting of outputs of the FeatureNeighborhoodEncoder with encoded - encoding of the input spelling neighbor_pronunciations_encoded - encodings of the neighbor prons neighbor_pronunciations_encoded - encodings of the neighbor spellings state - encoder state to which has been added dec_input - seed output for the decoder [*, 1] tensor consisting of sentence start indices (corresponding to "<s>") pronunciations: NestedMap with pronunciations - [*, max_pronunciation_len] tensor of pronunciations is_inference: If False then uses teacher forcing else does autoregression. Returns: NestedMap with loss, per_sequence_losses,labels, a [*, max_pronunciation_len] tensor of predictions, and attention ([*, max_pronunciation_len, max_spelling_len]), and neighbor_attention ([*, max_pronunciation_len, max_neighbors]) tensors, along with the raw batch passed through from the encoder. """ p = self.params targets = pronunciations.pronunciations t_len = int(targets.get_shape().as_list()[1]) t_idx = tf.constant(0) attention = tf.TensorArray(dtype=tf.float32, size=t_len) neighbor_attention = tf.TensorArray(dtype=tf.float32, size=t_len) outputs = tf.TensorArray(dtype=tf.float32, size=t_len) loop_cond = lambda t_idx, ts, *_: tf.less(t_idx, t_len) dec_input = tf.convert_to_tensor([p.start] * p.input.batch_size) state = encoder_outputs.state # pylint: disable=missing-docstring def loop_body(t_idx, dec_input, attention, neighbor_attention, state, outputs): decoder_result = self.Decode(encoder_outputs, dec_input, state) outputs = outputs.write(t_idx, decoder_result.predictions) attention = attention.write(t_idx, decoder_result.attention_weights) neighbor_attention = neighbor_attention.write( t_idx, tf.cast(decoder_result.neighbor_attention_weights, dtype=tf.float32)) if is_inference: dec_input = tf.cast(tf.argmax(decoder_result.predictions, 1), tf.int32) else: dec_input = targets[:, t_idx] t_idx = t_idx + 1 state = decoder_result.state return t_idx, dec_input, attention, neighbor_attention, state, outputs _, _, attention, neighbor_attention, state, outputs = tf.while_loop( loop_cond, loop_body, loop_vars=[ t_idx, dec_input, attention, neighbor_attention, state, outputs ]) outputs = tf.transpose(outputs.stack(), [1, 0, 2]) labels = tf.argmax(outputs, axis=-1) mask = tf.cast(tf.math.logical_not(tf.math.equal(targets, 0)), dtype=tf.float32) loss = self._loss_object(targets, outputs, sample_weight=mask) loss = tf.reduce_sum(loss, axis=1) per_sequence_losses = (loss / t_len) loss = tf.reduce_mean(per_sequence_losses) predictions = py_utils.NestedMap() predictions.loss = loss predictions.per_sequence_losses = per_sequence_losses predictions.labels = labels predictions.attention = tf.transpose(tf.squeeze(attention.stack()), perm=[1, 0, 2]) if p.use_neighbors: predictions.neighbor_attention = tf.transpose(tf.squeeze( neighbor_attention.stack()), perm=[1, 0, 2]) else: predictions.neighbor_attention = tf.squeeze( neighbor_attention.stack()) # Expose this for subsequent data analysis predictions.batch = encoder_outputs.batch return predictions
def ExtractUsingExtractors(self, record): """Extracts Tensors from a tf.Example record using self.extractors. Args: record: A tf.Example input to pass to tf.parse_single_example. Returns: A tuple of tensors: - bucket_id: A scalar int Tensor. - extracted: a NestedMap of Tensors extracted. """ feature_map = {} context_map = {} self._extractors.Transform( lambda e: feature_map.update(e.FeatureMap())) if self.params.record_type == 'SEQUENCE_EXAMPLE': self._extractors.Transform( lambda e: context_map.update(e.ContextMap())) if self.params.record_type not in _PARSING_FUNCTIONS: raise ValueError('Invalid record_type: {}'.format( self.params.record_type)) parsing_fn = _PARSING_FUNCTIONS[self.params.record_type] if self.params.record_type == 'SEQUENCE_EXAMPLE': features = parsing_fn(record, feature_map, context_map) else: features = parsing_fn(record, feature_map) def ExtractAndFilter(e): with tf.name_scope(e.params.name): with tf.name_scope('extract'): extracted = e.Extract(features) with tf.name_scope('filter'): bucket = e.Filter(extracted) return bucket, extracted bucket_extracted = self._extractors.Transform(ExtractAndFilter) buckets = bucket_extracted.Transform(lambda x: x[0]) extracted = bucket_extracted.Transform(lambda x: x[1]) # Return the maximum bucket id so that any extractor can decide whether # to filter the entire example. max_bucket = tf.reduce_max(buckets.Flatten()) def NullLike(): """A function to return the same Tensor signature as Preprocess. This is necessary for the tf.cond() to avoid executing the preprocessor for examples that are going to be dropped because it exceeds the bucket limit; tf.cond() requires that the output of both branches yields the same structure. Returns: A structure with the same Tensor dtype and shape as the output of Preprocess. """ shapes = self.Shape() rets = [ tf.zeros(dtype=dtype, shape=shape) for (dtype, shape) in zip(self.DType().Flatten(), shapes.Flatten()) ] return shapes.Pack(rets) def Preprocess(extracted): for key, preprocessor in zip(self.params.preprocessors_order, self.preprocessors): with tf.name_scope(key), tf.name_scope( preprocessor.params.name): extracted = preprocessor.TransformFeatures(extracted) return extracted # If the extractor wants to filter the example, don't run the preprocessor. # # Preprocessors can then assume that only examples that pass filtering will # be executed. final_output = tf.cond(tf.less(max_bucket, BUCKET_UPPER_BOUND), lambda: Preprocess(extracted), NullLike) return max_bucket, final_output
def flat_beam_search(batch_size, beam_size, max_steps, dec_callback, dec_state, bos_id=1, eos_id=2, length_norm_alpha=0.8, beam_gap=3.0, top_k_fn=tf.math.top_k, prefix=None, prefix_len=None, fprop_dtype=tf.float32, ext_size=0, nbest_size=None, debug=True): """Flat beam search. Args: batch_size: batch size beam_size: beam size limit in number of hyps max_steps: max steps dec_callback: decoder callback (see above) dec_state: decoder state bos_id: <s> token id eos_id: </s> token id length_norm_alpha: length normalization parameter beam_gap: early stopping threshold; None to disable top_k_fn: top_k function to call prefix: (optional) int32 tensor [batch_size, prefix_max] prefix_len: (optional) int32 tensor [batch_size] fprop_dtype: fprop dtype ext_size: int >= beam_size, extension buffer size nbest_size: number of returned hyps, default is beam_size debug: log intermediate vlaues with tpu_summary.tensor() Returns: (loop_vars, dec_state, nbest) where nbest = (topk_ids, topk_len, topk_score) """ assert beam_size > 0 assert batch_size > 0 assert max_steps > 0 buf_size = beam_size * max_steps output_len = max_steps if prefix is None: assert prefix_len is None prefix = tf.zeros([batch_size, beam_size], dtype=tf.int32) prefix += tf.one_hot(0, beam_size, dtype=tf.int32) * bos_id prefix_len = tf.ones([batch_size], dtype=tf.int32) else: assert int(prefix.shape[0]) == batch_size, (batch_size, prefix.shape) assert int(prefix_len.shape[0]) == batch_size, (batch_size, prefix_len.shape) output_len += int(prefix.shape[1]) if debug: tpu_summary.tensor('prefix', prefix) tpu_summary.tensor('prefix_len', prefix_len) with tf.name_scope('init_state'): t = tf.constant(0) tgt_id = tf.zeros([batch_size, beam_size], dtype=tf.int32) tgt_id += bos_id tgt_pos = tf.zeros([batch_size, beam_size], dtype=tf.int32) tgt_mask = tf.zeros([batch_size, beam_size, buf_size], dtype=fprop_dtype) tgt_mask += tf.one_hot(tf.range(beam_size), buf_size, dtype=fprop_dtype) hyp_score = tf.zeros([batch_size, beam_size], dtype=fprop_dtype) # penalize all hyps except the first hyp_score -= tf.cast(tf.range(beam_size, dtype=tf.float32) * 1e5, dtype=fprop_dtype) nbest_size = nbest_size or beam_size nbest_score = tf.zeros([batch_size, nbest_size], dtype=fprop_dtype) nbest_score -= 1e9 nbest_score_norm = nbest_score nbest_mask = tf.zeros([batch_size, nbest_size, buf_size], dtype=fprop_dtype) with tf.name_scope('init_ext'): # Initialize the extension buffer. # # Extension buffer stores a (potentially large) set of 'extensions', # which consist of a hypothesis (represented by ext_mask) and next token # (represented by ext_id). At each decoder iteration, top_k extensions # from each hypothesis are added to the buffer and sorted by score. # # Then top beam_size extensions are removed from the buffer and used # in the next decoder iteration. And top 'ext_size' remaining extensions # are carried over to be possibly evaluated at a later step. # # As a result of this manipulation, the decoder is no longer restricted # to always compare hyps of the same token length at each iteration. # In particular, for a fixed length N it can generate more than beam_size # terminated hyps. # # Setting ext_size = 0 disables this feautre. if ext_size: ext_id = tf.zeros([batch_size, ext_size], dtype=tf.int32) ext_score = tf.zeros([batch_size, ext_size], dtype=fprop_dtype) ext_score -= 1e9 ext_mask = tf.zeros([batch_size, ext_size, buf_size], dtype=fprop_dtype) else: ext_size = ext_id = ext_score = ext_mask = 0 with tf.name_scope('init_prefix'): # rename prefix->pfx for shorter variables pfx = tf.cast(prefix, tf.int32) pfx_len = tf.cast(prefix_len, tf.int32) del prefix, prefix_len # Before the first call to dec_callback() the prefix shall be packed into # the tgt_id buffer as follows: # # [ P P P P P P - - - - - - P* - - - ] ^ # [ P P P P P P P P P P - - P* - - - ] | batch # [ P - - - - - - - - - - - P* - - - ] V # |<---- prefix len ----> |<-- beam --> # # The last meaningful token in the prefix (P*) # must be located at the same position in all batch rows. # # We then make one dec_callback() with full prefix (minus P*) # which will populate the initial dec_state # (for transformer -- self-attention key/value cache) # # The last block [batch, beam] then becomes the first tgt_id for the loop. pfx_max = int(pfx.shape[1]) pfx_mul = pfx_max // beam_size assert pfx_max == pfx_mul * beam_size, (pfx_max, pfx_mul, beam_size) pfx_time = tf.range(pfx_max) pfx_pad = tf.cast( tf.less(tf.expand_dims(pfx_time, 0), tf.expand_dims(pfx_len - 1, 1)), tf.int32) pfx_id = pfx * pfx_pad pfx_last = einsum_i32( 'BT,BT->B', pfx, tf.one_hot(pfx_len - 1, pfx_max, dtype=fprop_dtype)) buf_time = tf.range(buf_size) pfx_time_mask = tf.cast( tf.less_equal(tf.expand_dims(buf_time, 0), tf.expand_dims(pfx_time, 1)), fprop_dtype) pfx_mask = tf.einsum('BQ,QK->BQK', tf.cast(pfx_pad, fprop_dtype), pfx_time_mask) pfx_segment_id = pfx_pad pfx_pos = pfx_time * pfx_pad if debug: tpu_summary.tensor('pfx_id', pfx_id) tpu_summary.tensor('pfx_len', pfx_len) tpu_summary.tensor('pfx_pos', pfx_pos) tpu_summary.tensor('pfx_last', pfx_last) # Now call decoder with prefix minus P*: # 'dec_state' now shall contain the key/value cache for prefix tokens # (for transformer models), and 'logits' we can either discard or # roll into the initial hyp_score. Discard is simpler. with tf.name_scope('prefix_fprop'): # TODO(krikun): remove extra type checks assert (pfx_id.dtype == tf.int32), (pfx_id.dtype) assert (pfx_segment_id.dtype == tf.int32), (pfx_segment_id.dtype) assert (pfx_pos.dtype == tf.int32), (pfx_pos.dtype) assert (pfx_mask.dtype == fprop_dtype), (pfx_mask.dtype) assert (t.dtype == tf.int32), (t.dtype) logits, dec_state = dec_callback(pfx_id, pfx_segment_id, pfx_pos, pfx_mask, dec_state, t) del logits # Now construct the initial state for the rest of the beam search loop. # 'tgt_id' is simply 'pfx_last' padded to [batch, beam] shape # 'tgt_pos' is different for each batch row and is equal to prefix_len # 'tgt_segment_id' always 1 (no packing) # 'hyp_score' is 0 for beam=0 and negative for beam>=1 tgt_id = tf.zeros([batch_size, beam_size], tf.int32) + tf.expand_dims( pfx_last, 1) tgt_pos = tf.zeros([batch_size, beam_size], tf.int32) + tf.expand_dims( (pfx_len - 1), 1) hyp_score = tf.zeros( [batch_size, beam_size], dtype=fprop_dtype) - tf.cast( tf.range(beam_size, dtype=tf.float32) * 1e5, dtype=fprop_dtype) # TODO(krikun) Here we make initial 't' constant and determined by the # shape of the prefix tensor 'pfx_max'. It is possible to make it dynamic # as t ~ max(pfx_len) / beam_size and this will more steps for beam search # however 'max' results in a very slow all-to-all for 'max' on 16x16 # and variable number of decoder steps may result in bad latency. t = tf.cast(tf.math.ceil(pfx_max / beam_size), tf.int32) # Initial tgt_mask is such that each token P* has attention on itself # (as usual) and on all prefix tokens before it, which are not padding. tgt_mask = tf.zeros([batch_size, beam_size, buf_size], dtype=fprop_dtype) tgt_mask += tf.cast( tf.expand_dims( tf.pad(pfx_pad, [[0, 0], [0, (buf_size - pfx_max)]]), 1), fprop_dtype) tgt_mask += tf.one_hot(tf.range(beam_size) + t * beam_size, buf_size, dtype=fprop_dtype) if debug: tpu_summary.tensor('tgt_id', tgt_id) tpu_summary.tensor('tgt_pos', tgt_pos) tpu_summary.tensor('tgt_mask', tgt_mask) tpu_summary.tensor('t', t) with tf.name_scope('init_hist'): # h_tgt_id is used to recover topk_ids from nbest_mask h_tgt_id = tf.TensorArray(dtype=tf.int32, size=max_steps) h_tgt_pos = tf.TensorArray(dtype=tf.int32, size=max_steps) # When non-trivial prefix is present we also write prefix ids to # h_tgt_id so that the full sequence including prefix can be recovered # by unmask() below. When prefix is empty, pfx_id shape is [batch, 0] # and the loop below becomes a no-op. # TODO(krikun): maybe a tf.while_loop is more appropriate here. for i, x_i in enumerate(tf.split(pfx_id, pfx_mul, 1)): h_tgt_id = h_tgt_id.write(i, x_i) for i, x_i in enumerate(tf.split(pfx_pos, pfx_mul, 1)): h_tgt_pos = h_tgt_pos.write(i, x_i) hist = (h_tgt_id, h_tgt_pos) tf.logging.info('hist=%r', hist) nbest_hyps = (nbest_mask, nbest_score, nbest_score_norm) tf.logging.info('nbest_hyps=%r', nbest_hyps) ext = (ext_id, ext_score, ext_mask) tf.logging.info('ext=%r', ext) loop_vars = (t, tgt_id, tgt_pos, tgt_mask, hyp_score, nbest_hyps, ext, hist) tf.logging.info('loop_vars=%r', loop_vars) def loop_step(loop_vars, dec_state): # pylint: disable=missing-docstring tf.logging.info('loop_vars=%r', loop_vars) tf.logging.info('dec_state=%r', dec_state) (t, tgt_id, tgt_pos, tgt_mask, hyp_score, nbest_hyps, ext, hist) = loop_vars (ext_id, ext_score, ext_mask) = ext (h_tgt_id, h_tgt_pos) = hist h_tgt_id = h_tgt_id.write(t, tgt_id, name='h_tgt_id') h_tgt_pos = h_tgt_pos.write(t, tgt_pos, name='h_tgt_pos') # not using tf.ones() here because of XLA compilation error tgt_segment_id = tgt_id * 0 + 1 logits, dec_state = dec_callback(tgt_id, tgt_segment_id, tgt_pos, tgt_mask, dec_state, t) # take predicted EOS score for each hyp and compute normalized score eos_score = hyp_score + tf.cast(logits[:, :, eos_id], hyp_score.dtype) def length_norm(t): t = tf.cast(t, fprop_dtype) alpha = length_norm_alpha tf.logging.info('length_norm.alpha=%r', alpha) return tf.math.pow((t + 5.) / 5., alpha) hyp_len = tgt_pos - tf.expand_dims((pfx_len - 1), -1) eos_score_norm = eos_score / length_norm(hyp_len) # update the n-best list nbest_hyps = update_nbest(nbest_hyps, (tgt_mask, hyp_score, eos_score_norm)) if debug: tpu_summary.tensor('eos_score', eos_score) tpu_summary.tensor('hyp_len', hyp_len) # take top k tokens for each hyp k = beam_size with tf.name_scope('topk1'): top_score, top_id = top_k_fn(logits, k) top_score = tf.cast(top_score, fprop_dtype) top_score += tf.expand_dims(hyp_score, -1) top_score -= 1e9 * tf.cast(tf.equal(top_id, eos_id), fprop_dtype) top_score = tf.reshape(top_score, [batch_size, beam_size * k]) top_id = tf.reshape(top_id, [batch_size, beam_size * k]) top_mask = tf.repeat(tgt_mask, beam_size, 1) if debug: tpu_summary.tensor('top_id', top_id) tpu_summary.tensor('top_score', top_score) # tpu_summary.tensor('top_mask', top_mask) with tf.name_scope('update_ext'): # combine top k tokens with extension buffer (if any) if ext_size: ext_id = tf.concat([ext_id, top_id], 1) ext_score = tf.concat([ext_score, top_score], 1) ext_mask = tf.concat([ext_mask, top_mask], 1) else: ext_id, ext_score, ext_mask = top_id, top_score, top_mask # sort by score ext_score, i = tf.math.top_k(ext_score, ext_size + beam_size) i1 = tf.one_hot(i, ext_size + beam_size * k, dtype=fprop_dtype) ext_mask = tf.einsum('bkt,bjk->bjt', ext_mask, i1) ext_id = einsum_i32('bk,bjk->bj', ext_id, i1) # pick top beam_size extensions to evaluate at next iteration if ext_size: hyp_score = ext_score[:, :beam_size] ext_score = ext_score[:, beam_size:] tgt_id = ext_id[:, :beam_size] ext_id = ext_id[:, beam_size:] tgt_mask = ext_mask[:, :beam_size] ext_mask = ext_mask[:, beam_size:] else: hyp_score, tgt_id, tgt_mask = ext_score, ext_id, ext_mask ext_score = ext_id = ext_mask = 0 tgt_pos = tf.reduce_sum(tgt_mask, -1) tgt_pos = tf.cast(tgt_pos, tf.int32) t += 1 with tf.name_scope('tgt_mask_extend'): tgt_mask += tf.one_hot(tf.range(beam_size) + t * beam_size, buf_size, dtype=fprop_dtype) ext = (ext_id, ext_score, ext_mask) hist = (h_tgt_id, h_tgt_pos) loop_vars = (t, tgt_id, tgt_pos, tgt_mask, hyp_score, nbest_hyps, ext, hist) tf.logging.info('loop_vars=%r', loop_vars) tf.logging.info('dec_state=%r', dec_state) return loop_vars, dec_state def loop_cond(loop_vars, dec_state): # pylint: disable=missing-docstring tf.logging.info('loop_vars=%r', loop_vars) tf.logging.info('dec_state=%r', dec_state) if beam_gap is None: (t, _, _, _, _, _, _, _) = loop_vars return t < max_steps else: (t, _, _, _, _, nbest_hyps, _, _) = loop_vars (_, nbest_score, _) = nbest_hyps # stop early if all current hyps are significantly worse than nbest diff = tf.reduce_min( tf.reduce_min(nbest_score, -1) - tf.reduce_max(hyp_score, -1)) return tf.math.logical_and(t < max_steps, diff < beam_gap) with tf.name_scope('flat_beam_search_loop'): (loop_vars, dec_state) = tf.while_loop(loop_cond, loop_step, loop_vars=(loop_vars, dec_state), back_prop=False, swap_memory=False, maximum_iterations=max_steps) # flatten all tensorarrays into tensors (t, tgt_id, tgt_pos, tgt_mask, hyp_score, nbest_hyps, ext, hist) = loop_vars (nbest_mask, nbest_score, nbest_score_norm) = nbest_hyps (h_tgt_id, h_tgt_pos) = hist h_tgt_id = h_tgt_id.stack() h_tgt_pos = h_tgt_pos.stack() hist = (h_tgt_id, h_tgt_pos) loop_vars = (t, tgt_id, tgt_pos, tgt_mask, hyp_score, nbest_hyps, ext, hist) # recover topk_ids from nbest_mask and tgt_id history h = tf.transpose(h_tgt_id, [1, 0, 2]) h = tf.reshape(h, [batch_size, buf_size]) def unmask(h, m): with tf.name_scope('unmask'): tpu_summary.tensor('unmask_h', h) tpu_summary.tensor('unmask_m', m) t = tf.cumsum(m, -1) * m - 1 mh = einsum_i32('bkt,bt->bkt', m, h) t2 = tf.one_hot(tf.cast(t, tf.int32), output_len, dtype=fprop_dtype) x = einsum_i32('bkt,bktT->bkT', mh, t2) return tf.cast(x, h.dtype) topk_ids = unmask(h, nbest_mask) topk_len = tf.reduce_sum(nbest_mask, -1) topk_len = tf.cast(topk_len, tf.int32) # add eos, because nbest_mask does not encode eos topk_ids += eos_id * tf.one_hot(topk_len, output_len, dtype=tf.int32) topk_len += 1 topk_len = tf.minimum(topk_len, output_len) topk_score = nbest_score_norm nbest = (topk_ids, topk_len, topk_score) return loop_vars, dec_state, nbest
def _AddNoise(self, batch): """Adding noise the src (see https://arxiv.org/pdf/1711.00043). This function implement 3 types of noise (hyparams defined in self.params.denoise): 1) slightly shuffle the sentence following p.shuffle_tok_range 2) randomly drop tokens with probability p.drop_tok_prob 3) randomly mask tokens with probability p.blank_tok_prob The noises are added to the input with probability p.noise_sent_prob. Args: batch: a `.NestedMap` of the input batch. """ def IsSpecialExample(task_ids, special_task_ids): """A utility function indicates whether inputs belong to specific tasks. Args: task_ids: Task ids for the input batch. Tensor of shape [batch]. special_task_ids: A list of specified task ids. Returns: A tensor indicating whether each sample in the batch belong to the specified task. Return a tensor of size [batch]. """ batch_size = py_utils.GetShape(task_ids)[0] return tf.reduce_any( tf.equal( tf.expand_dims(task_ids, -1), tf.cast( tf.broadcast_to( special_task_ids, [batch_size, len(special_task_ids)]), tf.int32)), -1) p = self.params.denoise batch_size = tf.shape(batch.src.ids)[0] source_max_len = tf.shape(batch.src.ids)[1] # Shuffle tokens according to p.shuffle_tok_range noise = tf.random.uniform([batch_size, source_max_len], 0, p.shuffle_tok_range + 1) # Don't shuffle eos or padding shuffle_tok_range = tf.fill([batch_size, source_max_len], float(p.shuffle_tok_range)) shifted_paddings = tf.pad(batch.src.paddings[:, 1:], [[0, 0], [0, 1]], constant_values=1) noise = tf.where(tf.equal(shifted_paddings, 0), noise, shuffle_tok_range) indices = tf.broadcast_to(tf.range(source_max_len, dtype=tf.int32), [batch_size, source_max_len]) noisy_indices = tf.cast(indices, dtype=tf.float32) + noise permutations = tf.argsort(noisy_indices) stacked = tf.stack([batch.src.ids, permutations], axis=1) denoise_src_ids = tf.stack(tf.map_fn(lambda x: tf.gather(x[0], x[1]), stacked), axis=0) # Select tokens to drop with probability=p.drop_tok_prob random_drop_tok = tf.random.uniform([batch_size, source_max_len]) # Don't drop eos token is_keep_tok = tf.math.logical_or( tf.greater(random_drop_tok, p.drop_tok_prob), tf.equal(denoise_src_ids, self._src_tokenizer.eos_id)) denoise_src_ids = tf.ragged.boolean_mask( denoise_src_ids, is_keep_tok).to_tensor(default_value=0, shape=tf.shape(batch.src.ids)) denoise_src_paddings = tf.ragged.boolean_mask( batch.src.paddings, is_keep_tok).to_tensor(default_value=1, shape=tf.shape(batch.src.ids)) # Select tokens to blank with probability=p.blank_tok_prob # Don't blank eos token random_blank_tok = tf.random.uniform([batch_size, source_max_len]) shifted_paddings = tf.pad(denoise_src_paddings[:, 1:], [[0, 0], [0, 1]], constant_values=1) is_blank_tok = tf.math.logical_and( tf.less(random_blank_tok, p.blank_tok_prob), tf.equal(shifted_paddings, 0)) blank_id = tf.fill([batch_size, source_max_len], p.blank_id) denoise_src_ids = tf.where(is_blank_tok, blank_id, denoise_src_ids) # Select denoising task examples with probability=p.denoise_sent_prob random_uniform_sent = tf.random.uniform([batch_size]) is_denoise_sent = tf.math.logical_and( tf.less(random_uniform_sent, p.noise_sent_prob), IsSpecialExample(self._GetTaskIds(batch.src.source_ids[:, 0]), p.task_ids)) batch.src.ids = tf.where(is_denoise_sent, denoise_src_ids, batch.src.ids) batch.src.paddings = tf.where(is_denoise_sent, denoise_src_paddings, batch.src.paddings) batch.src.ids_indicator = 1 - batch.src.paddings batch.src.weights = batch.src.ids_indicator
def _Seeded(): return tf.cond(tf.less(curr_idx, num_seeded_points), _GetSeededPoint, _GetFurthestPoint)
def ProcessFeatures(self, features): """Process extracted features. Args: features: A dict of extracted Tensors from the records. Returns: A tuple of tensors: - bucket_id: A scalar int Tensor. - extracted: a NestedMap of Tensors extracted. """ def ExtractAndFilter(e): with tf.name_scope(e.params.name): with tf.name_scope('extract'): extracted = e.Extract(features) with tf.name_scope('filter'): bucket = e.Filter(extracted) return bucket, extracted bucket_extracted = self._extractors.Transform(ExtractAndFilter) buckets = bucket_extracted.Transform(lambda x: x[0]) extracted = bucket_extracted.Transform(lambda x: x[1]) # Return the maximum bucket id so that any extractor can decide whether # to filter the entire example. max_bucket = tf.reduce_max(buckets.Flatten()) def NullLike(): """A function to return the same Tensor signature as Preprocess. This is necessary for the tf.cond() to avoid executing the preprocessor for examples that are going to be dropped because it exceeds the bucket limit; tf.cond() requires that the output of both branches yields the same structure. Returns: A structure with the same Tensor dtype as the output of Preprocess. """ shapes = self.Shape() rets = [] for dtype, shape in zip(self.DType().Flatten(), shapes.Flatten()): if shape.is_fully_defined(): rets += [tf.zeros(dtype=dtype, shape=shape)] else: rets += [tf.zeros(dtype=dtype, shape=[])] # Our best guess. return shapes.Pack(rets) def Preprocess(extracted): for key, preprocessor in zip(self.params.preprocessors_order, self.preprocessors): with tf.name_scope(key), tf.name_scope(preprocessor.params.name): extracted = preprocessor.TransformFeatures(extracted) return extracted # If the extractor wants to filter the example, don't run the preprocessor. # # Preprocessors can then assume that only examples that pass filtering will # be executed. # # Note that the NullLike branch may return tensors with shapes different # from self.Shape(). final_output = tf.cond( tf.less(max_bucket, BUCKET_UPPER_BOUND), lambda: Preprocess(extracted), NullLike) return max_bucket, final_output
def Top2GatingOnLogits(inputs, paddings, logits, num_devices, experts_dim, expert_capacity_dim, fprop_dtype, use_xla_sharding=True, second_expert_policy='all', second_expert_threshold=0.0, legacy_mtf_behavior=True, capacity_factor=None): """Computes Top-2 gating for Mixture-of-Experts. There are two expected usages of this function: 1. used with xla_sharding. In this case, 'inputs' corresponds to a sharded tensor across multiple tpu cores. The operations within this function are automatically sharded/replicated across tpu cores. 2. used within ML-Pathways. In this case, 'inputs' is always local to one tpu core. All computations below are carried out on one tpu core only. This function tries to dispatch examples across tpu cores in such a way that each expert is assigned no more than 'expert_capacity_dim' number of examples. Below ` indicates common way of splitting along mesh dimension. Dimensions cheat sheet: G: group_dim S: group_size_dim E: number of experts C: capacity per expert M: model_dim (same as input_dim, same as output_dim) B: original batch_dim L: original sequence_length_dim Note that for local_dispatch original batch BLM is reshaped into GSM, each group `g = 0...G-1` is being dispatched independently. Args: inputs: G`SM Tensor. paddings: G`S Tensor. logits: G`SE Tensor. num_devices: number of MoE devices for local dispatch experts_dim: number of experts. expert_capacity_dim: number of examples per minibatch(group) per expert. Each example is typically a vector of size input_dim, representing embedded token or an element of Transformer layer output. fprop_dtype: activations datatype to use. use_xla_sharding: bool, True if this function is used for the xla_sharding case. second_expert_policy: 'all', 'sampling' or 'random'. - 'all': we greedily pick the 2nd expert. - 'sampling': we sample the 2nd expert from the softmax. - 'random': we optionally 'random'-ize dispatch to second-best expert proportional to (weight / second_expert_threshold). second_expert_threshold: threshold for probability normalization for second_expert_policy == 'random'. legacy_mtf_behavior: bool, True if to match legacy mtf behavior exactly. capacity_factor: if set, increases expert_capacity_dim to at least (group_size * capacity_factor) / experts_dim where `group_size` is the size of G dimension of `inputs`. If the value of expert_capacity_dim is already big enough no change is made. TODO(lepikhin): get rid of the legacy_mtf_behavior flag. Returns: A tuple (aux_loss, combine_tensor, dispatch_tensor). - aux_loss: auxiliary loss, for equalizing the expert assignment ratios. - combine_tensor: G`SEC Tensor for combining expert outputs. - dispatch_tensor: G`SEC Tensor, scattering/dispatching inputs to experts. """ del inputs # inputs is currently not used. raw_gates = tf.nn.softmax(logits) # along E dim if capacity_factor is not None: # Determine expert capacity automatically depedning on the input size. group_size_dim = int(logits.shape[1]) auto_expert_capacity = int((group_size_dim * capacity_factor) / experts_dim) if expert_capacity_dim < auto_expert_capacity: expert_capacity_dim = auto_expert_capacity # Round up to a multiple of 4 to avoid possible padding. while expert_capacity_dim % 4: expert_capacity_dim += 1 tf.logging.info( 'Setting expert_capacity_dim=%r (capacity_factor=%r ' 'group_size_dim=%r experts_dim=%r name_scope=%r)', expert_capacity_dim, capacity_factor, group_size_dim, experts_dim, tf.get_default_graph().get_name_scope()) tpu_summary.scalar('expert_capacity', expert_capacity_dim) # top first and second gate value and expert index for each input # # GSK Tensors, K=2 def _MaybeSplit(x): if use_xla_sharding: return Split(x, 0, num_devices) else: return x def _CreateOverCapacityRatioSummary(mask, position_in_expert, capacity, name): over_capacity = tf.reduce_sum( tf.cast( tf.greater_equal(mask * position_in_expert, capacity), mask.dtype)) over_capacity_ratio = over_capacity / tf.reduce_sum(mask) py_utils.AddTpuSummaryTensor(name, over_capacity_ratio) tpu_summary.scalar(name, over_capacity_ratio, while_loop_reduce='mean') # As pointed out by zhifengc@ this method needs to be refactored. lepikhin@ # and krikun@ will: # - expand moe_spmd_test to compare Adafactor updates, slots on TPU # including 2x2 with sharding # # - add more tests for policy="random" # # - add single step test for full size WMT model on CPU # # and then break this function into modules. # # GS index_1 = tf.math.argmax(raw_gates, axis=-1, output_type=tf.int32) index_1 = _MaybeSplit(index_1) tpu_summary.tensor('index_1', index_1) # GSE mask_1 = tf.one_hot(index_1, experts_dim, dtype=fprop_dtype) mask_1 = _MaybeSplit(mask_1) density_1_proxy = raw_gates importance = tf.ones_like(mask_1[:, :, 0]) if paddings is not None: importance = 1.0 - paddings mask_1 *= tf.expand_dims(importance, -1) density_1_proxy *= tf.expand_dims(importance, -1) gate_1 = tf.einsum('GSE,GSE->GS', raw_gates, mask_1) gates_without_top_1 = raw_gates * (1.0 - mask_1) if second_expert_policy == 'sampling': # We directly sample the 2nd expert index from the softmax over of the 2nd # expert by getting rid of the 1st expert already selected above. To do so, # we set a very negative value to the logit corresponding to the 1st expert. # Then we sample from the softmax (categorical) distribution using the # Gumbel max trick. noise = _MaybeSplit(tf.random.uniform(logits.shape, dtype=logits.dtype)) # Generates standard Gumbel(0, 1) noise, GSE Tensors noise = -tf.math.log(-tf.math.log(noise)) very_negative_logits = _MaybeSplit( (tf.ones_like(logits) * logits.dtype.max * tf.constant(-0.7, dtype=logits.dtype))) # Gets rid of the first expert by setting its logit to be very negative updated_logits = _MaybeSplit( tf.where(mask_1 > 0.0, very_negative_logits, logits)) # Adds the Gumbel noise to the updated logits noised_logits = _MaybeSplit(updated_logits + noise) # Picks the index of the largest noised logit as the 2nd expert. This is # equivalent to sampling from the softmax over the 2nd experts. index_2 = tf.math.argmax(noised_logits, axis=-1, output_type=tf.int32) else: index_2 = tf.math.argmax(gates_without_top_1, axis=-1, output_type=tf.int32) index_2 = _MaybeSplit(index_2) mask_2 = tf.one_hot(index_2, experts_dim, dtype=fprop_dtype) mask_2 = _MaybeSplit(mask_2) if paddings is not None: mask_2 *= tf.expand_dims(importance, -1) gate_2 = tf.einsum('GSE,GSE->GS', gates_without_top_1, mask_2) if legacy_mtf_behavior: # cl/298510175 moved this branch for gate_{1,2} denom calculation here. # # For policy=random, it's better to nomalize gate_{1,2} before taking # capacity into account and before potentially dropping second expert. # # According to mean_xent (http://short/_NzbZ5rINr5): # MoE_512_102xen_PolicyAll_298510175 # MoE_512_102xen_PolicyRandom_298510175 # # vs pre-cl/298510175 # MoE_512_102xen_PolicyRandom # MoE_512_102xen_PolicyAll # # it substantially improves policy=random with threshold=0.5 which # historically was better than policy="all" # # Also confirmed this by decoding # nmt_train/m4/data/es_en/test.txt # nmt_train/m4/data/ru_en/test.txt # nmt_train/m4/data/zh_en/test.txt # and improving BLEU # # moe_decode.MoE_512_102xen_PolicyRandom_298510175-160000.batch1024.beam4.c_dim4.ln0.8.rkv.mteval102 # 0.421443 # 0.327102 # 0.315693 # vs # moe_decode.feb18_non_fig_snapshot_2626_MoE_512_102xen_PolicyRandom-190000.batch1024.beam4.c_dim4.ln0.8.rkv.mteval102 # 0.399232 # 0.310606 # 0.288229 # # Additional comparison, see mean_xent http://short/_YHccOhQtdu with # legacy_mtf_behavior=False models # 3 - MoE_512_102xen_PolicyAll_LegacyFalse # 6 - MoE_512_102xen_PolicyRandom_LegacyFalse # shows that policy="random" gets worse with legacy_mtf_behavior=False, and # is similar to pre-cl/298510175 # 4 - MoE_512_102xen_PolicyRandom # # gate_1 can become 0 due to Expert being out of capacity. # # gate_2 can become 0 due to # second_expert_policy == 'random' # or "out of capacity" scenario. # # Here we renormalize regardless of cases above. denom = gate_1 + gate_2 + 1e-9 gate_1 /= denom gate_2 /= denom # We reshape the mask as [X*S, E], and compute cumulative sums of # assignment indicators for each expert index e \in 0..E-1 independently. # First occurrence of assignment indicator is excluded, see exclusive=True # flag below. position_in_expert_1 = tf.cumsum(mask_1, exclusive=True, axis=1) # GS Tensor capacity = tf.cast(expert_capacity_dim, dtype=position_in_expert_1.dtype) # GE Tensor (reducing S out of GSE tensor mask_1) # density_1[:, e] represents assignment ratio (num assigned / total) to # expert e as top_1 expert without taking capacity into account. if legacy_mtf_behavior: density_denom = 1.0 else: density_denom = tf.reduce_mean( importance, axis=(1))[:, tf.newaxis] + 1e-6 density_1 = tf.reduce_mean(mask_1, axis=(1)) / density_denom # density_1_proxy[:, e] represents mean of raw_gates for expert e, including # those of examples not assigned to e with top_k. density_1_proxy = tf.reduce_mean(density_1_proxy, axis=1) / density_denom # The MoE paper (https://arxiv.org/pdf/1701.06538.pdf) uses an aux loss of # reduce_mean(density_1_proxy * density_1_proxy). Here we replace one of # the density_1_proxy with the discrete density_1 following # mesh_tensorflow/transformer/moe.py?rcl=283569345. aux_loss = tf.reduce_mean(density_1_proxy * density_1) # element-wise aux_loss *= experts_dim * experts_dim # const coefficient # Add the over capacity ratio for expert 1 _CreateOverCapacityRatioSummary(mask_1, position_in_expert_1, capacity, 'over_capacity_1_ratio') mask_1 *= tf.cast(tf.less(position_in_expert_1, capacity), dtype=mask_1.dtype) position_in_expert_1 = tf.einsum('GSE,GSE->GS', position_in_expert_1, mask_1) # How many examples in this sequence go to this expert mask_1_count = tf.einsum('GSE->GE', mask_1) # [batch, group] - mostly ones, but zeros where something didn't fit mask_1_flat = tf.einsum('GSE->GS', mask_1) if second_expert_policy == 'all' or second_expert_policy == 'sampling': pass elif second_expert_policy == 'random': # gate_2 is between 0 and 1, reminder: # # raw_gates = tf.nn.softmax(logits) # index_1 = tf.math.argmax(raw_gates, axis=-1, output_type=tf.int32) # mask_1 = tf.one_hot(index_1, experts_dim, dtype=fprop_dtype) # gate_1 = tf.einsum('GSE,GSE->GS', raw_gates, mask_1) # # E.g. if gate_2 exceeds second_expert_threshold, then we definitely # dispatch to second-best expert. Otherwise we dispatch with probability # proportional to (gate_2 / threshold). # sampled_2 = tf.less( _MaybeSplit(tf.random.uniform(gate_2.shape, dtype=gate_2.dtype)), (gate_2 / max(second_expert_threshold, 1e-9))) gate_2 *= tf.cast(sampled_2, gate_2.dtype) mask_2 *= tf.cast(tf.expand_dims(sampled_2, -1), mask_2.dtype) else: raise ValueError(second_expert_policy) position_in_expert_2 = tf.cumsum( mask_2, exclusive=True, axis=1) + tf.expand_dims(mask_1_count, 1) # Add the over capacity ratio for expert 2 _CreateOverCapacityRatioSummary(mask_2, position_in_expert_2, capacity, 'over_capacity_2_ratio') mask_2 *= tf.cast(tf.less(position_in_expert_2, capacity), mask_2.dtype) position_in_expert_2 = tf.einsum('GSE,GSE->GS', position_in_expert_2, mask_2) mask_2_flat = tf.reduce_sum(mask_2, axis=-1) # Equivalent non-einsum implementation: # # position_in_expert_2 *= mask_2 # position_in_expert_2 = tf.reduce_sum( # position_in_expert_2, axis=-1, name='position_in_expert_2') gate_1 *= mask_1_flat gate_2 *= mask_2_flat if not legacy_mtf_behavior: denom = gate_1 + gate_2 # To avoid divide by 0. denom = tf.where(denom > 0, denom, tf.ones_like(denom)) gate_1 /= denom gate_2 /= denom # GSC Tensor b = tf.one_hot( tf.cast(position_in_expert_1, dtype=tf.int32), expert_capacity_dim, dtype=fprop_dtype, name='one_hot_b_0') # GSE Tensor a = tf.expand_dims(gate_1 * mask_1_flat, -1) * tf.one_hot( index_1, experts_dim, dtype=fprop_dtype) # GSEC Tensor first_part_of_combine_tensor = tf.einsum( 'GSE,GSC->GSEC', a, b, name='first_part_of_combine_tensor') # GSC Tensor b = tf.one_hot( tf.cast(position_in_expert_2, dtype=tf.int32), expert_capacity_dim, dtype=fprop_dtype, name='one_hot_b_1') # GSE Tensor a = tf.expand_dims(gate_2 * mask_2_flat, -1) * tf.one_hot( index_2, experts_dim, dtype=fprop_dtype) second_part_of_combine_tensor = tf.einsum( 'GSE,GSC->GSEC', a, b, name='second_part_of_combine_tensor') # GSEC Tensor combine_tensor = ( first_part_of_combine_tensor + second_part_of_combine_tensor) combine_tensor = _MaybeSplit(combine_tensor) # GSEC Tensor dispatch_tensor = tf.cast(tf.cast(combine_tensor, tf.bool), fprop_dtype) dispatch_tensor = _MaybeSplit(dispatch_tensor) # TODO(yonghui): compute and return per-group aux_loss. return aux_loss, combine_tensor, dispatch_tensor