Example #1
0
    def _PaddedMaxFn(inp):
      """Apply padded max using reduce_max with paddings replaced by neginf."""
      # Replace all padded features with -inf.
      neginf_padding = tf.where(
          inp.padding > 0, -np.inf * inp.padding, inp.padding)
      features = inp.features + neginf_padding[..., tf.newaxis]
      features = tf.reduce_max(features, axis=-2)

      # Replace features of all padded points by zeros. If a batch of points are
      # all padded, then reduce_min over the padding will be 1. We set the
      # features to be zero, so that we don't get any downstream issue with
      # NaNs. Note that inf * 0 = NaN.
      all_padded = tf.cast(tf.reduce_min(inp.padding, axis=-1), tf.bool)
      all_padded = tf.broadcast_to(all_padded[..., tf.newaxis],
                                   py_utils.GetShape(features))
      features = tf.where(all_padded, tf.zeros_like(features), features)
      return py_utils.CheckNumerics(features)
Example #2
0
        def _GetFurthestPoint():
            """Get point that is furthest from those already selected.

      We also bias the sampling towards real points by setting the distance
      to padded points negative until we are out of real points.
      """
            # Set padded points distance to negative so they aren't selected.
            padding_masked_distance_to_selected = tf.where(
                tf.equal(padding, 0.0), distance_to_selected, -1.0 * tf.ones(
                    (batch_size, num_points), dtype=tf.float32))
            # But only do this when we still have valid points left.
            padding_masked_distance_to_selected = tf.where(
                tf.less(curr_idx, num_valid_points),
                padding_masked_distance_to_selected, distance_to_selected)
            return tf.argmax(padding_masked_distance_to_selected,
                             axis=-1,
                             output_type=tf.int32)
Example #3
0
def RandomPadOrTrimTo(tensor_list, num_points_out, seed=None):
    """Pads or Trims a list of Tensors on the major dimension.

  Slices if there are more points, or pads if not enough.

  In this implementation:
    Padded points are random duplications of real points.
    Sliced points are a random subset of the real points.

  Args:
    tensor_list: A list of tf.Tensor objects to pad or trim along first dim. All
      tensors are expected to have the same first dimension.
    num_points_out: An int for the requested number of points to trim/pad to.
    seed: Random seed to use for random generators.

  Returns:
    A tuple of output_tensors and a padding indicator.

    - output_tensors: A list of padded or trimmed versions of our tensor_list
      input tensors, all with the same first dimension.
    - padding: A tf.float32 tf.Tensor of shape [num_points_out] with 0 if the
      point is real, 1 if it is padded.
  """
    actual_num = tf.shape(tensor_list[0])[0]
    point_idx = tf.range(num_points_out, dtype=tf.int32)
    padding_tensor = tf.where(point_idx < actual_num,
                              tf.zeros([num_points_out], dtype=tf.float32),
                              tf.ones([num_points_out], dtype=tf.float32))

    def _Slicing():
        # Choose a random set of indices.
        indices = tf.range(actual_num)
        indices = tf.random_shuffle(indices, seed=seed)[:num_points_out]
        return [tf.gather(t, indices, axis=0) for t in tensor_list]

    def _Padding():
        indices = tf.random_uniform([num_points_out - actual_num],
                                    minval=0,
                                    maxval=actual_num,
                                    dtype=tf.int32,
                                    seed=seed)
        padded = []
        for t in tensor_list:
            padded.append(tf.concat([t, tf.gather(t, indices, axis=0)],
                                    axis=0))
        return padded

    def _PadZeros():
        padded = []
        for t in tensor_list:
            shape = tf.concat([[num_points_out], tf.shape(t)[1:]], axis=0)
            padded.append(tf.zeros(shape=shape, dtype=t.dtype))
        return padded

    data = tf.cond(
        actual_num > num_points_out, _Slicing,
        lambda: tf.cond(tf.equal(actual_num, 0), _PadZeros, _Padding))
    return (data, padding_tensor)
Example #4
0
 def Clipped():
   clip_ratio = state[0]
   min_value, max_value = self._GetCurrentMinMax(state, start_cap, end_cap,
                                                 bits)
   min_value = tf.stop_gradient(min_value)
   max_value = tf.stop_gradient(max_value)
   return tf.where(clip_ratio >= 0.0,
                   (lambda: tf.clip_by_value(x, min_value, max_value))(),
                   (lambda: x)())
Example #5
0
    def Value(self):
        p = self.params
        with tf.name_scope(p.name):
            steps = self._best_step
            best_step = steps[0]
            last_step = steps[1]

            ref_step = tf.maximum(self.theta.ref_step, best_step)
            f = self.theta.cur_factor

            # Decay if no improvement within window.
            new_factor = tf.where(last_step - ref_step < p.window, f,
                                  tf.maximum(p.min_factor, f * p.decay))
            # Update ref_step if we decayed.
            new_step = tf.where(tf.equal(new_factor, f), ref_step, last_step)
            update_step = tf.assign(self.vars.ref_step, new_step)
            with tf.control_dependencies([update_step]):
                return tf.assign(self.vars.cur_factor, new_factor)
    def CreateTpuEmbeddingEnqueueOps(self):
        """Creates the TpuEmbedding enqueue ops on the host.

    Note that this must be called after the instantiation of the
    monolithic TPUEmbeddingLayer.
    """
        p = self.params
        cluster = self.cluster
        num_tpu_hosts = cluster.num_tpu_hosts
        num_infeed_hosts = num_tpu_hosts if p.use_per_host_infeed else 1

        tpu_embedding_collection = tf.get_collection(py_utils.TPU_EMBEDDING)
        tpu_embedding = (tpu_embedding_collection[0]
                         if tpu_embedding_collection else None)

        enqueue_ops = []

        if num_tpu_hosts > 1 and tpu_embedding is not None:
            if not p.use_per_host_infeed:
                tf.logging.fatal(
                    'TPU Embedding must be used with per_host_infeed with multiple '
                    'TPU host topologies.')
        tpu_emb_input_keys = (list(tpu_embedding.feature_to_config_dict.keys())
                              if tpu_embedding is not None else [])
        tf.logging.info('tpu_emb_input_keys: %r', tpu_emb_input_keys)
        if not tpu_embedding:
            return

        for task_id in range(num_infeed_hosts):
            host_device = '/task:{}/device:CPU:0'.format(task_id)
            with tf.device(host_device):
                if isinstance(self._batch, py_utils.NestedMap):
                    # Hack: bucket_keys and xxx.bucket_keys are not needed on TPU.
                    # Note that when MultiTaskData is used, bucket_keys will be at the
                    # second level of the dictionary.
                    self._batch = self._batch.FilterKeyVal(
                        lambda k, _: not k.endswith('bucket_keys'))
                tf.logging.info('host_device: %s, batch: %r', host_device,
                                self._batch)

                enqueue_dict_per_core = [
                    {} for _ in range(tpu_embedding.num_cores_per_host)
                ]
                num_cores_per_host = tpu_embedding.num_cores_per_host
                for key in tpu_emb_input_keys:
                    feat = self._batch[key]
                    tpu_emb_feat_splitted = tf.split(feat, num_cores_per_host)
                    for core, split in enumerate(tpu_emb_feat_splitted):
                        # Dense to sparse. Note the assumption of a padding id.
                        sample_indices = tf.where(tf.not_equal(split, -1))
                        embedding_indices = tf.gather_nd(split, sample_indices)
                        enqueue_data = tpu_embedding_lib.EnqueueData(
                            embedding_indices, sample_indices)
                        enqueue_dict_per_core[core][key] = enqueue_data
                enqueue_ops += tpu_embedding.generate_enqueue_ops(
                    enqueue_dict_per_core)
        self._tpu_infeed_op.append(tf.group(*enqueue_ops))
Example #7
0
    def CpuEmbLookup(self, ids_map, partition_strategy):
        """CPU evaluation embedding lookup.

    Args:
      ids_map: A dict of `input_key` string -> [batch, sequence] int32 Tensor.
        -1 is used as a padding id.
      partition_strategy: See TPUEmbeddingLayer partition_strategy param.

    Returns:
      An activations dict of string -> float32 Tensor.
      For non-sequence embeddings: [batch, 1, embedding_dim]
      For sequence embeddings: [batch, max_sequence_length, embedding_dim]

    """
        p = self.params
        rets = py_utils.NestedMap()
        if self.max_sequence_length > 0:
            # "Sequence embedding", no combiner case
            for k, ids in ids_map.items():
                embs = tf.nn.embedding_lookup(
                    self.theta.wm,
                    tf.reshape(ids, [-1]),
                    partition_strategy=partition_strategy)
                out_shape = tf.concat([tf.shape(ids), [p.embedding_dim]], 0)
                rets[k] = tf.reshape(embs, out_shape)
        else:
            # Non-"Sequence embedding", combiner case
            for k, ids in ids_map.items():
                # Dense to sparse.
                dense_shape = tf.shape(ids, out_type=tf.int64)
                sample_indices = tf.cast(tf.where(tf.not_equal(ids, -1)),
                                         tf.int64)
                embedding_indices = tf.cast(tf.gather_nd(ids, sample_indices),
                                            tf.int64)
                sparse_ids = tf.SparseTensor(indices=sample_indices,
                                             values=embedding_indices,
                                             dense_shape=dense_shape)
                # [?, embedding_dim]
                # For tf.nn.embedding_lookup_sparse, output.dim0 might be different from
                # sparse_ids.dense_shape.dim0.
                # In fact, the '?' is the smallest span starting from the index=0 that
                # covers all the results.
                embs = tf.nn.embedding_lookup_sparse(
                    self.theta.wm,
                    sparse_ids,
                    None,  # sp_weights
                    combiner=p.combiner,
                    partition_strategy=partition_strategy)
                batch_size = dense_shape[0]
                # Explicitly pad results to maintain dim0=batch.
                dim0_padlen = tf.cast(batch_size, tf.int32) - tf.shape(embs)[0]
                embs = tf.pad(embs, [[0, dim0_padlen], [0, 0]])
                # [batch, 1, embedding_dim]
                embs = py_utils.HasShape(embs, [batch_size], ndims=1)
                rets[k] = tf.expand_dims(embs, 1)
        return rets
Example #8
0
 def FProp(self, theta, current_step):
   """Returns the current learning rate decay."""
   p = self.params
   current_step = tf.cast(current_step, tf.float32)
   warmup_steps = tf.cast(p.warmup_steps * p.worker_replicas, tf.float32)
   if p.decay_end is not None:
     current_step = tf.where(current_step < p.decay_end, current_step,
                             tf.cast(p.decay_end, tf.float32))
   return p.model_dim**-0.5 * tf.minimum(
       (current_step + 1) * warmup_steps**-1.5, (current_step + 1)**-0.5)
Example #9
0
 def grad_fn(d_outputs):
     with tf.name_scope("entmax_grad"):
         gppr = tf.where(p_m > 0, tf.math.pow(p_m, 2.0 - alpha),
                         tf.zeros_like(p_m))
         d_inputs = d_outputs * gppr
         q = tf.math.reduce_sum(d_inputs, axis) / tf.math.reduce_sum(
             gppr, axis)
         q = tf.expand_dims(q, axis)
         d_inputs -= q * gppr
         return d_inputs, d_inputs
Example #10
0
    def NMSIndices(self,
                   bboxes,
                   scores,
                   max_output_size,
                   nms_iou_threshold=0.3,
                   score_threshold=0.01):
        """Apply NMS to a series of 3d bounding boxes in 7-DOF format.

    Args:
      bboxes: A [num_boxes, 7] floating point Tensor of bounding boxes in [x, y,
        z, dx, dy, dz, phi] format.
      scores: A [num_boxes] floating point Tensor containing box
        scores.
      max_output_size: Maximum number of boxes to predict per input.
      nms_iou_threshold: IoU threshold to use when determining whether two boxes
        overlap for purposes of suppression.
      score_threshold: The score threshold passed to NMS that allows NMS to
        quickly ignore irrelevant boxes.

    Returns:
      The NMS indices and the mask of the padded indices.
    """
        bboxes = py_utils.HasShape(bboxes, [-1, 7])

        # Extract x, y, w, h, then convert to extrema.
        #
        # Note that we drop the rotation angle because we don't have an NMS
        # operation that takes rotation into account.
        bboxes_2d = tf.stack(
            [bboxes[:, 0], bboxes[:, 1], bboxes[:, 3], bboxes[:, 4]], axis=-1)
        bboxes_extrema = geometry.XYWHToBBoxes(bboxes_2d)

        # Compute NMS with padding; we use the padded version so this function can
        # be used in a map_fn.  This function returns the scalar number of boxes
        # for each example.
        #
        # We use an IoU threshold of 0.3 since our anchor boxes have rotations
        # that make the default IoU threshold of 0.5 possibly too high.
        nms_index_padded, num_valid = tf.image.non_max_suppression_padded(
            bboxes_extrema,
            scores,
            iou_threshold=nms_iou_threshold,
            max_output_size=max_output_size,
            score_threshold=score_threshold,
            pad_to_max_output_size=True)

        # Return the mask of valid indices instead of just a scalar number.
        mask = tf.concat(
            [tf.ones([num_valid]),
             tf.zeros([max_output_size - num_valid])],
            axis=0)

        nms_index_padded = tf.where(mask > 0, nms_index_padded,
                                    tf.zeros_like(nms_index_padded))
        return nms_index_padded, mask
Example #11
0
    def _ParseRecord(self, record):
        """Reads and parses a single record."""
        p = self.params
        name_to_features = {
            'input_ids':
            tf.io.FixedLenFeature([p.max_sequence_length], tf.int64),
            'input_mask':
            tf.io.FixedLenFeature([p.max_sequence_length], tf.int64),
            'masked_lm_positions':
            tf.io.FixedLenFeature([p.max_predictions_per_seq], tf.int64),
            'masked_lm_ids':
            tf.io.FixedLenFeature([p.max_predictions_per_seq], tf.int64),
            'masked_lm_weights':
            tf.io.FixedLenFeature([p.max_predictions_per_seq], tf.float32),
        }
        example = tf.io.parse_single_example(record, name_to_features)
        mask_length = tf.cast(tf.reduce_sum(example['masked_lm_weights']),
                              dtype=tf.int32)
        masked_lm_positions = tf.slice(example['masked_lm_positions'], [0],
                                       [mask_length])
        masked_lm_ids = tf.cast(tf.slice(example['masked_lm_ids'], [0],
                                         [mask_length]),
                                dtype=tf.int32)
        ret = py_utils.NestedMap()
        ret.masked_ids = tf.cast(example['input_ids'], dtype=tf.int32)
        # Get back non-masked, original ids.
        ret.ids = tf.tensor_scatter_nd_update(tensor=ret.masked_ids,
                                              indices=tf.reshape(
                                                  masked_lm_positions,
                                                  [-1, 1]),
                                              updates=masked_lm_ids)
        ret.masked_pos = tf.tensor_scatter_nd_update(
            tensor=tf.zeros_like(ret.masked_ids, dtype=tf.float32),
            indices=tf.reshape(masked_lm_positions, [-1, 1]),
            updates=tf.ones_like(masked_lm_ids, dtype=tf.float32))
        ret.segment_ids = tf.cast(example['input_mask'], dtype=tf.float32)

        first_eos_idx = tf.where(tf.math.equal(ret.ids, p.eos_token_id))[0][0]

        def _RemoveFirstEos(x):
            # We remove the element at position `first_eos_idx`, and pad with 0
            # to keep length unchanged.
            zero = tf.constant(0, shape=(1, ), dtype=x.dtype)
            return tf.concat([x[:first_eos_idx], x[first_eos_idx + 1:], zero],
                             axis=0)

        ret = ret.Transform(_RemoveFirstEos)
        ret.paddings = 1.0 - ret.segment_ids
        pos = tf.cast(tf.range(p.max_sequence_length), dtype=tf.float32)
        ret.segment_pos = tf.cast(ret.segment_ids * pos, dtype=tf.int32)

        if p.remove_mask:
            del ret.masked_pos
            del ret.masked_ids
        return ret
Example #12
0
 def Value(self):
     """Returns the current learning rate decay."""
     p = self.params
     current_step = tf.cast(py_utils.GetGlobalStep(), tf.float32)
     warmup_steps = tf.cast(p.warmup_steps * p.worker_replicas, tf.float32)
     if p.decay_end is not None:
         current_step = tf.where(current_step < p.decay_end, current_step,
                                 tf.cast(p.decay_end, tf.float32))
     return p.model_dim**-0.5 * tf.minimum(
         (current_step + 1) * warmup_steps**(p.decay_factor - 1.0),
         (current_step + 1)**tf.cast(p.decay_factor, tf.float32))
Example #13
0
 def _Lookup(ids):
   # Dense to sparse.
   dense_shape = tf.shape(ids, out_type=tf.int64)
   sample_indices = tf.cast(tf.where(tf.not_equal(ids, -1)), tf.int64)
   embedding_indices = tf.cast(tf.gather_nd(ids, sample_indices), tf.int64)
   # [?, embedding_dim]
   sparse_ids = tf.SparseTensor(
       indices=sample_indices,
       values=embedding_indices,
       dense_shape=dense_shape)
   return self._CombinerEmbLookup(sparse_ids, partition_strategy)
Example #14
0
        def Step(recurrent_theta, state0, inputs):
            """Computes one decoder step."""
            if p.use_recurrent:
                del inputs
            with tf.name_scope('single_sampler_step'):
                # Compute logits and states.
                bs_result, bs_state1 = pre_step_callback(
                    decoder_theta,
                    recurrent_theta.encoder_outputs,
                    tf.expand_dims(state0.ids, 1),  # [batch, 1].
                    state0.bs_state,
                    num_hyps_per_beam=p.num_hyps_per_beam)
                batch = tf.shape(bs_result.log_probs)[0]
                state1 = py_utils.NestedMap(timestep=state0.timestep + 1)
                state1.logits = bs_result.log_probs

                if p.top_k > 0:
                    topk_logits, topk_ids = tf.math.top_k(state1.logits,
                                                          k=p.top_k)
                    sample_logits = tf.nn.log_softmax(
                        topk_logits) if p.top_k_renormalize else topk_logits
                else:
                    sample_logits = state1.logits

                # Sample ids from logits. [batch].
                ids = tf.reshape(
                    tf.random.stateless_categorical(
                        sample_logits / p.temperature,
                        num_samples=1,
                        seed=tf.stack(
                            [recurrent_theta.random_seed, state0.timestep]),
                        dtype=state0.ids.dtype,
                        name='sample_next_id'), [batch])
                state1.ids = tf.gather(topk_ids, ids, axis=1,
                                       batch_dims=1) if p.top_k > 0 else ids

                if 'is_last_chunk' in bs_result and p.target_eoc_id >= 0:
                    state1.ids = tf.where(
                        tf.math.logical_and(
                            bs_result.is_last_chunk,
                            tf.equal(state1.ids, p.target_eoc_id)),
                        tf.fill(tf.shape(state1.ids), p.target_eos_id),
                        state1.ids)
                state1.bs_state = post_step_callback(
                    decoder_theta, recurrent_theta.encoder_outputs, state1.ids,
                    bs_state1)
            if p.use_recurrent:
                return state1, py_utils.NestedMap()
            else:
                inputs.ids = inputs.ids.write(state0.timestep, state1.ids)
                inputs.logits = inputs.logits.write(state0.timestep,
                                                    state1.logits)
                return (recurrent_theta, state1, inputs)
Example #15
0
 def FProp(self, theta, current_step):
   """Returns the current learning rate decay."""
   params = self.params
   warmup_steps = tf.to_float(params.decay_start * params.worker_replicas)
   current_step = tf.to_float(current_step)
   if params.decay_end is not None:
     current_step = tf.where(current_step < params.decay_end, current_step,
                             tf.to_float(params.decay_end))
   peak_learning_rate = (warmup_steps**-0.5)
   return (params.model_dim**-0.5) * tf.minimum(
       tf.minimum((current_step + 1),
                  (current_step + 1)**-0.5), peak_learning_rate)
Example #16
0
    def Polynomial(x):
      """Polynomial function of x."""
      p = self.params
      x0, y0 = p.start
      x1, y1 = p.limit

      assert x0 < x1, '%s must be < %s' % (x0, x1)

      x0 = tf.cast(x0, dtype=x.dtype)
      x1 = tf.cast(x1, dtype=x.dtype)
      y0 = tf.cast(y0, dtype=x.dtype)
      y1 = tf.cast(y1, dtype=x.dtype)

      ratio = (x - x0) / (x1 - x0)
      if p.origin == 'start':
        f_x = ratio**p.power
      elif p.origin == 'limit':
        f_x = 1 - (1 - ratio)**p.power
      else:
        raise ValueError('Invalid parameter origin: %s' % p.origin)
      y = y0 + f_x * (y1 - y0)
      return tf.where(x < x0, y0, tf.where(x >= x1, y1, y))
Example #17
0
 def Value(self):
   """Returns the current learning rate decay."""
   params = self.params
   warmup_steps = tf.cast(params.decay_start * params.worker_replicas,
                          tf.float32)
   current_step = tf.cast(py_utils.GetGlobalStep(), tf.float32)
   if params.decay_end is not None:
     current_step = tf.where(current_step < params.decay_end, current_step,
                             tf.cast(params.decay_end, tf.float32))
   peak_learning_rate = (warmup_steps**-0.5)
   return (params.model_dim**-0.5) * tf.minimum(
       tf.minimum((current_step + 1),
                  (current_step + 1)**-0.5), peak_learning_rate)
Example #18
0
  def CombineStates(self, state0, state1, switch_cond):
    """Combines states based on a switch conditional.

    Args:
      state0: a NestedMap of states to use for batch elements where switch_cond
        is true.
      state1: a NestedMap of states to use for batch elements where switch_cond
        is false.
      switch_cond: bool tensor of shape [batch] on which to switch.

    Returns:
      state_combined: a NestedMap of states.
    """
    updated_rnn_states = []
    for i in range(self.params.rnns.num_layers):
      updated_rnn_states.append(
          py_utils.NestedMap({
              'c': tf.where(switch_cond, state0.rnn[i].c, state1.rnn[i].c),
              'm': tf.where(switch_cond, state0.rnn[i].m, state1.rnn[i].m)
          }))
    combined_state = py_utils.NestedMap({'rnn': updated_rnn_states})
    return combined_state
Example #19
0
 def PostTrainingStepUpdate(self, global_step):
     """Updates moving_mean, moving_variance after each training step."""
     p = self.params
     # Get sufficient stats that accumulates over microbatches.
     counts = self.accumulators.counts.GetValue()
     mean_ss = self.accumulators.mean_ss.GetValue()
     variance_ss = self.accumulators.variance_ss.GetValue()
     # Compute batch mean and batch variance from sufficient stats
     mean, variance = tf.nn.normalize_moments(counts, mean_ss, variance_ss,
                                              None)
     decay = tf.convert_to_tensor(1.0 - p.decay, p.dtype)
     # Update moving_mean, moving_variance from  batch mean and batch variance.
     with tf.name_scope(p.name) as scope:
         with tf.colocate_with(self.vars.moving_mean):
             mean_update = tf.assign_sub(
                 self.vars.moving_mean,
                 tf.where(tf.greater(counts, 0.5),
                          (self.vars.moving_mean - tf.cast(mean, p.dtype)) *
                          decay, tf.zeros_like(self.vars.moving_mean)),
                 name='moving_mean_update')
         with tf.colocate_with(self.vars.moving_variance):
             var_update = tf.assign_sub(
                 self.vars.moving_variance,
                 tf.where(tf.greater(counts, 0.5),
                          (self.vars.moving_variance -
                           tf.cast(variance, p.dtype)) * decay,
                          tf.zeros_like(self.vars.moving_variance)),
                 name='moving_variance_update')
         py_utils.CheckNumerics(
             self.vars.moving_mean,
             'moving mean of {} failed numeric check'.format(scope))
         py_utils.CheckNumerics(
             self.vars.moving_variance,
             'moving variance of {} failed numeric check'.format(scope))
     self.accumulators.counts.Reset()
     self.accumulators.mean_ss.Reset()
     self.accumulators.variance_ss.Reset()
     return tf.group(mean_update, var_update)
Example #20
0
    def QuantizeTensors(self, t_name, ts, eval_only=False):
        p = self.params
        # Always straddle a real zero point.
        if self.do_eval:
            # At eval/inference time, use the memorized range.
            # Important: Don't capture these variables in training mode so as to
            # avoid extra/unnecessary captures.
            min_var = self._GetQStateVar(t_name, 'min')
            max_var = self._GetQStateVar(t_name, 'max')
            return [
                self._MaybeFakeQuant(t, min_var, max_var, num_bits=p.bits)
                for t in ts
            ]
        else:
            # At training time, use the batch calculated min/max.
            accumulator_name = self._GetAccumulatorNameForTensor(t_name)
            # Calculate min/max for all tensors.
            batch_min = 0.0
            batch_max = 0.0
            for t in ts:
                batch_min = tf.minimum(tf.reduce_min(t), batch_min)
                batch_max = tf.maximum(tf.reduce_max(t), batch_max)

            # New state.
            state1 = tf.stack([1.0, batch_min, batch_max])
            self.accumulators[accumulator_name].Update(state1)

            # Results.
            ts_out = []
            for i, t in enumerate(ts):
                if eval_only:
                    # If only quantizing at eval time, still record ranges as above
                    # but don't quantize.
                    quant_t = t
                else:
                    # If quantizing during training, skip quantization if it produces
                    # NANs. Sometimes early in the training process, things are unstable
                    # and ranges can produce numerical instability that makes it
                    # impossible to perform a fake_quant.
                    quant_t = self._MaybeFakeQuant(t,
                                                   batch_min,
                                                   batch_max,
                                                   num_bits=p.bits)
                    # TODO(laurenzo): Plumb quant_t_has_nans through state and report.
                    quant_t_has_nans = tf.math.is_nan(quant_t)
                    quant_t = tf.where(quant_t_has_nans, t, quant_t)
                ts_out.append(quant_t)
                summary_utils.histogram(
                    '%s/%s_%d' % (self._qvars_scope.name, t_name, i), t)
            return ts_out
Example #21
0
    def FProp(self, theta, inputs, paddings):
        p = self.params
        fns = self.fns

        # It is the most important that weights and top-level activations
        # be tagged for quantization:
        #   - Weights use the self.QWeight() decorator
        #   - Inputs/activations are decorated with self.QTensor(). In general,
        #     the provided name should match a call to self.TrackQTensor in the
        #     constructor. This creates an tensor that is individually accounted
        #     for.
        w = fns.qweight(theta.w)

        # TODO(shivaniagrawal): change this to ToAqtWeight and FromAqtWeight.
        w = self.ToAqtWeight('aqt_w',
                             w,
                             feature_axis=-1,
                             expected_scale_shape=(1, p.output_dim))

        inputs = self.QTensor('inputs', inputs)

        # Note the use of the qmatmul from the function library. This will
        # automatically track the output against the qtensor 'transformed'.
        out = fns.qmatmul(tf.reshape(inputs, [-1, p.input_dim]),
                          w,
                          qt='transformed')
        out = self.FromAqtWeight('aqt_w', out, feature_axis=-1)

        out = tf.reshape(out,
                         tf.concat([tf.shape(inputs)[:-1], [p.output_dim]], 0))

        # Decorate outputs of simple activation functions with their corresponding
        # range decorator. This will ensure that the result does not exceed the
        # precision of the underlying representation.
        out = fns.qtanh(out)

        # Perform padding manipulation via booleans instead of:
        #   out *= 1.0 - paddings
        # Because the paddings can exist in entirely different numeric ranges than
        # the tensor they are being applied to, it is best to not perform
        # arithmetic directly between them. Instead, broadcast them to the needed
        # size (if different) and perform an exact mask with tf.where.
        # For added numeric range protection, the QRPadding decorator ensures
        # the correct range. This is mostly needed for cases where padding is
        # dynamic at inference time.
        paddings = self.QRPadding(paddings)
        paddings *= tf.ones_like(out)  # Broadcast to 'out' size.
        out = tf.where(paddings > 0.0, tf.zeros_like(out), out)

        return out
Example #22
0
  def _MaybeFakeQuant(self, inputs, min_v, max_v, num_bits):
    p = self.params

    def Apply():
      return tf.quantization.fake_quant_with_min_max_vars(
          inputs, min_v, max_v, num_bits=num_bits)

    if p.delay_start_steps != 0 and not self.do_eval:
      if p.delay_start_steps == -1:
        return inputs
      return tf.where(self.theta.global_step >= p.delay_start_steps, Apply(),
                      inputs)
    else:
      return Apply()
Example #23
0
    def forward(inputs, alpha):
        with tf.name_scope("entmax_loss"):
            alpha_shape = inputs.get_shape().as_list()

            alpha_shape[axis] = 1
            alpha = tf.fill(alpha_shape, alpha)
            alpha = tf.cast(alpha, dtype=inputs.dtype)

            d = inputs.get_shape().as_list()[axis]
            alpha_m1 = alpha - 1.0

            inputs = inputs * alpha_m1

            max_val = tf.math.reduce_max(inputs, axis=axis, keepdims=True)
            tau_lo = max_val - tf.ones(alpha.get_shape().as_list(),
                                       dtype=inputs.dtype)
            tau_hi = max_val - tf.math.pow(
                tf.cast((1.0 / d), dtype=inputs.dtype), alpha_m1)

            f_lo = tf.math.reduce_sum(
                _calculate_probability(tf.math.subtract(inputs, tau_lo),
                                       alpha), axis) - 1.0

            dm = tau_hi - tau_lo

            for _ in range(n_iter):
                dm /= 2
                tau_m = tau_lo + dm
                p_m = _calculate_probability(inputs - tau_m, alpha)
                f_m = tf.math.reduce_sum(p_m, axis) - 1.0

                mask = tf.expand_dims(tf.math.greater(f_m * f_lo, 0), axis)
                tau_lo = tf.where(mask, tau_m, tau_lo)

            if ensure_sum_one:
                p_m /= tf.expand_dims(tf.math.reduce_sum(p_m, axis), axis)

        def grad_fn(d_outputs):
            with tf.name_scope("entmax_grad"):
                gppr = tf.where(p_m > 0, tf.math.pow(p_m, 2.0 - alpha),
                                tf.zeros_like(p_m))
                d_inputs = d_outputs * gppr
                q = tf.math.reduce_sum(d_inputs, axis) / tf.math.reduce_sum(
                    gppr, axis)
                q = tf.expand_dims(q, axis)
                d_inputs -= q * gppr
                return d_inputs, d_inputs

        return p_m, grad_fn
Example #24
0
    def Value(self):
        p = self.params
        x = tf.cast(py_utils.GetGlobalStep(), dtype=p.dtype)
        x0, y0 = p.start
        x1, y1 = p.limit

        if x0 >= x1:
            raise ValueError(f'{x0} must be < {x1}')

        x0 = tf.cast(x0, dtype=x.dtype)
        x1 = tf.cast(x1, dtype=x.dtype)
        y0 = tf.cast(y0, dtype=x.dtype)
        y1 = tf.cast(y1, dtype=x.dtype)

        ratio = (x - x0) / (x1 - x0)
        if p.origin == 'start':
            f_x = ratio**p.power
        elif p.origin == 'limit':
            f_x = 1 - (1 - ratio)**p.power
        else:
            raise ValueError('Invalid parameter origin: %s' % p.origin)

        y = y0 + f_x * (y1 - y0)
        return tf.where(x < x0, y0, tf.where(x >= x1, y1, y))
Example #25
0
  def _RecordTensor(self, t_name):
    p = self.params
    if self.do_eval:
      return []

    accumulator_name = self._GetAccumulatorNameForTensor(t_name)
    accumulator = self.accumulators[accumulator_name]
    min_var = self._GetQStateVar(t_name, 'min')
    max_var = self._GetQStateVar(t_name, 'max')

    # Unpack state tensor.
    current_value = accumulator.GetValue()
    count = current_value[0]
    min_value = current_value[1]
    max_value = current_value[2]
    accumulator.Reset()

    def Ema(variable, value):
      return (1.0 - p.ema_decay) * (variable - value)

    # Note that small floating point issues can cause ranges that naturally
    # begin or end at zero to move slightly past, causing hard failures
    # downstream (checks that all ranges straddle zero). We therefore repeat
    # the straddling constraint here.
    return [
        tf.assign(
            min_var,
            tf.minimum(
                0.,
                min_var - tf.where(count > 0., Ema(min_var, min_value), 0.))),
        tf.assign(
            max_var,
            tf.maximum(
                0.,
                max_var - tf.where(count > 0., Ema(max_var, max_value), 0.))),
    ]
Example #26
0
        def _GetRandomRealPoint():
            """Select the first point.

      For the first point, we want any random real (non padded) point, so we
      create a random values per point, and then set all padded ones to
      some large value (more than the maxval). We then take the min per batch
      element to get the first points.
      """
            random_values = tf.random.uniform((batch_size, num_points),
                                              minval=0,
                                              maxval=1,
                                              dtype=tf.float32,
                                              seed=random_seed)
            random_values = tf.where(tf.equal(padding, 0.0), random_values,
                                     padding * 10)
            return tf.argmin(random_values, axis=1, output_type=tf.int32)
Example #27
0
def FillPaddingPos(ids: tf.Tensor, id_len: tf.Tensor,
                   padding_value: int) -> tf.Tensor:
    """Given a batch of sequences, fills the padding pos with `padding_value`.

  Args:
    ids: a [B, max_len] int tensor.
    id_len: a [B, ] int tensor.
    padding_value: an int.

  Returns:
    new_ids: new ids with the property.
      - new_ids[b, :id_len[b]] = ids[b, :id_len[b]]
      - new_ids[b, id_len[b]:] = padding_value
  """
    mask = py_utils.SequencePaddings(id_len, maxlen=tf.shape(ids)[1])
    mask = tf.cast(mask, dtype=tf.bool)
    new_ids = tf.where(mask, tf.fill(tf.shape(ids), padding_value), ids)
    return new_ids
Example #28
0
 def QuantizeWeight(self, w):
   p = self.params
   w_min = tf.reduce_min(w)
   w_max = tf.reduce_max(w)
   # NOTE: We force a small, non-zero range because otherwise, zero weights
   # can cause downstream inference engines to blow up.
   w_min = tf.minimum(w_min, -p.quantize_weight_epsilon)
   w_max = tf.maximum(w_max, p.quantize_weight_epsilon)
   quant_w = self._MaybeFakeQuant(w, w_min, w_max, num_bits=p.bits)
   if self.do_eval:
     return quant_w
   else:
     # If quantizing during training, skip quantization if it produces
     # NANs. Sometimes early in the training process, things are unstable
     # and ranges can produce numerical instability that makes it
     # impossible to perform a fake_quant.
     quant_w_has_nans = tf.math.is_nan(quant_w)
     return tf.where(quant_w_has_nans, w, quant_w)
Example #29
0
    def _PaddedMeanFn(inp):
      """Apply padded mean using reduce_sum and dividing by # real points."""
      # Replace all padded features with 0 by masking the padded features out.
      mask = 1 - inp.padding
      features = inp.features * mask[..., tf.newaxis]
      features = tf.reduce_sum(features, axis=-2)
      num_real_points = tf.reduce_sum(mask, axis=-1, keep_dims=True)
      # Prevent the divisor of our padded mean from ever being 0, so that
      # the gradient flowing back through this op doesn't give us NaNs.
      num_real_points = tf.maximum(num_real_points, 1)
      features = features / num_real_points

      # Replace features of all padded points by zeros. If a batch of points are
      # all padded, then num_real_points will be zero. We set the features to be
      # zero, so that we don't get any downstream issue with NaNs.
      # Note that inf * 0 = NaN.
      all_padded = tf.equal(num_real_points, 0.)
      all_padded = tf.broadcast_to(all_padded, py_utils.GetShape(features))
      features = tf.where(all_padded, tf.zeros_like(features), features)
      return py_utils.CheckNumerics(features)
Example #30
0
def ComputeWer(hyps, refs):
    """Computes word errors in hypotheses relative to reference transcripts.

  Args:
    hyps: Hypotheses, represented as string tensors of shape [N].
    refs: References, represented as string tensors of shape [N].

  Returns:
    An int64 tensor, word_errs, of size [N, 2] where word_errs[i, 0] corresponds
    to the number of word errors in hyps[i] relative to refs[i]; word_errs[i, 1]
    corresponds to the number of words in refs[i].
  """
    def _NormalizeWhitespace(s):
        return tf.strings.regex_replace(tf.strings.strip(s), r'\s+', ' ')

    hyps = _NormalizeWhitespace(hyps)
    refs = _NormalizeWhitespace(refs)

    hyps = py_utils.HasRank(hyps, 1)
    refs = py_utils.HasRank(refs, 1)
    hyps = py_utils.HasShape(hyps, tf.shape(refs))

    word_errors = tf.cast(
        tf.edit_distance(tf.string_split(hyps),
                         tf.string_split(refs),
                         normalize=False), tf.int64)

    # Count number of spaces in reference, and increment by 1 to get total number
    # of words.
    ref_words = tf.cast(
        tf.strings.length(tf.strings.regex_replace(refs, '[^ ]', '')) + 1,
        tf.int64)
    # Set number of words to 0 if the reference was empty.
    ref_words = tf.where(tf.equal(refs, ''),
                         tf.zeros_like(ref_words, tf.int64), ref_words)

    return tf.concat(
        [tf.expand_dims(word_errors, -1),
         tf.expand_dims(ref_words, -1)],
        axis=1)