示例#1
0
    def QuantizeTensors(self, t_name, ts, eval_only=False):
        p = self.params
        # Always straddle a real zero point.
        if self.do_eval:
            # At eval/inference time, use the memorized range.
            # Important: Don't capture these variables in training mode so as to
            # avoid extra/unnecessary captures.
            min_var = self._GetQStateVar(t_name, 'min')
            max_var = self._GetQStateVar(t_name, 'max')
            return [
                self._MaybeFakeQuant(t, min_var, max_var, num_bits=p.bits)
                for t in ts
            ]
        else:
            # At training time, use the batch calculated min/max.
            accumulator_name = self._GetAccumulatorNameForTensor(t_name)
            # Calculate min/max for all tensors.
            batch_min = 0.0
            batch_max = 0.0
            for t in ts:
                batch_min = tf.minimum(tf.reduce_min(t), batch_min)
                batch_max = tf.maximum(tf.reduce_max(t), batch_max)

            # New state.
            state1 = tf.stack([1.0, batch_min, batch_max])
            self.accumulators[accumulator_name].Update(state1)

            # Results.
            ts_out = []
            for i, t in enumerate(ts):
                if eval_only:
                    # If only quantizing at eval time, still record ranges as above
                    # but don't quantize.
                    quant_t = t
                else:
                    # If quantizing during training, skip quantization if it produces
                    # NANs. Sometimes early in the training process, things are unstable
                    # and ranges can produce numerical instability that makes it
                    # impossible to perform a fake_quant.
                    quant_t = self._MaybeFakeQuant(t,
                                                   batch_min,
                                                   batch_max,
                                                   num_bits=p.bits)
                    # TODO(laurenzo): Plumb quant_t_has_nans through state and report.
                    quant_t_has_nans = tf.math.is_nan(quant_t)
                    quant_t = tf.where(quant_t_has_nans, t, quant_t)
                ts_out.append(quant_t)
                summary_utils.histogram(
                    '%s/%s_%d' % (self._qvars_scope.name, t_name, i), t)
            return ts_out
示例#2
0
    def FProp(self, theta, input_batch, state0=None):
        p = self.params
        src_segment_id = None
        with tf.name_scope(p.name):
            # Reshape to [t, b]
            inputs = py_utils.with_dependencies([
                py_utils.assert_shape_match(tf.shape(input_batch.ids),
                                            [-1, -1]),
                py_utils.assert_shape_match(tf.shape(input_batch.ids),
                                            tf.shape(input_batch.paddings))
            ], tf.transpose(input_batch.ids))
            paddings = tf.expand_dims(tf.transpose(input_batch.paddings), 2)

            # Setup streaming states.
            if not state0:
                state0 = self.zero_state(theta, tf.shape(inputs)[1])
            state1 = py_utils.NestedMap(rnn=[None] * p.num_lstm_layers)

            xs = self.emb.EmbLookup(theta.emb, inputs)
            xs = self.ApplyClipping(theta, xs)
            summary_utils.histogram('input_emb', xs)
            xs = self.dropout.FProp(theta.dropout, xs)
            ps = paddings
            # Now the rnn layers.
            outputs_list = []
            for i in range(0, p.num_lstm_layers):
                layer = self.rnn[i]
                ys, state1.rnn[i] = layer.FProp(theta.rnn[i],
                                                xs,
                                                ps,
                                                state0=state0.rnn[i])
                ys = self.dropout.FProp(theta.dropout, ys)
                if i >= p.residual_start:
                    xs += ys  # Residual skip
                    xs = self.ApplyClipping(theta, xs)
                else:
                    xs = ys
                outputs_list.append(xs)
                summary_utils.histogram('layer_out_%s' % i, xs)

            if p.is_transparent:
                xs = self.transparent_merger.FProp(theta.transparent_merger,
                                                   outputs_list)

            return py_utils.NestedMap(encoded=xs,
                                      padding=tf.squeeze(ps, [2]),
                                      segment_id=src_segment_id,
                                      state=state1)
示例#3
0
    def FProp(self, theta, inputs, *args):
        p = self.params
        with tf.name_scope(p.name) as scope:
            expert_dist = self._GetExpertDist(theta, inputs, *args)
            if not self.do_eval:
                summary_utils.histogram('soft_cond_{}'.format(scope),
                                        expert_dist)

            # Excludes non-variable extra_theta like global_step.
            var_set = set([key for key, _ in self.body.vars.FlattenItems()])
            values = []
            for key, value in theta.body.FlattenItems():
                if key in var_set and value is not None:
                    # Weighted average for all variables created in the body layer.
                    value = tf.einsum('i,i...->...', expert_dist, value)
                values.append(value)
            weighted_theta = theta.body.Pack(values)
            return self.body.FProp(weighted_theta, inputs, *args)
示例#4
0
  def _DataSourceToInputBatch(self):
    """The current input batch as a `.NestedMap` of input tensors."""
    ret, _ = self._BuildDataSource()
    self._Pack(ret)
    if 'weights' not in ret.src or 'weights' not in ret.tgt:
      ret.src.weights = ret.src.ids_indicator
      ret.tgt.weights = ret.tgt.ids_indicator
    if 'paddings' not in ret.src or 'paddings' not in ret.tgt:
      ret.src.paddings = 1 - ret.src.weights
      ret.tgt.paddings = 1 - ret.tgt.weights
    del ret.src.ids_indicator
    del ret.tgt.ids_indicator

    if self.params.pad_to_max_seq_length:
      assert self.params.source_max_length

      def _EnsureSrcShape(x):
        if x.dtype == tf.string:
          return tf.ensure_shape(x, [self._ScaledBatchSize()])
        return tf.ensure_shape(
            x, [self._ScaledBatchSize(), self.params.source_max_length])

      def _EnsureTgtShape(x):
        if x.dtype == tf.string:
          return tf.ensure_shape(x, [self._ScaledBatchSize()])
        return tf.ensure_shape(
            x, [self._ScaledBatchSize(), self.params.target_max_length])

      ret.src = ret.src.Transform(_EnsureSrcShape)
      ret.tgt = ret.tgt.Transform(_EnsureTgtShape)

    summary_utils.histogram('source_token_ids', ret.src.ids)
    summary_utils.histogram('target_token_ids', ret.tgt.ids)

    # Casts floating point tensors to fprop_dtype before returning.
    return ret.Transform(self.Cast)
示例#5
0
    def FProp(self, theta, input_batch):
        p = self.params
        with tf.name_scope(p.name):
            inputs = py_utils.with_dependencies([
                py_utils.assert_shape_match(tf.shape(input_batch.ids),
                                            [-1, -1]),
                py_utils.assert_shape_match(tf.shape(input_batch.ids),
                                            tf.shape(input_batch.paddings))
            ], tf.transpose(input_batch.ids))
            paddings = tf.expand_dims(tf.transpose(input_batch.paddings), 2)
            if p.packed_input:
                src_segment_id = tf.expand_dims(
                    tf.transpose(input_batch.segment_ids), 2)
            else:
                src_segment_id = None
            xs = self.emb.EmbLookup(theta.emb, inputs)
            xs = self.ApplyClipping(theta, xs)
            summary_utils.histogram('input_emb', xs)
            xs = self.dropout.FProp(theta.dropout, xs)
            ps = paddings
            # Now the rnn layers.
            outputs_list = []
            for i in range(0, p.num_lstm_layers):
                layer = self.rnn[i]
                ys = layer.FProp(theta.rnn[i],
                                 xs,
                                 ps,
                                 segment_id=src_segment_id)
                ys = self.dropout.FProp(theta.dropout, ys)
                if i >= p.residual_start:
                    xs += ys  # Residual skip
                    xs = self.ApplyClipping(theta, xs)
                else:
                    xs = ys
                outputs_list.append(xs)
                summary_utils.histogram('layer_out_%s' % i, xs)

            if p.is_transparent:
                xs = self.transparent_merger.FProp(theta.transparent_merger,
                                                   outputs_list)

            if p.lstm_cell_size * 2 != p.encoder_out_dim:
                # Project to the right depth.
                xs = self.final_proj.FProp(theta.final_proj, xs, ps)
                summary_utils.histogram('final_proj_out', xs)

            if src_segment_id is not None:
                src_segment_id = tf.squeeze(src_segment_id, [2])

            return py_utils.NestedMap(encoded=xs,
                                      padding=tf.squeeze(ps, [2]),
                                      segment_id=src_segment_id)
示例#6
0
  def _Pack(self, batch):
    """Packs a given batch.

    Note that this may change the batch size.

    This function packs the input batch and adds .segment_ids and .segment_pos
    fields to its `src` and `tgt` fields.

    Args:
      batch: a `.NestedMap` of input tensors to be packed. It is modified in
        place.
    """
    src_actual_seq_len = tf.math.reduce_sum(
        tf.cast(batch.src.ids_indicator, tf.int32), axis=1)
    tgt_actual_seq_len = tf.math.reduce_sum(
        tf.cast(batch.tgt.ids_indicator, tf.int32), axis=1)
    summary_utils.histogram('source_seq_lengths', src_actual_seq_len)
    summary_utils.histogram('target_seq_lengths', tgt_actual_seq_len)

    if not self.params.packing_factor:
      # Supply segment_ids and segment_pos with no packing.
      batch.src.segment_ids = batch.src.ids_indicator
      batch.src.segment_pos = _GetSegmentPos(batch.src.ids_indicator)
      batch.tgt.segment_ids = batch.tgt.ids_indicator
      batch.tgt.segment_pos = _GetSegmentPos(batch.tgt.ids_indicator)
      return

    (src_segment_ids, src_segment_pos, src_indices_in_input, tgt_segment_ids,
     tgt_segment_pos, tgt_indices_in_input) = ops.pack_sequences(
         src_actual_seq_len, tgt_actual_seq_len, self._ScaledBatchSize(),
         self.params.source_max_length, self.params.target_max_length)

    uniq_src_indices_in_input = tf.unique(
        tf.reshape(src_indices_in_input, [-1])).y
    uniq_tgt_indices_in_input = tf.unique(
        tf.reshape(tgt_indices_in_input, [-1])).y
    summary_utils.histogram(
        'packed_source_seq_lengths',
        tf.gather(src_actual_seq_len, uniq_src_indices_in_input, axis=0))
    summary_utils.histogram(
        'packed_target_seq_lengths',
        tf.gather(tgt_actual_seq_len, uniq_tgt_indices_in_input, axis=0))

    # We deferred adding .paddings and use its complement .ids_indicator
    # exclusively so that we can apply the packing with padding set to 0 for all
    # fields.
    def ApplyPackingToSource(x):
      if x.dtype == tf.string:
        return ops.apply_packing(x, '\t', src_segment_ids, src_indices_in_input)
      return ops.apply_packing(x, 0, src_segment_ids, src_indices_in_input)

    batch.src = batch.src.Transform(ApplyPackingToSource)
    batch.src.segment_ids = tf.cast(src_segment_ids, tf.float32)
    batch.src.segment_pos = src_segment_pos

    def ApplyPackingToTarget(x):
      if x.dtype == tf.string:
        return ops.apply_packing(x, '\t', tgt_segment_ids, tgt_indices_in_input)
      return ops.apply_packing(x, 0, tgt_segment_ids, tgt_indices_in_input)

    batch.tgt = batch.tgt.Transform(ApplyPackingToTarget)
    batch.tgt.segment_ids = tf.cast(tgt_segment_ids, tf.float32)
    batch.tgt.segment_pos = tgt_segment_pos
  def ComputeAndUpdateMoments(self, theta, inputs, paddings=None):
    """Computes moments and updates state.

    Args:
      theta: A `.NestedMap` object containing weights' values of this layer and
        its children layers.
      inputs: The inputs tensor.  Shaped [..., dim].
      paddings: The paddings tensor.  Shaped [..., 1], with the same rank as the
        input tensor.

    Returns:
      Tuple of (mean, variance, beta, gamma).
    """
    p = self.params
    if paddings is None:
      paddings = self._GetDefaultPaddings(inputs)
    inputs = py_utils.with_dependencies([
        py_utils.assert_shape_match([tf.shape(paddings)[-1]], [1]),
    ], inputs)
    with tf.name_scope(p.name):
      if self.do_eval:
        # The mean and variance used for normalization.
        norm_mean, norm_variance = (self.vars.moving_mean,
                                    self.vars.moving_variance)
      else:
        mean, variance = self._Moments(inputs, 1.0 - paddings,
                                       p.enable_cross_replica_sum_on_tpu)

        py_utils.UpdateBatchNormVars(self.vars.moving_mean, mean, self._decay)
        py_utils.UpdateBatchNormVars(self.vars.moving_variance, variance,
                                     self._decay)
        # Add some summaries for visualization.
        summary_utils.histogram('%s_mean' % p.name, tf.cast(mean, tf.float32))
        summary_utils.histogram('%s_variance' % p.name,
                                tf.cast(variance, tf.float32))
        summary_utils.histogram('%s_moving_mean' % p.name,
                                tf.cast(self.vars.moving_mean, tf.float32))
        summary_utils.histogram('%s_moving_variance' % p.name,
                                tf.cast(self.vars.moving_variance, tf.float32))
        summary_utils.histogram(
            '%s_mean_diff' % p.name,
            tf.cast(mean - self.vars.moving_mean, tf.float32))
        summary_utils.histogram(
            '%s_variance_diff' % p.name,
            tf.cast(variance - self.vars.moving_variance, tf.float32))
        if p.use_moving_avg_in_training:
          # Use the global statistics for normalization.
          # Control dependencies on mean and variance make sure
          # moving_mean and variance will be updated for every training step.
          norm_mean = py_utils.with_dependencies([mean], self.vars.moving_mean)
          norm_variance = py_utils.with_dependencies([variance],
                                                     self.vars.moving_variance)
        else:
          # Use the batch statistics for normalization.
          norm_mean = mean
          norm_variance = variance

      norm_mean = py_utils.CheckNumerics(
          norm_mean, 'mean of %s failed numeric check' % p.name)
      norm_variance = py_utils.CheckNumerics(
          norm_variance, 'variance of %s failed numeric check' % p.name)

      if p.use_moving_avg_in_training:
        beta = 0.0
        gamma = 1.0
      else:
        beta = theta.beta
        gamma = theta.gamma
      return norm_mean, norm_variance, beta, gamma