def _transform_images(self, params, features, labels=None):
        """Transforms images."""

        images = features['images']
        batch_size, _, _, c = images.get_shape().as_list()
        if params['conv0_space_to_depth_block_size'] != 0:
            # Transforms (space-to-depth) images for TPU performance.

            def _fused_transform(images, image_size):
                return spatial_transform.fused_transpose_and_space_to_depth(
                    images, image_size,
                    params['conv0_space_to_depth_block_size'],
                    params['transpose_input'])

            images = tf.cond(
                tf.less(features['image_info'][0, 3],
                        features['image_info'][0, 4]),
                lambda: _fused_transform(images, params['image_size']),
                lambda: _fused_transform(images, params['image_size'][::-1]))

        else:
            # Transposes images for TPU performance.
            image_area = params['image_size'][0] * params['image_size'][1]
            if params['transpose_input']:
                images = tf.transpose(images, [1, 2, 0, 3])
                # Flattens spatial dimensions so that the image tensor has a static
                # shape.
                images = tf.reshape(images, [image_area, batch_size, c])
            else:
                images = tf.reshape(images, [batch_size, image_area, c])

        if params['use_bfloat16']:
            images = tf.cast(images, dtype=tf.bfloat16)

        features['images'] = images

        if labels is not None:
            return features, labels
        else:
            return features, tf.zeros([batch_size])
Example #2
0
def _local_perm(inputs, targets, is_masked, perm_size, seq_len):
  """
  Sample a permutation of the factorization order, and create an
  attention mask accordingly.

  Args:
    inputs: int64 Tensor in shape [seq_len], input ids.
    targets: int64 Tensor in shape [seq_len], target ids.
    is_masked: bool Tensor in shape [seq_len]. True means being selected
      for partial prediction.
    perm_size: the length of longest permutation. Could be set to be reuse_len.
      Should not be larger than reuse_len or there will be data leaks.
    seq_len: int, sequence length.
  """

  # Generate permutation indices
  index = tf.range(seq_len, dtype=tf.int64)
  index = tf.transpose(tf.reshape(index, [-1, perm_size]))
  index = tf.random_shuffle(index)
  index = tf.reshape(tf.transpose(index), [-1])

  # `perm_mask` and `target_mask`
  # non-functional tokens
  non_func_tokens = tf.logical_not(tf.logical_or(
      tf.equal(inputs, SEP_ID),
      tf.equal(inputs, CLS_ID)))

  non_mask_tokens = tf.logical_and(tf.logical_not(is_masked), non_func_tokens)
  masked_or_func_tokens = tf.logical_not(non_mask_tokens)

  # Set the permutation indices of non-masked (& non-funcional) tokens to the
  # smallest index (-1):
  # (1) they can be seen by all other positions
  # (2) they cannot see masked positions, so there won"t be information leak
  smallest_index = -tf.ones([seq_len], dtype=tf.int64)
  rev_index = tf.where(non_mask_tokens, smallest_index, index)

  # Create `target_mask`: non-funcional and maksed tokens
  # 1: use mask as input and have loss
  # 0: use token (or [SEP], [CLS]) as input and do not have loss
  target_tokens = tf.logical_and(masked_or_func_tokens, non_func_tokens)
  target_mask = tf.cast(target_tokens, tf.float32)

  # Create `perm_mask`
  # `target_tokens` cannot see themselves
  self_rev_index = tf.where(target_tokens, rev_index, rev_index + 1)

  # 1: cannot attend if i <= j and j is not non-masked (masked_or_func_tokens)
  # 0: can attend if i > j or j is non-masked
  perm_mask = tf.logical_and(
      self_rev_index[:, None] <= rev_index[None, :],
      masked_or_func_tokens)
  perm_mask = tf.cast(perm_mask, tf.float32)

  # new target: [next token] for LM and [curr token] (self) for PLM
  new_targets = tf.concat([inputs[0: 1], targets[: -1]],
                          axis=0)

  # construct inputs_k
  inputs_k = inputs

  # construct inputs_q
  inputs_q = target_mask

  return perm_mask, new_targets, target_mask, inputs_k, inputs_q