コード例 #1
0
ファイル: input_generator.py プロジェクト: tensorflow/lingvo
    def _ParseRecord(self, record):
        """Reads and parses a single record."""
        p = self.params
        name_to_features = {
            'input_ids':
            tf.io.FixedLenFeature([p.max_sequence_length], tf.int64),
            'input_mask':
            tf.io.FixedLenFeature([p.max_sequence_length], tf.int64),
            'masked_lm_positions':
            tf.io.FixedLenFeature([p.max_predictions_per_seq], tf.int64),
            'masked_lm_ids':
            tf.io.FixedLenFeature([p.max_predictions_per_seq], tf.int64),
            'masked_lm_weights':
            tf.io.FixedLenFeature([p.max_predictions_per_seq], tf.float32),
        }
        example = tf.io.parse_single_example(record, name_to_features)
        mask_length = tf.cast(tf.reduce_sum(example['masked_lm_weights']),
                              dtype=tf.int32)
        masked_lm_positions = tf.slice(example['masked_lm_positions'], [0],
                                       [mask_length])
        masked_lm_ids = tf.cast(tf.slice(example['masked_lm_ids'], [0],
                                         [mask_length]),
                                dtype=tf.int32)
        ret = py_utils.NestedMap()
        ret.masked_ids = tf.cast(example['input_ids'], dtype=tf.int32)
        # Get back non-masked, original ids.
        ret.ids = tf.tensor_scatter_nd_update(tensor=ret.masked_ids,
                                              indices=tf.reshape(
                                                  masked_lm_positions,
                                                  [-1, 1]),
                                              updates=masked_lm_ids)
        ret.masked_pos = tf.tensor_scatter_nd_update(
            tensor=tf.zeros_like(ret.masked_ids, dtype=tf.float32),
            indices=tf.reshape(masked_lm_positions, [-1, 1]),
            updates=tf.ones_like(masked_lm_ids, dtype=tf.float32))
        ret.segment_ids = tf.cast(example['input_mask'], dtype=tf.float32)

        first_eos_idx = tf.where(tf.math.equal(ret.ids, p.eos_token_id))[0][0]

        def _RemoveFirstEos(x):
            # We remove the element at position `first_eos_idx`, and pad with 0
            # to keep length unchanged.
            zero = tf.constant(0, shape=(1, ), dtype=x.dtype)
            return tf.concat([x[:first_eos_idx], x[first_eos_idx + 1:], zero],
                             axis=0)

        ret = ret.Transform(_RemoveFirstEos)
        ret.paddings = 1.0 - ret.segment_ids
        pos = tf.cast(tf.range(p.max_sequence_length), dtype=tf.float32)
        ret.segment_pos = tf.cast(ret.segment_ids * pos, dtype=tf.int32)

        if p.remove_mask:
            del ret.masked_pos
            del ret.masked_ids
        return ret
コード例 #2
0
 def _Slice(tensor):
   """Return a slice of this tensor at time=state0.t."""
   shape = py_utils.GetShape(tensor)
   # All zeros except for t in the time dimension.
   # e.g. if params.axis=1, begin is [0, t, 0, 0, 0, ...]
   begin = tf.one_hot(self.params.axis, tf.rank(tensor), on_value=state0.t)
   # Same as shape, but with a 1 in the time dimension.
   # e.g. if params.axis=1, shape is [shape[0], 1, shape[2], shape[3], ...]
   size = tf.concat([
       shape[0:self.params.axis],
       tf.constant([1], dtype=tf.int32), shape[self.params.axis + 1:]
   ],
                    axis=0)
   # Make a slice where the time dimension is fixed at state0.t.
   time_slice = tf.slice(tensor, begin, size)
   # Remove the time dimension.
   return tf.squeeze(time_slice, axis=self.params.axis)
コード例 #3
0
ファイル: layers_with_gpipe.py プロジェクト: linhx13/lingvo
 def GetDecoderEmbeddingsDefaultTheta(self, input_ids, t=None):
     p = self.params
     seq_len = tf.shape(input_ids)[0]
     # [seq_len, batch, model_dim]
     input_embs = self.tgt_token_emb.EmbLookup(self.theta.tgt_token_emb,
                                               input_ids)
     # [seq_len, 1, model_dim]
     if t is None:
         pos_embs = tf.expand_dims(
             self.tgt_pos_emb.FProp(self.theta.tgt_pos_emb, seq_len), 1)
     else:  # Support decoding.
         pos_embs = tf.slice(
             self.tgt_pos_emb.FProp(self.theta.tgt_pos_emb, p.max_seq_len),
             [t, 0], [1, p.token_emb.embedding_dim])
     input_embs += pos_embs
     input_embs = self.tgt_dropout.FProp(self.theta.tgt_dropout, input_embs)
     return input_embs
コード例 #4
0
ファイル: layers_with_gpipe.py プロジェクト: xueyongfu/lingvo
 def GetDecoderEmbeddingsDefaultTheta(self, input_ids, task_ids=None, t=None):
   p = self.params
   input_embs = self.tgt_token_emb.EmbLookup(self.theta.tgt_token_emb,
                                             input_ids)
   if t is None:
     time_dim = 0 if p.batch_dim else 1
     seq_len = tf.shape(input_ids)[time_dim]
     pos_embs = tf.expand_dims(
         self.tgt_pos_emb.FProp(self.theta.tgt_pos_emb, seq_len), p.batch_dim)
   else:  # Support decoding.
     pos_embs = tf.slice(
         self.tgt_pos_emb.FProp(self.theta.tgt_pos_emb, p.max_seq_len), [t, 0],
         [1, p.token_emb.embedding_dim])
   input_embs += pos_embs
   if task_ids is not None and p.dec_task_emb:
     input_embs += self.tgt_task_emb.EmbLookup(self.theta.tgt_task_emb,
                                               task_ids)
   input_embs = self.tgt_dropout.FProp(self.theta.tgt_dropout, input_embs)
   return input_embs
コード例 #5
0
  def _PreprocessForTraining(self, image):
    """Distort one image for training a network.

    Args:
      image: The input image, a shape [height, width, num_channels=3] Tensor.
        Must be of type `tf.float32`. Image values are assumed to be in [0, 1].

    Returns:
      3-D float Tensor of distorted image used for training with range [0, 1].
    """
    p = self.params
    assert image.dtype == tf.float32

    crop_bbox_begin, crop_bbox_size, _ = tf.image.sample_distorted_bounding_box(
        tf.shape(image),
        # No objects of interest; use the whole image as input.
        bounding_boxes=tf.zeros([1, 1, 4], dtype=tf.float32),
        area_range=(p.training_crop_min_area, 1.0),
        use_image_if_no_bounding_boxes=True)
    image = tf.slice(image, crop_bbox_begin, crop_bbox_size)
    # Restore the shape since the dynamic slice based upon the bbox_size loses
    # the third dimension.
    image.set_shape([None, None, 3])

    # Bilinear resize to the target shape. Note this does not respect the
    # original aspect ratio and may distort the image.
    height, width = p.output_image_size
    image = tf.image.resize(image, [height, width], antialias=True)
    image.set_shape([height, width, 3])

    image = tf.image.random_flip_left_right(image)
    image = _DistortBrightnessAndColor(image)

    # [0, 1] => output_range
    image *= float(p.output_range[1] - p.output_range[0])
    image += p.output_range[0]
    return image
コード例 #6
0
  def _BeamSearchStep(self, theta, encoder_outputs, cur_step, step_ids,
                      core_bs_states, other_states, num_hyps_per_beam,
                      pre_beam_search_step_callback,
                      post_beam_search_step_callback):
    """Extend beam search hyps for one step.

      | num_beams = Number of source sequences to be decoded.
      | num_hyps_per_beam = Number of hyps to keep per source sequence.
      | num_hyps = num_beams * num_hyps_per_beam
      | src_seq_len = Number of time steps in the source sequence.
      | src_batch = Number of examples in the source sequence.
      | tgt_seq_len = Maximum allowed time steps in the target sequence.
      | tgt_batch = num_hyps_per_beam * src_batch

    Args:
      theta: A `.NestedMap` object containing weights' values of the decoder
        layer and its children layers.
      encoder_outputs: A `.NestedMap` containing encoder outputs to be passed to
        the callbacks.
      cur_step: A scalar int tensor, the current time step, 0-based.
      step_ids: An int tensor of shape [num_hyps, 1]. The input ids to the
        current search step.
      core_bs_states: A tuple of core beam search states. This list is
        maintained by this helper class.
      other_states: A `.NestedMap` of other beam search states. This
        `.NestedMap` is managed and updated by the client. It is expected that
        each of its member tensors are of rank >= 1. t[i, ...] is the state of
        the i-th hyp at the beginning of this search step.
      num_hyps_per_beam: Num of hyps to keep per beam.
      pre_beam_search_step_callback: The `PreBeamSearchStepCallback` callback.
        See class header comments for more details.
      post_beam_search_step_callback: The `PostBeamSearchStepCallback` callback.
        See class header comments for more details.

    Returns:
      A tuple of following elements for the next beam search step,
      (next step, all_done, step_ids, core_bs_states, other_states)
    """
    p = self.params

    bs_results, other_states = pre_beam_search_step_callback(
        theta, encoder_outputs, step_ids, other_states, num_hyps_per_beam)

    (best_scores, cumulative_scores, in_scores, in_hyps, in_prev_hyps,
     in_done_hyps, in_atten_probs) = core_bs_states

    (out_best_scores, out_cumulative_scores, out_scores, out_hyps,
     out_prev_hyps, out_done_hyps, out_atten_probs,
     all_done) = ops.beam_search_step(
         bs_results.log_probs,
         bs_results.atten_probs,
         best_scores,
         cumulative_scores,
         in_scores,
         in_hyps,
         in_prev_hyps,
         in_done_hyps,
         in_atten_probs,
         bs_results.is_last_chunk if self._model_uses_eoc_id else [],
         cur_step,
         eoc_id=p.target_eoc_id,
         eos_id=p.target_eos_id,
         beam_size=p.beam_size,
         num_hyps_per_beam=num_hyps_per_beam,
         valid_eos_max_logit_delta=p.valid_eos_max_logit_delta,
         merge_paths=p.merge_paths,
         allow_empty_terminated_hyp=p.allow_empty_terminated_hyp,
         ensure_full_beam=p.ensure_full_beam,
         force_eos_in_last_step=p.force_eos_in_last_step,
         local_eos_threshold=p.local_eos_threshold)

    new_step_ids = tf.reshape(out_hyps[cur_step, :], tf.shape(step_ids))
    new_step_ids.set_shape(step_ids.get_shape())

    old_hyp_ids = tf.reshape(
        tf.slice(out_prev_hyps, begin=[cur_step, 0], size=[1, -1]), [-1])

    if p.batch_major_compute:
      # Transformed the indices into the key/value cache for fast decoding
      # (prefix_states in other_states) due to the num_hyps dimension of
      # cache is computed as num_beams by num_hyps_per_beam, which is different
      # from the old_hyp_ids assumption (num_hyps_per_beam by num_beams).
      # Both transpose and recomputation are required to correct the indices.
      num_beams = tf.shape(best_scores)[0]
      old_hyp_ids_in_cache_order = tf.reshape(
          tf.transpose(tf.reshape(old_hyp_ids, [num_hyps_per_beam, -1])), [-1])
      old_hyp_ids_in_cache_order = (
          (old_hyp_ids_in_cache_order % num_beams) * num_hyps_per_beam +
          old_hyp_ids_in_cache_order // num_beams)

    new_bs_states = (out_best_scores, out_cumulative_scores, out_scores,
                     out_hyps, out_prev_hyps, out_done_hyps, out_atten_probs)

    def ReOrderHyps(x_in):
      """Reorders x_in based on prev hyp ids."""
      if (isinstance(x_in, tf.Tensor) and x_in.shape.ndims and
          x_in.shape.ndims > 0):
        if x_in.shape.ndims > 2 and not p.batch_major_state:
          # Use corrected indices only here for batch major compute as key/value
          # caches are the states being affected.
          correct_old_hyp_ids = (
              old_hyp_ids_in_cache_order
              if p.batch_major_compute else old_hyp_ids)
          x_out = tf.gather(x_in, correct_old_hyp_ids, axis=1)
        else:
          x_out = tf.gather(x_in, old_hyp_ids)
        x_out.set_shape(x_in.get_shape())
        return x_out
      else:
        return x_in

    new_other_states = other_states.Transform(ReOrderHyps)

    final_other_states = post_beam_search_step_callback(theta, encoder_outputs,
                                                        new_step_ids,
                                                        new_other_states)

    return (cur_step + 1, all_done, new_step_ids, new_bs_states,
            final_other_states)
コード例 #7
0
ファイル: pruning.py プロジェクト: snsun/lingvo
    def _maybe_update_block_mask(self, weights, threshold):
        """Performs block-granular masking of the weights.

    Block pruning occurs only if the block_height or block_width is > 1 and
    if the weight tensor, when squeezed, has ndims = 2. Otherwise, elementwise
    pruning occurs.

    Args:
      weights: The weight tensor that needs to be masked.
      threshold: The current threshold value. The function will compute a new
        threshold and return the exponential moving average using the current
        value of threshold

    Returns:
      new_threshold: The new value of the threshold based on weights, and
        sparsity at the current global_step
      new_mask: A numpy array of the same size and shape as weights containing
        0 or 1 to indicate which of the values in weights falls below
        the threshold

    Raises:
      ValueError: if block pooling function is not AVG or MAX
    """
        block_dims = self._get_block_dims(weights.op.name)
        squeezed_weights = tf.squeeze(weights)
        if squeezed_weights.get_shape().ndims != 2 or block_dims == [1, 1]:
            return self._update_mask(weights, threshold)

        for i in range(2):
            if block_dims[i] == -1:
                block_dims[i] = squeezed_weights.get_shape()[i]

        if self._block_pooling_function not in ['AVG', 'MAX']:
            raise ValueError(
                'Unknown pooling function for block sparsity: %s' %
                self._block_pooling_function)

        with tf.name_scope(weights.op.name + '_pruning_ops'):
            abs_weights = tf.abs(squeezed_weights)

            pool_window = block_dims
            pool_fn = pruning_utils.factorized_pool
            squeeze_axis = None
            if not self._spec.use_tpu:
                pool_fn = tf.nn.pool
                abs_weights = tf.reshape(abs_weights, [
                    1,
                    abs_weights.get_shape()[0],
                    abs_weights.get_shape()[1], 1
                ])
                squeeze_axis = [0, 3]

            pooled_weights = pool_fn(abs_weights,
                                     window_shape=pool_window,
                                     pooling_type=self._block_pooling_function,
                                     strides=pool_window,
                                     padding='SAME',
                                     name=weights.op.name + '_pooled')

            if pooled_weights.get_shape().ndims != 2:
                pooled_weights = tf.squeeze(pooled_weights, axis=squeeze_axis)

            smoothed_threshold, new_mask = self._update_mask(
                pooled_weights, threshold)

            updated_mask = pruning_utils.expand_tensor(new_mask, block_dims)
            sliced_mask = tf.slice(updated_mask, [0, 0], [
                squeezed_weights.get_shape()[0],
                squeezed_weights.get_shape()[1]
            ])

        return smoothed_threshold, tf.reshape(sliced_mask, tf.shape(weights))