def _PaddedMaxFn(inp): """Apply padded max using reduce_max with paddings replaced by neginf.""" # Replace all padded features with -inf. neginf_padding = tf.where( inp.padding > 0, -np.inf * inp.padding, inp.padding) features = inp.features + neginf_padding[..., tf.newaxis] features = tf.reduce_max(features, axis=-2) # Replace features of all padded points by zeros. If a batch of points are # all padded, then reduce_min over the padding will be 1. We set the # features to be zero, so that we don't get any downstream issue with # NaNs. Note that inf * 0 = NaN. all_padded = tf.cast(tf.reduce_min(inp.padding, axis=-1), tf.bool) all_padded = tf.broadcast_to(all_padded[..., tf.newaxis], py_utils.GetShape(features)) features = tf.where(all_padded, tf.zeros_like(features), features) return py_utils.CheckNumerics(features)
def _GetFurthestPoint(): """Get point that is furthest from those already selected. We also bias the sampling towards real points by setting the distance to padded points negative until we are out of real points. """ # Set padded points distance to negative so they aren't selected. padding_masked_distance_to_selected = tf.where( tf.equal(padding, 0.0), distance_to_selected, -1.0 * tf.ones( (batch_size, num_points), dtype=tf.float32)) # But only do this when we still have valid points left. padding_masked_distance_to_selected = tf.where( tf.less(curr_idx, num_valid_points), padding_masked_distance_to_selected, distance_to_selected) return tf.argmax(padding_masked_distance_to_selected, axis=-1, output_type=tf.int32)
def RandomPadOrTrimTo(tensor_list, num_points_out, seed=None): """Pads or Trims a list of Tensors on the major dimension. Slices if there are more points, or pads if not enough. In this implementation: Padded points are random duplications of real points. Sliced points are a random subset of the real points. Args: tensor_list: A list of tf.Tensor objects to pad or trim along first dim. All tensors are expected to have the same first dimension. num_points_out: An int for the requested number of points to trim/pad to. seed: Random seed to use for random generators. Returns: A tuple of output_tensors and a padding indicator. - output_tensors: A list of padded or trimmed versions of our tensor_list input tensors, all with the same first dimension. - padding: A tf.float32 tf.Tensor of shape [num_points_out] with 0 if the point is real, 1 if it is padded. """ actual_num = tf.shape(tensor_list[0])[0] point_idx = tf.range(num_points_out, dtype=tf.int32) padding_tensor = tf.where(point_idx < actual_num, tf.zeros([num_points_out], dtype=tf.float32), tf.ones([num_points_out], dtype=tf.float32)) def _Slicing(): # Choose a random set of indices. indices = tf.range(actual_num) indices = tf.random_shuffle(indices, seed=seed)[:num_points_out] return [tf.gather(t, indices, axis=0) for t in tensor_list] def _Padding(): indices = tf.random_uniform([num_points_out - actual_num], minval=0, maxval=actual_num, dtype=tf.int32, seed=seed) padded = [] for t in tensor_list: padded.append(tf.concat([t, tf.gather(t, indices, axis=0)], axis=0)) return padded def _PadZeros(): padded = [] for t in tensor_list: shape = tf.concat([[num_points_out], tf.shape(t)[1:]], axis=0) padded.append(tf.zeros(shape=shape, dtype=t.dtype)) return padded data = tf.cond( actual_num > num_points_out, _Slicing, lambda: tf.cond(tf.equal(actual_num, 0), _PadZeros, _Padding)) return (data, padding_tensor)
def Clipped(): clip_ratio = state[0] min_value, max_value = self._GetCurrentMinMax(state, start_cap, end_cap, bits) min_value = tf.stop_gradient(min_value) max_value = tf.stop_gradient(max_value) return tf.where(clip_ratio >= 0.0, (lambda: tf.clip_by_value(x, min_value, max_value))(), (lambda: x)())
def Value(self): p = self.params with tf.name_scope(p.name): steps = self._best_step best_step = steps[0] last_step = steps[1] ref_step = tf.maximum(self.theta.ref_step, best_step) f = self.theta.cur_factor # Decay if no improvement within window. new_factor = tf.where(last_step - ref_step < p.window, f, tf.maximum(p.min_factor, f * p.decay)) # Update ref_step if we decayed. new_step = tf.where(tf.equal(new_factor, f), ref_step, last_step) update_step = tf.assign(self.vars.ref_step, new_step) with tf.control_dependencies([update_step]): return tf.assign(self.vars.cur_factor, new_factor)
def CreateTpuEmbeddingEnqueueOps(self): """Creates the TpuEmbedding enqueue ops on the host. Note that this must be called after the instantiation of the monolithic TPUEmbeddingLayer. """ p = self.params cluster = self.cluster num_tpu_hosts = cluster.num_tpu_hosts num_infeed_hosts = num_tpu_hosts if p.use_per_host_infeed else 1 tpu_embedding_collection = tf.get_collection(py_utils.TPU_EMBEDDING) tpu_embedding = (tpu_embedding_collection[0] if tpu_embedding_collection else None) enqueue_ops = [] if num_tpu_hosts > 1 and tpu_embedding is not None: if not p.use_per_host_infeed: tf.logging.fatal( 'TPU Embedding must be used with per_host_infeed with multiple ' 'TPU host topologies.') tpu_emb_input_keys = (list(tpu_embedding.feature_to_config_dict.keys()) if tpu_embedding is not None else []) tf.logging.info('tpu_emb_input_keys: %r', tpu_emb_input_keys) if not tpu_embedding: return for task_id in range(num_infeed_hosts): host_device = '/task:{}/device:CPU:0'.format(task_id) with tf.device(host_device): if isinstance(self._batch, py_utils.NestedMap): # Hack: bucket_keys and xxx.bucket_keys are not needed on TPU. # Note that when MultiTaskData is used, bucket_keys will be at the # second level of the dictionary. self._batch = self._batch.FilterKeyVal( lambda k, _: not k.endswith('bucket_keys')) tf.logging.info('host_device: %s, batch: %r', host_device, self._batch) enqueue_dict_per_core = [ {} for _ in range(tpu_embedding.num_cores_per_host) ] num_cores_per_host = tpu_embedding.num_cores_per_host for key in tpu_emb_input_keys: feat = self._batch[key] tpu_emb_feat_splitted = tf.split(feat, num_cores_per_host) for core, split in enumerate(tpu_emb_feat_splitted): # Dense to sparse. Note the assumption of a padding id. sample_indices = tf.where(tf.not_equal(split, -1)) embedding_indices = tf.gather_nd(split, sample_indices) enqueue_data = tpu_embedding_lib.EnqueueData( embedding_indices, sample_indices) enqueue_dict_per_core[core][key] = enqueue_data enqueue_ops += tpu_embedding.generate_enqueue_ops( enqueue_dict_per_core) self._tpu_infeed_op.append(tf.group(*enqueue_ops))
def CpuEmbLookup(self, ids_map, partition_strategy): """CPU evaluation embedding lookup. Args: ids_map: A dict of `input_key` string -> [batch, sequence] int32 Tensor. -1 is used as a padding id. partition_strategy: See TPUEmbeddingLayer partition_strategy param. Returns: An activations dict of string -> float32 Tensor. For non-sequence embeddings: [batch, 1, embedding_dim] For sequence embeddings: [batch, max_sequence_length, embedding_dim] """ p = self.params rets = py_utils.NestedMap() if self.max_sequence_length > 0: # "Sequence embedding", no combiner case for k, ids in ids_map.items(): embs = tf.nn.embedding_lookup( self.theta.wm, tf.reshape(ids, [-1]), partition_strategy=partition_strategy) out_shape = tf.concat([tf.shape(ids), [p.embedding_dim]], 0) rets[k] = tf.reshape(embs, out_shape) else: # Non-"Sequence embedding", combiner case for k, ids in ids_map.items(): # Dense to sparse. dense_shape = tf.shape(ids, out_type=tf.int64) sample_indices = tf.cast(tf.where(tf.not_equal(ids, -1)), tf.int64) embedding_indices = tf.cast(tf.gather_nd(ids, sample_indices), tf.int64) sparse_ids = tf.SparseTensor(indices=sample_indices, values=embedding_indices, dense_shape=dense_shape) # [?, embedding_dim] # For tf.nn.embedding_lookup_sparse, output.dim0 might be different from # sparse_ids.dense_shape.dim0. # In fact, the '?' is the smallest span starting from the index=0 that # covers all the results. embs = tf.nn.embedding_lookup_sparse( self.theta.wm, sparse_ids, None, # sp_weights combiner=p.combiner, partition_strategy=partition_strategy) batch_size = dense_shape[0] # Explicitly pad results to maintain dim0=batch. dim0_padlen = tf.cast(batch_size, tf.int32) - tf.shape(embs)[0] embs = tf.pad(embs, [[0, dim0_padlen], [0, 0]]) # [batch, 1, embedding_dim] embs = py_utils.HasShape(embs, [batch_size], ndims=1) rets[k] = tf.expand_dims(embs, 1) return rets
def FProp(self, theta, current_step): """Returns the current learning rate decay.""" p = self.params current_step = tf.cast(current_step, tf.float32) warmup_steps = tf.cast(p.warmup_steps * p.worker_replicas, tf.float32) if p.decay_end is not None: current_step = tf.where(current_step < p.decay_end, current_step, tf.cast(p.decay_end, tf.float32)) return p.model_dim**-0.5 * tf.minimum( (current_step + 1) * warmup_steps**-1.5, (current_step + 1)**-0.5)
def grad_fn(d_outputs): with tf.name_scope("entmax_grad"): gppr = tf.where(p_m > 0, tf.math.pow(p_m, 2.0 - alpha), tf.zeros_like(p_m)) d_inputs = d_outputs * gppr q = tf.math.reduce_sum(d_inputs, axis) / tf.math.reduce_sum( gppr, axis) q = tf.expand_dims(q, axis) d_inputs -= q * gppr return d_inputs, d_inputs
def NMSIndices(self, bboxes, scores, max_output_size, nms_iou_threshold=0.3, score_threshold=0.01): """Apply NMS to a series of 3d bounding boxes in 7-DOF format. Args: bboxes: A [num_boxes, 7] floating point Tensor of bounding boxes in [x, y, z, dx, dy, dz, phi] format. scores: A [num_boxes] floating point Tensor containing box scores. max_output_size: Maximum number of boxes to predict per input. nms_iou_threshold: IoU threshold to use when determining whether two boxes overlap for purposes of suppression. score_threshold: The score threshold passed to NMS that allows NMS to quickly ignore irrelevant boxes. Returns: The NMS indices and the mask of the padded indices. """ bboxes = py_utils.HasShape(bboxes, [-1, 7]) # Extract x, y, w, h, then convert to extrema. # # Note that we drop the rotation angle because we don't have an NMS # operation that takes rotation into account. bboxes_2d = tf.stack( [bboxes[:, 0], bboxes[:, 1], bboxes[:, 3], bboxes[:, 4]], axis=-1) bboxes_extrema = geometry.XYWHToBBoxes(bboxes_2d) # Compute NMS with padding; we use the padded version so this function can # be used in a map_fn. This function returns the scalar number of boxes # for each example. # # We use an IoU threshold of 0.3 since our anchor boxes have rotations # that make the default IoU threshold of 0.5 possibly too high. nms_index_padded, num_valid = tf.image.non_max_suppression_padded( bboxes_extrema, scores, iou_threshold=nms_iou_threshold, max_output_size=max_output_size, score_threshold=score_threshold, pad_to_max_output_size=True) # Return the mask of valid indices instead of just a scalar number. mask = tf.concat( [tf.ones([num_valid]), tf.zeros([max_output_size - num_valid])], axis=0) nms_index_padded = tf.where(mask > 0, nms_index_padded, tf.zeros_like(nms_index_padded)) return nms_index_padded, mask
def _ParseRecord(self, record): """Reads and parses a single record.""" p = self.params name_to_features = { 'input_ids': tf.io.FixedLenFeature([p.max_sequence_length], tf.int64), 'input_mask': tf.io.FixedLenFeature([p.max_sequence_length], tf.int64), 'masked_lm_positions': tf.io.FixedLenFeature([p.max_predictions_per_seq], tf.int64), 'masked_lm_ids': tf.io.FixedLenFeature([p.max_predictions_per_seq], tf.int64), 'masked_lm_weights': tf.io.FixedLenFeature([p.max_predictions_per_seq], tf.float32), } example = tf.io.parse_single_example(record, name_to_features) mask_length = tf.cast(tf.reduce_sum(example['masked_lm_weights']), dtype=tf.int32) masked_lm_positions = tf.slice(example['masked_lm_positions'], [0], [mask_length]) masked_lm_ids = tf.cast(tf.slice(example['masked_lm_ids'], [0], [mask_length]), dtype=tf.int32) ret = py_utils.NestedMap() ret.masked_ids = tf.cast(example['input_ids'], dtype=tf.int32) # Get back non-masked, original ids. ret.ids = tf.tensor_scatter_nd_update(tensor=ret.masked_ids, indices=tf.reshape( masked_lm_positions, [-1, 1]), updates=masked_lm_ids) ret.masked_pos = tf.tensor_scatter_nd_update( tensor=tf.zeros_like(ret.masked_ids, dtype=tf.float32), indices=tf.reshape(masked_lm_positions, [-1, 1]), updates=tf.ones_like(masked_lm_ids, dtype=tf.float32)) ret.segment_ids = tf.cast(example['input_mask'], dtype=tf.float32) first_eos_idx = tf.where(tf.math.equal(ret.ids, p.eos_token_id))[0][0] def _RemoveFirstEos(x): # We remove the element at position `first_eos_idx`, and pad with 0 # to keep length unchanged. zero = tf.constant(0, shape=(1, ), dtype=x.dtype) return tf.concat([x[:first_eos_idx], x[first_eos_idx + 1:], zero], axis=0) ret = ret.Transform(_RemoveFirstEos) ret.paddings = 1.0 - ret.segment_ids pos = tf.cast(tf.range(p.max_sequence_length), dtype=tf.float32) ret.segment_pos = tf.cast(ret.segment_ids * pos, dtype=tf.int32) if p.remove_mask: del ret.masked_pos del ret.masked_ids return ret
def Value(self): """Returns the current learning rate decay.""" p = self.params current_step = tf.cast(py_utils.GetGlobalStep(), tf.float32) warmup_steps = tf.cast(p.warmup_steps * p.worker_replicas, tf.float32) if p.decay_end is not None: current_step = tf.where(current_step < p.decay_end, current_step, tf.cast(p.decay_end, tf.float32)) return p.model_dim**-0.5 * tf.minimum( (current_step + 1) * warmup_steps**(p.decay_factor - 1.0), (current_step + 1)**tf.cast(p.decay_factor, tf.float32))
def _Lookup(ids): # Dense to sparse. dense_shape = tf.shape(ids, out_type=tf.int64) sample_indices = tf.cast(tf.where(tf.not_equal(ids, -1)), tf.int64) embedding_indices = tf.cast(tf.gather_nd(ids, sample_indices), tf.int64) # [?, embedding_dim] sparse_ids = tf.SparseTensor( indices=sample_indices, values=embedding_indices, dense_shape=dense_shape) return self._CombinerEmbLookup(sparse_ids, partition_strategy)
def Step(recurrent_theta, state0, inputs): """Computes one decoder step.""" if p.use_recurrent: del inputs with tf.name_scope('single_sampler_step'): # Compute logits and states. bs_result, bs_state1 = pre_step_callback( decoder_theta, recurrent_theta.encoder_outputs, tf.expand_dims(state0.ids, 1), # [batch, 1]. state0.bs_state, num_hyps_per_beam=p.num_hyps_per_beam) batch = tf.shape(bs_result.log_probs)[0] state1 = py_utils.NestedMap(timestep=state0.timestep + 1) state1.logits = bs_result.log_probs if p.top_k > 0: topk_logits, topk_ids = tf.math.top_k(state1.logits, k=p.top_k) sample_logits = tf.nn.log_softmax( topk_logits) if p.top_k_renormalize else topk_logits else: sample_logits = state1.logits # Sample ids from logits. [batch]. ids = tf.reshape( tf.random.stateless_categorical( sample_logits / p.temperature, num_samples=1, seed=tf.stack( [recurrent_theta.random_seed, state0.timestep]), dtype=state0.ids.dtype, name='sample_next_id'), [batch]) state1.ids = tf.gather(topk_ids, ids, axis=1, batch_dims=1) if p.top_k > 0 else ids if 'is_last_chunk' in bs_result and p.target_eoc_id >= 0: state1.ids = tf.where( tf.math.logical_and( bs_result.is_last_chunk, tf.equal(state1.ids, p.target_eoc_id)), tf.fill(tf.shape(state1.ids), p.target_eos_id), state1.ids) state1.bs_state = post_step_callback( decoder_theta, recurrent_theta.encoder_outputs, state1.ids, bs_state1) if p.use_recurrent: return state1, py_utils.NestedMap() else: inputs.ids = inputs.ids.write(state0.timestep, state1.ids) inputs.logits = inputs.logits.write(state0.timestep, state1.logits) return (recurrent_theta, state1, inputs)
def FProp(self, theta, current_step): """Returns the current learning rate decay.""" params = self.params warmup_steps = tf.to_float(params.decay_start * params.worker_replicas) current_step = tf.to_float(current_step) if params.decay_end is not None: current_step = tf.where(current_step < params.decay_end, current_step, tf.to_float(params.decay_end)) peak_learning_rate = (warmup_steps**-0.5) return (params.model_dim**-0.5) * tf.minimum( tf.minimum((current_step + 1), (current_step + 1)**-0.5), peak_learning_rate)
def Polynomial(x): """Polynomial function of x.""" p = self.params x0, y0 = p.start x1, y1 = p.limit assert x0 < x1, '%s must be < %s' % (x0, x1) x0 = tf.cast(x0, dtype=x.dtype) x1 = tf.cast(x1, dtype=x.dtype) y0 = tf.cast(y0, dtype=x.dtype) y1 = tf.cast(y1, dtype=x.dtype) ratio = (x - x0) / (x1 - x0) if p.origin == 'start': f_x = ratio**p.power elif p.origin == 'limit': f_x = 1 - (1 - ratio)**p.power else: raise ValueError('Invalid parameter origin: %s' % p.origin) y = y0 + f_x * (y1 - y0) return tf.where(x < x0, y0, tf.where(x >= x1, y1, y))
def Value(self): """Returns the current learning rate decay.""" params = self.params warmup_steps = tf.cast(params.decay_start * params.worker_replicas, tf.float32) current_step = tf.cast(py_utils.GetGlobalStep(), tf.float32) if params.decay_end is not None: current_step = tf.where(current_step < params.decay_end, current_step, tf.cast(params.decay_end, tf.float32)) peak_learning_rate = (warmup_steps**-0.5) return (params.model_dim**-0.5) * tf.minimum( tf.minimum((current_step + 1), (current_step + 1)**-0.5), peak_learning_rate)
def CombineStates(self, state0, state1, switch_cond): """Combines states based on a switch conditional. Args: state0: a NestedMap of states to use for batch elements where switch_cond is true. state1: a NestedMap of states to use for batch elements where switch_cond is false. switch_cond: bool tensor of shape [batch] on which to switch. Returns: state_combined: a NestedMap of states. """ updated_rnn_states = [] for i in range(self.params.rnns.num_layers): updated_rnn_states.append( py_utils.NestedMap({ 'c': tf.where(switch_cond, state0.rnn[i].c, state1.rnn[i].c), 'm': tf.where(switch_cond, state0.rnn[i].m, state1.rnn[i].m) })) combined_state = py_utils.NestedMap({'rnn': updated_rnn_states}) return combined_state
def PostTrainingStepUpdate(self, global_step): """Updates moving_mean, moving_variance after each training step.""" p = self.params # Get sufficient stats that accumulates over microbatches. counts = self.accumulators.counts.GetValue() mean_ss = self.accumulators.mean_ss.GetValue() variance_ss = self.accumulators.variance_ss.GetValue() # Compute batch mean and batch variance from sufficient stats mean, variance = tf.nn.normalize_moments(counts, mean_ss, variance_ss, None) decay = tf.convert_to_tensor(1.0 - p.decay, p.dtype) # Update moving_mean, moving_variance from batch mean and batch variance. with tf.name_scope(p.name) as scope: with tf.colocate_with(self.vars.moving_mean): mean_update = tf.assign_sub( self.vars.moving_mean, tf.where(tf.greater(counts, 0.5), (self.vars.moving_mean - tf.cast(mean, p.dtype)) * decay, tf.zeros_like(self.vars.moving_mean)), name='moving_mean_update') with tf.colocate_with(self.vars.moving_variance): var_update = tf.assign_sub( self.vars.moving_variance, tf.where(tf.greater(counts, 0.5), (self.vars.moving_variance - tf.cast(variance, p.dtype)) * decay, tf.zeros_like(self.vars.moving_variance)), name='moving_variance_update') py_utils.CheckNumerics( self.vars.moving_mean, 'moving mean of {} failed numeric check'.format(scope)) py_utils.CheckNumerics( self.vars.moving_variance, 'moving variance of {} failed numeric check'.format(scope)) self.accumulators.counts.Reset() self.accumulators.mean_ss.Reset() self.accumulators.variance_ss.Reset() return tf.group(mean_update, var_update)
def QuantizeTensors(self, t_name, ts, eval_only=False): p = self.params # Always straddle a real zero point. if self.do_eval: # At eval/inference time, use the memorized range. # Important: Don't capture these variables in training mode so as to # avoid extra/unnecessary captures. min_var = self._GetQStateVar(t_name, 'min') max_var = self._GetQStateVar(t_name, 'max') return [ self._MaybeFakeQuant(t, min_var, max_var, num_bits=p.bits) for t in ts ] else: # At training time, use the batch calculated min/max. accumulator_name = self._GetAccumulatorNameForTensor(t_name) # Calculate min/max for all tensors. batch_min = 0.0 batch_max = 0.0 for t in ts: batch_min = tf.minimum(tf.reduce_min(t), batch_min) batch_max = tf.maximum(tf.reduce_max(t), batch_max) # New state. state1 = tf.stack([1.0, batch_min, batch_max]) self.accumulators[accumulator_name].Update(state1) # Results. ts_out = [] for i, t in enumerate(ts): if eval_only: # If only quantizing at eval time, still record ranges as above # but don't quantize. quant_t = t else: # If quantizing during training, skip quantization if it produces # NANs. Sometimes early in the training process, things are unstable # and ranges can produce numerical instability that makes it # impossible to perform a fake_quant. quant_t = self._MaybeFakeQuant(t, batch_min, batch_max, num_bits=p.bits) # TODO(laurenzo): Plumb quant_t_has_nans through state and report. quant_t_has_nans = tf.math.is_nan(quant_t) quant_t = tf.where(quant_t_has_nans, t, quant_t) ts_out.append(quant_t) summary_utils.histogram( '%s/%s_%d' % (self._qvars_scope.name, t_name, i), t) return ts_out
def FProp(self, theta, inputs, paddings): p = self.params fns = self.fns # It is the most important that weights and top-level activations # be tagged for quantization: # - Weights use the self.QWeight() decorator # - Inputs/activations are decorated with self.QTensor(). In general, # the provided name should match a call to self.TrackQTensor in the # constructor. This creates an tensor that is individually accounted # for. w = fns.qweight(theta.w) # TODO(shivaniagrawal): change this to ToAqtWeight and FromAqtWeight. w = self.ToAqtWeight('aqt_w', w, feature_axis=-1, expected_scale_shape=(1, p.output_dim)) inputs = self.QTensor('inputs', inputs) # Note the use of the qmatmul from the function library. This will # automatically track the output against the qtensor 'transformed'. out = fns.qmatmul(tf.reshape(inputs, [-1, p.input_dim]), w, qt='transformed') out = self.FromAqtWeight('aqt_w', out, feature_axis=-1) out = tf.reshape(out, tf.concat([tf.shape(inputs)[:-1], [p.output_dim]], 0)) # Decorate outputs of simple activation functions with their corresponding # range decorator. This will ensure that the result does not exceed the # precision of the underlying representation. out = fns.qtanh(out) # Perform padding manipulation via booleans instead of: # out *= 1.0 - paddings # Because the paddings can exist in entirely different numeric ranges than # the tensor they are being applied to, it is best to not perform # arithmetic directly between them. Instead, broadcast them to the needed # size (if different) and perform an exact mask with tf.where. # For added numeric range protection, the QRPadding decorator ensures # the correct range. This is mostly needed for cases where padding is # dynamic at inference time. paddings = self.QRPadding(paddings) paddings *= tf.ones_like(out) # Broadcast to 'out' size. out = tf.where(paddings > 0.0, tf.zeros_like(out), out) return out
def _MaybeFakeQuant(self, inputs, min_v, max_v, num_bits): p = self.params def Apply(): return tf.quantization.fake_quant_with_min_max_vars( inputs, min_v, max_v, num_bits=num_bits) if p.delay_start_steps != 0 and not self.do_eval: if p.delay_start_steps == -1: return inputs return tf.where(self.theta.global_step >= p.delay_start_steps, Apply(), inputs) else: return Apply()
def forward(inputs, alpha): with tf.name_scope("entmax_loss"): alpha_shape = inputs.get_shape().as_list() alpha_shape[axis] = 1 alpha = tf.fill(alpha_shape, alpha) alpha = tf.cast(alpha, dtype=inputs.dtype) d = inputs.get_shape().as_list()[axis] alpha_m1 = alpha - 1.0 inputs = inputs * alpha_m1 max_val = tf.math.reduce_max(inputs, axis=axis, keepdims=True) tau_lo = max_val - tf.ones(alpha.get_shape().as_list(), dtype=inputs.dtype) tau_hi = max_val - tf.math.pow( tf.cast((1.0 / d), dtype=inputs.dtype), alpha_m1) f_lo = tf.math.reduce_sum( _calculate_probability(tf.math.subtract(inputs, tau_lo), alpha), axis) - 1.0 dm = tau_hi - tau_lo for _ in range(n_iter): dm /= 2 tau_m = tau_lo + dm p_m = _calculate_probability(inputs - tau_m, alpha) f_m = tf.math.reduce_sum(p_m, axis) - 1.0 mask = tf.expand_dims(tf.math.greater(f_m * f_lo, 0), axis) tau_lo = tf.where(mask, tau_m, tau_lo) if ensure_sum_one: p_m /= tf.expand_dims(tf.math.reduce_sum(p_m, axis), axis) def grad_fn(d_outputs): with tf.name_scope("entmax_grad"): gppr = tf.where(p_m > 0, tf.math.pow(p_m, 2.0 - alpha), tf.zeros_like(p_m)) d_inputs = d_outputs * gppr q = tf.math.reduce_sum(d_inputs, axis) / tf.math.reduce_sum( gppr, axis) q = tf.expand_dims(q, axis) d_inputs -= q * gppr return d_inputs, d_inputs return p_m, grad_fn
def Value(self): p = self.params x = tf.cast(py_utils.GetGlobalStep(), dtype=p.dtype) x0, y0 = p.start x1, y1 = p.limit if x0 >= x1: raise ValueError(f'{x0} must be < {x1}') x0 = tf.cast(x0, dtype=x.dtype) x1 = tf.cast(x1, dtype=x.dtype) y0 = tf.cast(y0, dtype=x.dtype) y1 = tf.cast(y1, dtype=x.dtype) ratio = (x - x0) / (x1 - x0) if p.origin == 'start': f_x = ratio**p.power elif p.origin == 'limit': f_x = 1 - (1 - ratio)**p.power else: raise ValueError('Invalid parameter origin: %s' % p.origin) y = y0 + f_x * (y1 - y0) return tf.where(x < x0, y0, tf.where(x >= x1, y1, y))
def _RecordTensor(self, t_name): p = self.params if self.do_eval: return [] accumulator_name = self._GetAccumulatorNameForTensor(t_name) accumulator = self.accumulators[accumulator_name] min_var = self._GetQStateVar(t_name, 'min') max_var = self._GetQStateVar(t_name, 'max') # Unpack state tensor. current_value = accumulator.GetValue() count = current_value[0] min_value = current_value[1] max_value = current_value[2] accumulator.Reset() def Ema(variable, value): return (1.0 - p.ema_decay) * (variable - value) # Note that small floating point issues can cause ranges that naturally # begin or end at zero to move slightly past, causing hard failures # downstream (checks that all ranges straddle zero). We therefore repeat # the straddling constraint here. return [ tf.assign( min_var, tf.minimum( 0., min_var - tf.where(count > 0., Ema(min_var, min_value), 0.))), tf.assign( max_var, tf.maximum( 0., max_var - tf.where(count > 0., Ema(max_var, max_value), 0.))), ]
def _GetRandomRealPoint(): """Select the first point. For the first point, we want any random real (non padded) point, so we create a random values per point, and then set all padded ones to some large value (more than the maxval). We then take the min per batch element to get the first points. """ random_values = tf.random.uniform((batch_size, num_points), minval=0, maxval=1, dtype=tf.float32, seed=random_seed) random_values = tf.where(tf.equal(padding, 0.0), random_values, padding * 10) return tf.argmin(random_values, axis=1, output_type=tf.int32)
def FillPaddingPos(ids: tf.Tensor, id_len: tf.Tensor, padding_value: int) -> tf.Tensor: """Given a batch of sequences, fills the padding pos with `padding_value`. Args: ids: a [B, max_len] int tensor. id_len: a [B, ] int tensor. padding_value: an int. Returns: new_ids: new ids with the property. - new_ids[b, :id_len[b]] = ids[b, :id_len[b]] - new_ids[b, id_len[b]:] = padding_value """ mask = py_utils.SequencePaddings(id_len, maxlen=tf.shape(ids)[1]) mask = tf.cast(mask, dtype=tf.bool) new_ids = tf.where(mask, tf.fill(tf.shape(ids), padding_value), ids) return new_ids
def QuantizeWeight(self, w): p = self.params w_min = tf.reduce_min(w) w_max = tf.reduce_max(w) # NOTE: We force a small, non-zero range because otherwise, zero weights # can cause downstream inference engines to blow up. w_min = tf.minimum(w_min, -p.quantize_weight_epsilon) w_max = tf.maximum(w_max, p.quantize_weight_epsilon) quant_w = self._MaybeFakeQuant(w, w_min, w_max, num_bits=p.bits) if self.do_eval: return quant_w else: # If quantizing during training, skip quantization if it produces # NANs. Sometimes early in the training process, things are unstable # and ranges can produce numerical instability that makes it # impossible to perform a fake_quant. quant_w_has_nans = tf.math.is_nan(quant_w) return tf.where(quant_w_has_nans, w, quant_w)
def _PaddedMeanFn(inp): """Apply padded mean using reduce_sum and dividing by # real points.""" # Replace all padded features with 0 by masking the padded features out. mask = 1 - inp.padding features = inp.features * mask[..., tf.newaxis] features = tf.reduce_sum(features, axis=-2) num_real_points = tf.reduce_sum(mask, axis=-1, keep_dims=True) # Prevent the divisor of our padded mean from ever being 0, so that # the gradient flowing back through this op doesn't give us NaNs. num_real_points = tf.maximum(num_real_points, 1) features = features / num_real_points # Replace features of all padded points by zeros. If a batch of points are # all padded, then num_real_points will be zero. We set the features to be # zero, so that we don't get any downstream issue with NaNs. # Note that inf * 0 = NaN. all_padded = tf.equal(num_real_points, 0.) all_padded = tf.broadcast_to(all_padded, py_utils.GetShape(features)) features = tf.where(all_padded, tf.zeros_like(features), features) return py_utils.CheckNumerics(features)
def ComputeWer(hyps, refs): """Computes word errors in hypotheses relative to reference transcripts. Args: hyps: Hypotheses, represented as string tensors of shape [N]. refs: References, represented as string tensors of shape [N]. Returns: An int64 tensor, word_errs, of size [N, 2] where word_errs[i, 0] corresponds to the number of word errors in hyps[i] relative to refs[i]; word_errs[i, 1] corresponds to the number of words in refs[i]. """ def _NormalizeWhitespace(s): return tf.strings.regex_replace(tf.strings.strip(s), r'\s+', ' ') hyps = _NormalizeWhitespace(hyps) refs = _NormalizeWhitespace(refs) hyps = py_utils.HasRank(hyps, 1) refs = py_utils.HasRank(refs, 1) hyps = py_utils.HasShape(hyps, tf.shape(refs)) word_errors = tf.cast( tf.edit_distance(tf.string_split(hyps), tf.string_split(refs), normalize=False), tf.int64) # Count number of spaces in reference, and increment by 1 to get total number # of words. ref_words = tf.cast( tf.strings.length(tf.strings.regex_replace(refs, '[^ ]', '')) + 1, tf.int64) # Set number of words to 0 if the reference was empty. ref_words = tf.where(tf.equal(refs, ''), tf.zeros_like(ref_words, tf.int64), ref_words) return tf.concat( [tf.expand_dims(word_errors, -1), tf.expand_dims(ref_words, -1)], axis=1)