Exemple #1
0
    def QuantizeTensors(self, t_name, ts, eval_only=False):
        p = self.params
        # Always straddle a real zero point.
        if self.do_eval:
            # At eval/inference time, use the memorized range.
            # Important: Don't capture these variables in training mode so as to
            # avoid extra/unnecessary captures.
            min_var = self._GetQStateVar(t_name, 'min')
            max_var = self._GetQStateVar(t_name, 'max')
            return [
                self._MaybeFakeQuant(t, min_var, max_var, num_bits=p.bits)
                for t in ts
            ]
        else:
            # At training time, use the batch calculated min/max.
            accumulator_name = self._GetAccumulatorNameForTensor(t_name)
            # Calculate min/max for all tensors.
            batch_min = 0.0
            batch_max = 0.0
            for t in ts:
                batch_min = tf.minimum(tf.reduce_min(t), batch_min)
                batch_max = tf.maximum(tf.reduce_max(t), batch_max)

            # New state.
            state1 = tf.stack([1.0, batch_min, batch_max])
            self.accumulators[accumulator_name].Update(state1)

            # Results.
            ts_out = []
            for i, t in enumerate(ts):
                if eval_only:
                    # If only quantizing at eval time, still record ranges as above
                    # but don't quantize.
                    quant_t = t
                else:
                    # If quantizing during training, skip quantization if it produces
                    # NANs. Sometimes early in the training process, things are unstable
                    # and ranges can produce numerical instability that makes it
                    # impossible to perform a fake_quant.
                    quant_t = self._MaybeFakeQuant(t,
                                                   batch_min,
                                                   batch_max,
                                                   num_bits=p.bits)
                    # TODO(laurenzo): Plumb quant_t_has_nans through state and report.
                    quant_t_has_nans = tf.math.is_nan(quant_t)
                    quant_t = tf.where(quant_t_has_nans, t, quant_t)
                ts_out.append(quant_t)
                summary_utils.histogram(
                    '%s/%s_%d' % (self._qvars_scope.name, t_name, i), t)
            return ts_out
    def FProp(self, theta, input_batch, state0=None):
        p = self.params
        src_segment_id = None
        with tf.name_scope(p.name):
            # Reshape to [t, b]
            inputs = py_utils.with_dependencies([
                py_utils.assert_shape_match(tf.shape(input_batch.ids),
                                            [-1, -1]),
                py_utils.assert_shape_match(tf.shape(input_batch.ids),
                                            tf.shape(input_batch.paddings))
            ], tf.transpose(input_batch.ids))
            paddings = tf.expand_dims(tf.transpose(input_batch.paddings), 2)

            # Setup streaming states.
            if not state0:
                state0 = self.zero_state(theta, tf.shape(inputs)[1])
            state1 = py_utils.NestedMap(rnn=[None] * p.num_lstm_layers)

            xs = self.emb.EmbLookup(theta.emb, inputs)
            xs = self.ApplyClipping(theta, xs)
            summary_utils.histogram('input_emb', xs)
            xs = self.dropout.FProp(theta.dropout, xs)
            ps = paddings
            # Now the rnn layers.
            outputs_list = []
            for i in range(0, p.num_lstm_layers):
                layer = self.rnn[i]
                ys, state1.rnn[i] = layer.FProp(theta.rnn[i],
                                                xs,
                                                ps,
                                                state0=state0.rnn[i])
                ys = self.dropout.FProp(theta.dropout, ys)
                if i >= p.residual_start:
                    xs += ys  # Residual skip
                    xs = self.ApplyClipping(theta, xs)
                else:
                    xs = ys
                outputs_list.append(xs)
                summary_utils.histogram('layer_out_%s' % i, xs)

            if p.is_transparent:
                xs = self.transparent_merger.FProp(theta.transparent_merger,
                                                   outputs_list)

            return py_utils.NestedMap(encoded=xs,
                                      padding=tf.squeeze(ps, [2]),
                                      segment_id=src_segment_id,
                                      state=state1)
Exemple #3
0
  def FProp(self, theta, inputs, *args):
    p = self.params
    with tf.name_scope(p.name) as scope:
      expert_dist = self._GetExpertDist(theta, inputs, *args)
      if not self.do_eval:
        summary_utils.histogram('soft_cond_{}'.format(scope), expert_dist)

      # Excludes non-variable extra_theta like global_step.
      var_set = set([key for key, _ in self.body.vars.FlattenItems()])
      values = []
      for key, value in theta.body.FlattenItems():
        if key in var_set and value is not None:
          # Weighted average for all variables created in the body layer.
          value = tf.einsum('i,i...->...', expert_dist, value)
        values.append(value)
      weighted_theta = theta.body.Pack(values)
      return self.body.FProp(weighted_theta, inputs, *args)
Exemple #4
0
    def _DataSourceToInputBatch(self):
        """The current input batch as a `.NestedMap` of input tensors."""
        ret, _ = self._BuildDataSource()
        self._Pack(ret)
        if 'weights' not in ret.src or 'weights' not in ret.tgt:
            ret.src.weights = ret.src.ids_indicator
            ret.tgt.weights = ret.tgt.ids_indicator
        if 'paddings' not in ret.src or 'paddings' not in ret.tgt:
            ret.src.paddings = 1 - ret.src.weights
            ret.tgt.paddings = 1 - ret.tgt.weights
        del ret.src.ids_indicator
        del ret.tgt.ids_indicator

        if self.params.pad_to_max_seq_length:
            assert self.params.source_max_length

            def _EnsureSrcShape(x):
                if x.dtype == tf.string:
                    return tf.ensure_shape(x, [self._ScaledBatchSize()])
                return tf.ensure_shape(
                    x,
                    [self._ScaledBatchSize(), self.params.source_max_length])

            def _EnsureTgtShape(x):
                if x.dtype == tf.string:
                    return tf.ensure_shape(x, [self._ScaledBatchSize()])
                return tf.ensure_shape(
                    x,
                    [self._ScaledBatchSize(), self.params.target_max_length])

            ret.src = ret.src.Transform(_EnsureSrcShape)
            ret.tgt = ret.tgt.Transform(_EnsureTgtShape)

        summary_utils.histogram('source_token_ids', ret.src.ids)
        summary_utils.histogram('target_token_ids', ret.tgt.ids)

        # Casts floating point tensors to fprop_dtype before returning.
        return ret.Transform(self.Cast)
    def FProp(self, theta, input_batch):
        p = self.params
        with tf.name_scope(p.name):
            inputs = py_utils.with_dependencies([
                py_utils.assert_shape_match(tf.shape(input_batch.ids),
                                            [-1, -1]),
                py_utils.assert_shape_match(tf.shape(input_batch.ids),
                                            tf.shape(input_batch.paddings))
            ], tf.transpose(input_batch.ids))
            paddings = tf.expand_dims(tf.transpose(input_batch.paddings), 2)
            if p.packed_input:
                src_segment_id = tf.expand_dims(
                    tf.transpose(input_batch.segment_ids), 2)
            else:
                src_segment_id = None
            xs = self.emb.EmbLookup(theta.emb, inputs)
            xs = self.ApplyClipping(theta, xs)
            summary_utils.histogram('input_emb', xs)
            xs = self.dropout.FProp(theta.dropout, xs)
            ps = paddings
            # Now the rnn layers.
            outputs_list = []
            for i in range(0, p.num_lstm_layers):
                layer = self.rnn[i]
                ys = layer.FProp(theta.rnn[i],
                                 xs,
                                 ps,
                                 segment_id=src_segment_id)
                ys = self.dropout.FProp(theta.dropout, ys)
                if i >= p.residual_start:
                    xs += ys  # Residual skip
                    xs = self.ApplyClipping(theta, xs)
                else:
                    xs = ys
                outputs_list.append(xs)
                summary_utils.histogram('layer_out_%s' % i, xs)

            if p.is_transparent:
                xs = self.transparent_merger.FProp(theta.transparent_merger,
                                                   outputs_list)

            if p.lstm_cell_size * 2 != p.encoder_out_dim:
                # Project to the right depth.
                xs = self.final_proj.FProp(theta.final_proj, xs, ps)
                summary_utils.histogram('final_proj_out', xs)

            if src_segment_id is not None:
                src_segment_id = tf.squeeze(src_segment_id, [2])

            return py_utils.NestedMap(encoded=xs,
                                      padding=tf.squeeze(ps, [2]),
                                      segment_id=src_segment_id)
Exemple #6
0
    def _Pack(self, batch):
        """Packs a given batch.

    Note that this may change the batch size.

    This function packs the input batch and adds .segment_ids and .segment_pos
    fields to its `src` and `tgt` fields.

    Args:
      batch: a `.NestedMap` of input tensors to be packed. It is modified in
        place.
    """
        src_actual_seq_len = tf.math.reduce_sum(tf.cast(
            batch.src.ids_indicator, tf.int32),
                                                axis=1)
        tgt_actual_seq_len = tf.math.reduce_sum(tf.cast(
            batch.tgt.ids_indicator, tf.int32),
                                                axis=1)
        summary_utils.histogram('source_seq_lengths', src_actual_seq_len)
        summary_utils.histogram('target_seq_lengths', tgt_actual_seq_len)

        if not self.params.packing_factor:
            # Supply segment_ids and segment_pos with no packing.
            batch.src.segment_ids = batch.src.ids_indicator
            batch.src.segment_pos = _GetSegmentPos(batch.src.ids_indicator)
            batch.tgt.segment_ids = batch.tgt.ids_indicator
            batch.tgt.segment_pos = _GetSegmentPos(batch.tgt.ids_indicator)
            return

        (src_segment_ids, src_segment_pos, src_indices_in_input,
         tgt_segment_ids, tgt_segment_pos,
         tgt_indices_in_input) = ops.pack_sequences(
             src_actual_seq_len, tgt_actual_seq_len, self._ScaledBatchSize(),
             self.params.source_max_length, self.params.target_max_length)

        uniq_src_indices_in_input = tf.unique(
            tf.reshape(src_indices_in_input, [-1])).y
        uniq_tgt_indices_in_input = tf.unique(
            tf.reshape(tgt_indices_in_input, [-1])).y
        summary_utils.histogram(
            'packed_source_seq_lengths',
            tf.gather(src_actual_seq_len, uniq_src_indices_in_input, axis=0))
        summary_utils.histogram(
            'packed_target_seq_lengths',
            tf.gather(tgt_actual_seq_len, uniq_tgt_indices_in_input, axis=0))

        # We deferred adding .paddings and use its complement .ids_indicator
        # exclusively so that we can apply the packing with padding set to 0 for all
        # fields.
        def ApplyPackingToSource(x):
            if x.dtype == tf.string:
                return ops.apply_packing(x, '\t', src_segment_ids,
                                         src_indices_in_input)
            return ops.apply_packing(x, 0, src_segment_ids,
                                     src_indices_in_input)

        batch.src = batch.src.Transform(ApplyPackingToSource)
        batch.src.segment_ids = tf.cast(src_segment_ids, tf.float32)
        batch.src.segment_pos = src_segment_pos

        def ApplyPackingToTarget(x):
            if x.dtype == tf.string:
                return ops.apply_packing(x, '\t', tgt_segment_ids,
                                         tgt_indices_in_input)
            return ops.apply_packing(x, 0, tgt_segment_ids,
                                     tgt_indices_in_input)

        batch.tgt = batch.tgt.Transform(ApplyPackingToTarget)
        batch.tgt.segment_ids = tf.cast(tgt_segment_ids, tf.float32)
        batch.tgt.segment_pos = tgt_segment_pos
    def ComputeAndUpdateMoments(self, theta, inputs, paddings=None):
        """Computes moments and updates state.

    Args:
      theta: A `.NestedMap` object containing weights' values of this layer and
        its children layers.
      inputs: The inputs tensor.  Shaped [..., dim].
      paddings: The paddings tensor.  Shaped [..., 1], with the same rank as the
        input tensor.

    Returns:
      Tuple of (mean, variance, beta, gamma).
    """
        p = self.params
        if paddings is None:
            paddings = self._GetDefaultPaddings(inputs)
        inputs = py_utils.with_dependencies([
            py_utils.assert_shape_match([tf.shape(paddings)[-1]], [1]),
        ], inputs)
        with tf.name_scope(p.name):
            if self.do_eval:
                # The mean and variance used for normalization.
                norm_mean, norm_variance = (self.vars.moving_mean,
                                            self.vars.moving_variance)
            else:
                mean, variance = self._Moments(
                    inputs, 1.0 - paddings, p.enable_cross_replica_sum_on_tpu)

                py_utils.UpdateBatchNormVars(self.vars.moving_mean, mean,
                                             self._decay)
                py_utils.UpdateBatchNormVars(self.vars.moving_variance,
                                             variance, self._decay)
                # Add some summaries for visualization.
                summary_utils.histogram('%s_mean' % p.name,
                                        tf.cast(mean, tf.float32))
                summary_utils.histogram('%s_variance' % p.name,
                                        tf.cast(variance, tf.float32))
                summary_utils.histogram(
                    '%s_moving_mean' % p.name,
                    tf.cast(self.vars.moving_mean, tf.float32))
                summary_utils.histogram(
                    '%s_moving_variance' % p.name,
                    tf.cast(self.vars.moving_variance, tf.float32))
                summary_utils.histogram(
                    '%s_mean_diff' % p.name,
                    tf.cast(mean - self.vars.moving_mean, tf.float32))
                summary_utils.histogram(
                    '%s_variance_diff' % p.name,
                    tf.cast(variance - self.vars.moving_variance, tf.float32))
                if p.use_moving_avg_in_training:
                    # Use the global statistics for normalization.
                    # Control dependencies on mean and variance make sure
                    # moving_mean and variance will be updated for every training step.
                    norm_mean = py_utils.with_dependencies(
                        [mean], self.vars.moving_mean)
                    norm_variance = py_utils.with_dependencies(
                        [variance], self.vars.moving_variance)
                else:
                    # Use the batch statistics for normalization.
                    norm_mean = mean
                    norm_variance = variance

            norm_mean = py_utils.CheckNumerics(
                norm_mean, 'mean of %s failed numeric check' % p.name)
            norm_variance = py_utils.CheckNumerics(
                norm_variance, 'variance of %s failed numeric check' % p.name)

            if p.use_moving_avg_in_training:
                beta = 0.0
                gamma = 1.0
            else:
                beta = theta.beta
                gamma = theta.gamma
            return norm_mean, norm_variance, beta, gamma