def QuantizeTensors(self, t_name, ts, eval_only=False): p = self.params # Always straddle a real zero point. if self.do_eval: # At eval/inference time, use the memorized range. # Important: Don't capture these variables in training mode so as to # avoid extra/unnecessary captures. min_var = self._GetQStateVar(t_name, 'min') max_var = self._GetQStateVar(t_name, 'max') return [ self._MaybeFakeQuant(t, min_var, max_var, num_bits=p.bits) for t in ts ] else: # At training time, use the batch calculated min/max. accumulator_name = self._GetAccumulatorNameForTensor(t_name) # Calculate min/max for all tensors. batch_min = 0.0 batch_max = 0.0 for t in ts: batch_min = tf.minimum(tf.reduce_min(t), batch_min) batch_max = tf.maximum(tf.reduce_max(t), batch_max) # New state. state1 = tf.stack([1.0, batch_min, batch_max]) self.accumulators[accumulator_name].Update(state1) # Results. ts_out = [] for i, t in enumerate(ts): if eval_only: # If only quantizing at eval time, still record ranges as above # but don't quantize. quant_t = t else: # If quantizing during training, skip quantization if it produces # NANs. Sometimes early in the training process, things are unstable # and ranges can produce numerical instability that makes it # impossible to perform a fake_quant. quant_t = self._MaybeFakeQuant(t, batch_min, batch_max, num_bits=p.bits) # TODO(laurenzo): Plumb quant_t_has_nans through state and report. quant_t_has_nans = tf.math.is_nan(quant_t) quant_t = tf.where(quant_t_has_nans, t, quant_t) ts_out.append(quant_t) summary_utils.histogram( '%s/%s_%d' % (self._qvars_scope.name, t_name, i), t) return ts_out
def FProp(self, theta, input_batch, state0=None): p = self.params src_segment_id = None with tf.name_scope(p.name): # Reshape to [t, b] inputs = py_utils.with_dependencies([ py_utils.assert_shape_match(tf.shape(input_batch.ids), [-1, -1]), py_utils.assert_shape_match(tf.shape(input_batch.ids), tf.shape(input_batch.paddings)) ], tf.transpose(input_batch.ids)) paddings = tf.expand_dims(tf.transpose(input_batch.paddings), 2) # Setup streaming states. if not state0: state0 = self.zero_state(theta, tf.shape(inputs)[1]) state1 = py_utils.NestedMap(rnn=[None] * p.num_lstm_layers) xs = self.emb.EmbLookup(theta.emb, inputs) xs = self.ApplyClipping(theta, xs) summary_utils.histogram('input_emb', xs) xs = self.dropout.FProp(theta.dropout, xs) ps = paddings # Now the rnn layers. outputs_list = [] for i in range(0, p.num_lstm_layers): layer = self.rnn[i] ys, state1.rnn[i] = layer.FProp(theta.rnn[i], xs, ps, state0=state0.rnn[i]) ys = self.dropout.FProp(theta.dropout, ys) if i >= p.residual_start: xs += ys # Residual skip xs = self.ApplyClipping(theta, xs) else: xs = ys outputs_list.append(xs) summary_utils.histogram('layer_out_%s' % i, xs) if p.is_transparent: xs = self.transparent_merger.FProp(theta.transparent_merger, outputs_list) return py_utils.NestedMap(encoded=xs, padding=tf.squeeze(ps, [2]), segment_id=src_segment_id, state=state1)
def FProp(self, theta, inputs, *args): p = self.params with tf.name_scope(p.name) as scope: expert_dist = self._GetExpertDist(theta, inputs, *args) if not self.do_eval: summary_utils.histogram('soft_cond_{}'.format(scope), expert_dist) # Excludes non-variable extra_theta like global_step. var_set = set([key for key, _ in self.body.vars.FlattenItems()]) values = [] for key, value in theta.body.FlattenItems(): if key in var_set and value is not None: # Weighted average for all variables created in the body layer. value = tf.einsum('i,i...->...', expert_dist, value) values.append(value) weighted_theta = theta.body.Pack(values) return self.body.FProp(weighted_theta, inputs, *args)
def _DataSourceToInputBatch(self): """The current input batch as a `.NestedMap` of input tensors.""" ret, _ = self._BuildDataSource() self._Pack(ret) if 'weights' not in ret.src or 'weights' not in ret.tgt: ret.src.weights = ret.src.ids_indicator ret.tgt.weights = ret.tgt.ids_indicator if 'paddings' not in ret.src or 'paddings' not in ret.tgt: ret.src.paddings = 1 - ret.src.weights ret.tgt.paddings = 1 - ret.tgt.weights del ret.src.ids_indicator del ret.tgt.ids_indicator if self.params.pad_to_max_seq_length: assert self.params.source_max_length def _EnsureSrcShape(x): if x.dtype == tf.string: return tf.ensure_shape(x, [self._ScaledBatchSize()]) return tf.ensure_shape( x, [self._ScaledBatchSize(), self.params.source_max_length]) def _EnsureTgtShape(x): if x.dtype == tf.string: return tf.ensure_shape(x, [self._ScaledBatchSize()]) return tf.ensure_shape( x, [self._ScaledBatchSize(), self.params.target_max_length]) ret.src = ret.src.Transform(_EnsureSrcShape) ret.tgt = ret.tgt.Transform(_EnsureTgtShape) summary_utils.histogram('source_token_ids', ret.src.ids) summary_utils.histogram('target_token_ids', ret.tgt.ids) # Casts floating point tensors to fprop_dtype before returning. return ret.Transform(self.Cast)
def FProp(self, theta, input_batch): p = self.params with tf.name_scope(p.name): inputs = py_utils.with_dependencies([ py_utils.assert_shape_match(tf.shape(input_batch.ids), [-1, -1]), py_utils.assert_shape_match(tf.shape(input_batch.ids), tf.shape(input_batch.paddings)) ], tf.transpose(input_batch.ids)) paddings = tf.expand_dims(tf.transpose(input_batch.paddings), 2) if p.packed_input: src_segment_id = tf.expand_dims( tf.transpose(input_batch.segment_ids), 2) else: src_segment_id = None xs = self.emb.EmbLookup(theta.emb, inputs) xs = self.ApplyClipping(theta, xs) summary_utils.histogram('input_emb', xs) xs = self.dropout.FProp(theta.dropout, xs) ps = paddings # Now the rnn layers. outputs_list = [] for i in range(0, p.num_lstm_layers): layer = self.rnn[i] ys = layer.FProp(theta.rnn[i], xs, ps, segment_id=src_segment_id) ys = self.dropout.FProp(theta.dropout, ys) if i >= p.residual_start: xs += ys # Residual skip xs = self.ApplyClipping(theta, xs) else: xs = ys outputs_list.append(xs) summary_utils.histogram('layer_out_%s' % i, xs) if p.is_transparent: xs = self.transparent_merger.FProp(theta.transparent_merger, outputs_list) if p.lstm_cell_size * 2 != p.encoder_out_dim: # Project to the right depth. xs = self.final_proj.FProp(theta.final_proj, xs, ps) summary_utils.histogram('final_proj_out', xs) if src_segment_id is not None: src_segment_id = tf.squeeze(src_segment_id, [2]) return py_utils.NestedMap(encoded=xs, padding=tf.squeeze(ps, [2]), segment_id=src_segment_id)
def _Pack(self, batch): """Packs a given batch. Note that this may change the batch size. This function packs the input batch and adds .segment_ids and .segment_pos fields to its `src` and `tgt` fields. Args: batch: a `.NestedMap` of input tensors to be packed. It is modified in place. """ src_actual_seq_len = tf.math.reduce_sum( tf.cast(batch.src.ids_indicator, tf.int32), axis=1) tgt_actual_seq_len = tf.math.reduce_sum( tf.cast(batch.tgt.ids_indicator, tf.int32), axis=1) summary_utils.histogram('source_seq_lengths', src_actual_seq_len) summary_utils.histogram('target_seq_lengths', tgt_actual_seq_len) if not self.params.packing_factor: # Supply segment_ids and segment_pos with no packing. batch.src.segment_ids = batch.src.ids_indicator batch.src.segment_pos = _GetSegmentPos(batch.src.ids_indicator) batch.tgt.segment_ids = batch.tgt.ids_indicator batch.tgt.segment_pos = _GetSegmentPos(batch.tgt.ids_indicator) return (src_segment_ids, src_segment_pos, src_indices_in_input, tgt_segment_ids, tgt_segment_pos, tgt_indices_in_input) = ops.pack_sequences( src_actual_seq_len, tgt_actual_seq_len, self._ScaledBatchSize(), self.params.source_max_length, self.params.target_max_length) uniq_src_indices_in_input = tf.unique( tf.reshape(src_indices_in_input, [-1])).y uniq_tgt_indices_in_input = tf.unique( tf.reshape(tgt_indices_in_input, [-1])).y summary_utils.histogram( 'packed_source_seq_lengths', tf.gather(src_actual_seq_len, uniq_src_indices_in_input, axis=0)) summary_utils.histogram( 'packed_target_seq_lengths', tf.gather(tgt_actual_seq_len, uniq_tgt_indices_in_input, axis=0)) # We deferred adding .paddings and use its complement .ids_indicator # exclusively so that we can apply the packing with padding set to 0 for all # fields. def ApplyPackingToSource(x): if x.dtype == tf.string: return ops.apply_packing(x, '\t', src_segment_ids, src_indices_in_input) return ops.apply_packing(x, 0, src_segment_ids, src_indices_in_input) batch.src = batch.src.Transform(ApplyPackingToSource) batch.src.segment_ids = tf.cast(src_segment_ids, tf.float32) batch.src.segment_pos = src_segment_pos def ApplyPackingToTarget(x): if x.dtype == tf.string: return ops.apply_packing(x, '\t', tgt_segment_ids, tgt_indices_in_input) return ops.apply_packing(x, 0, tgt_segment_ids, tgt_indices_in_input) batch.tgt = batch.tgt.Transform(ApplyPackingToTarget) batch.tgt.segment_ids = tf.cast(tgt_segment_ids, tf.float32) batch.tgt.segment_pos = tgt_segment_pos
def ComputeAndUpdateMoments(self, theta, inputs, paddings=None): """Computes moments and updates state. Args: theta: A `.NestedMap` object containing weights' values of this layer and its children layers. inputs: The inputs tensor. Shaped [..., dim]. paddings: The paddings tensor. Shaped [..., 1], with the same rank as the input tensor. Returns: Tuple of (mean, variance, beta, gamma). """ p = self.params if paddings is None: paddings = self._GetDefaultPaddings(inputs) inputs = py_utils.with_dependencies([ py_utils.assert_shape_match([tf.shape(paddings)[-1]], [1]), ], inputs) with tf.name_scope(p.name): if self.do_eval: # The mean and variance used for normalization. norm_mean, norm_variance = (self.vars.moving_mean, self.vars.moving_variance) else: mean, variance = self._Moments(inputs, 1.0 - paddings, p.enable_cross_replica_sum_on_tpu) py_utils.UpdateBatchNormVars(self.vars.moving_mean, mean, self._decay) py_utils.UpdateBatchNormVars(self.vars.moving_variance, variance, self._decay) # Add some summaries for visualization. summary_utils.histogram('%s_mean' % p.name, tf.cast(mean, tf.float32)) summary_utils.histogram('%s_variance' % p.name, tf.cast(variance, tf.float32)) summary_utils.histogram('%s_moving_mean' % p.name, tf.cast(self.vars.moving_mean, tf.float32)) summary_utils.histogram('%s_moving_variance' % p.name, tf.cast(self.vars.moving_variance, tf.float32)) summary_utils.histogram( '%s_mean_diff' % p.name, tf.cast(mean - self.vars.moving_mean, tf.float32)) summary_utils.histogram( '%s_variance_diff' % p.name, tf.cast(variance - self.vars.moving_variance, tf.float32)) if p.use_moving_avg_in_training: # Use the global statistics for normalization. # Control dependencies on mean and variance make sure # moving_mean and variance will be updated for every training step. norm_mean = py_utils.with_dependencies([mean], self.vars.moving_mean) norm_variance = py_utils.with_dependencies([variance], self.vars.moving_variance) else: # Use the batch statistics for normalization. norm_mean = mean norm_variance = variance norm_mean = py_utils.CheckNumerics( norm_mean, 'mean of %s failed numeric check' % p.name) norm_variance = py_utils.CheckNumerics( norm_variance, 'variance of %s failed numeric check' % p.name) if p.use_moving_avg_in_training: beta = 0.0 gamma = 1.0 else: beta = theta.beta gamma = theta.gamma return norm_mean, norm_variance, beta, gamma