コード例 #1
0
    def get_hit_rate(self):
        total = self.var[0]
        hits = self.var[1]
        hit_rate = (tf.cast(hits, tf.float32) / tf.maximum(
            tf.constant(1, dtype=tf.float32), tf.cast(total, tf.float32)))

        with tf.control_dependencies([hit_rate]):
            update_var = self.var.assign_add([-total, -hits])
        with tf.control_dependencies([update_var]):
            return tf.identity(hit_rate)
コード例 #2
0
def _convert_example(example, use_bfloat16):
  """Cast int64 into int32 and float32 to bfloat16 if use_bfloat16."""
  for key in list(example.keys()):
    val = example[key]
    if tf.keras.backend.is_sparse(val):
      val = tf.sparse.to_dense(val)
    if val.dtype == tf.int64:
      val = tf.cast(val, tf.int32)
    if use_bfloat16 and val.dtype == tf.float32:
      val = tf.cast(val, tf.bfloat16)

    example[key] = val
コード例 #3
0
def parse_tfexample_sequence(example_proto,
                             img_height=48,
                             img_width=48,
                             action_size=None,
                             episode_length=16):
    """Parse TFExamples saved by episode_to_transitions.

  Args:
    example_proto: tf.String tensor representing a serialized protobuf.
    img_height: Height of parsed image tensors.
    img_width: Width of parsed image tensors.
    action_size: Size of continuous actions. If None, actions are assumed to be
      integer-encoded discrete actions.
    episode_length: Intended length of each episode.
  Returns:
    NamedTuple of type SARSTransition containing unbatched Tensors.
  """
    if action_size is None:
        # Is discrete.
        action_feature_spec = tf.FixedLenFeature((episode_length, ), tf.int64)
    else:
        # Vector-encoded float feature.
        action_feature_spec = tf.FixedLenFeature((episode_length, action_size),
                                                 tf.float32)

    features = {
        'S/img': tf.FixedLenFeature((episode_length, ), tf.string),
        'A': action_feature_spec,
        'R': tf.FixedLenFeature((episode_length, ), tf.float32),
        'S_p1/img': tf.FixedLenFeature((episode_length, ), tf.string),
        'done': tf.FixedLenFeature((episode_length, ), tf.int64),
        't': tf.FixedLenFeature((episode_length, ), tf.int64)
    }
    parsed_features = tf.parse_single_example(example_proto, features)
    # Decode the jpeg-encoded images into numeric tensors.
    states = []
    for key in 'S/img', 'S_p1/img':
        state = tf.stack([
            tf.image.decode_jpeg(img, channels=3)
            for img in tf.unstack(parsed_features[key], num=episode_length)
        ])
        state.set_shape([episode_length, img_height, img_width, 3])
        states.append(tf.cast(state, tf.float32))

    action = parsed_features['A']
    reward = parsed_features['R']
    done = tf.cast(parsed_features['done'], tf.float32)
    step = tf.cast(parsed_features['t'], tf.int32)
    aux = {'step': step}

    return SARSTransition((states[0], step), action, reward,
                          (states[1], step + 1), done, aux)
コード例 #4
0
ファイル: model_utils.py プロジェクト: samiraabnar/language
def hamming_loss(preds, targets, sign=False):
  """Implements hamming loss.

  Args:
    preds: Tensor of predicted values.
    targets: Tensor of target values.
    sign (bool): Set to True if targets={-1, 1} to take the sign of preds
    before calculating loss.

  Returns:
    A tf.metrics tuple containing the proportion of incorrect predictions and an
    update op for the metric.
  """
  if sign:
    preds = tf.sign(preds)
  equal = tf.equal(preds, tf.cast(targets, preds.dtype))
  proportion_correct, update_op = tf.metrics.mean(tf.cast(equal, tf.float32))
  return 1 - proportion_correct, update_op
コード例 #5
0
    def label_anchors(self, gt_boxes, gt_labels):
        """Labels anchors with ground truth inputs.

    Args:
      gt_boxes: A float tensor with shape [N, 4] representing groundtruth boxes.
        For each row, it stores [y0, x0, y1, x1] for four corners of a box.
      gt_labels: A integer tensor with shape [N, 1] representing groundtruth
        classes.
    Returns:
      cls_targets_dict: ordered dictionary with keys
        [min_level, min_level+1, ..., max_level]. The values are tensor with
        shape [height_l, width_l, num_anchors * num_classes]. The
        height_l and width_l represent the dimension of class logits at l-th
        level.
      box_targets_dict: ordered dictionary with keys
        [min_level, min_level+1, ..., max_level]. The values are tensor with
        shape [height_l, width_l, num_anchors * 4]. The height_l and
        width_l represent the dimension of bounding box regression output at
        l-th level.
      num_positives: scalar tensor storing number of positives in an image.
    """
        gt_box_list = box_list.BoxList(gt_boxes)
        anchor_box_list = box_list.BoxList(self._anchors.boxes)

        # cls_weights, box_weights are not used
        cls_targets, _, box_targets, _, matches = self._target_assigner.assign(
            anchor_box_list, gt_box_list, gt_labels)

        # class labels start from 1 and the background class = -1
        cls_targets -= 1

        # create one-hot labels
        cls_targets_one_hot = tf.one_hot(tf.cast(cls_targets, dtype=tf.int32),
                                         self._num_classes)
        cls_targets_one_hot = tf.reshape(cls_targets_one_hot,
                                         [-1, self._num_classes])

        cls_targets_dict = self._unpack_labels(cls_targets_one_hot)
        box_targets_dict = self._unpack_labels(box_targets)
        num_positives = tf.reduce_sum(
            tf.cast(tf.not_equal(matches.match_results, -1), tf.float32))

        return cls_targets_dict, box_targets_dict, num_positives
コード例 #6
0
    def _transform_images(self, params, features, labels=None):
        """Transforms images."""

        images = features['images']
        batch_size, _, _, c = images.get_shape().as_list()
        if params['conv0_space_to_depth_block_size'] != 0:
            # Transforms (space-to-depth) images for TPU performance.

            def _fused_transform(images, image_size):
                return spatial_transform.fused_transpose_and_space_to_depth(
                    images, image_size,
                    params['conv0_space_to_depth_block_size'],
                    params['transpose_input'])

            images = tf.cond(
                tf.less(features['image_info'][0, 3],
                        features['image_info'][0, 4]),
                lambda: _fused_transform(images, params['image_size']),
                lambda: _fused_transform(images, params['image_size'][::-1]))

        else:
            # Transposes images for TPU performance.
            image_area = params['image_size'][0] * params['image_size'][1]
            if params['transpose_input']:
                images = tf.transpose(images, [1, 2, 0, 3])
                # Flattens spatial dimensions so that the image tensor has a static
                # shape.
                images = tf.reshape(images, [image_area, batch_size, c])
            else:
                images = tf.reshape(images, [batch_size, image_area, c])

        if params['use_bfloat16']:
            images = tf.cast(images, dtype=tf.bfloat16)

        features['images'] = images

        if labels is not None:
            return features, labels
        else:
            return features, tf.zeros([batch_size])
コード例 #7
0
ファイル: dataloader.py プロジェクト: jhseu/tpu
        def _dataset_parser(value):
            """Parse data to a fixed dimension input image and learning targets."""
            with tf.name_scope('parser'):
                data = example_decoder.decode(value)
                source_id = data['source_id']
                image = data['image']
                boxes = data['groundtruth_boxes']
                classes = data['groundtruth_classes']
                classes = tf.reshape(tf.cast(classes, dtype=tf.float32),
                                     [-1, 1])

                # the image normalization is identical to Cloud TPU ResNet-50
                image = tf.image.convert_image_dtype(image, dtype=tf.float32)
                image = _normalize_image(image)

                if params['input_rand_hflip']:
                    image, boxes = preprocessor.random_horizontal_flip(
                        image, boxes=boxes)
                image_original_shape = tf.shape(image)
                image, _ = preprocessor.resize_to_range(
                    image,
                    min_dimension=params['image_size'],
                    max_dimension=params['image_size'])
                image_scale = tf.to_float(
                    image_original_shape[0]) / tf.to_float(tf.shape(image)[0])
                image, boxes = preprocessor.scale_boxes_to_pixel_coordinates(
                    image, boxes, keypoints=None)

                image = tf.image.pad_to_bounding_box(image, 0, 0,
                                                     params['image_size'],
                                                     params['image_size'])
                (cls_targets, box_targets,
                 num_positives) = anchor_labeler.label_anchors(boxes, classes)

                source_id = tf.string_to_number(source_id, out_type=tf.float32)
                row = (image, cls_targets, box_targets, num_positives,
                       source_id, image_scale)
                return row
コード例 #8
0
        def _dataset_parser(value):
            """Parse data to a fixed dimension input image and learning targets.

      Args:
        value: A dictionary contains an image and groundtruth annotations.

      Returns:
        A list of the following elements in order:
        image: Image tensor that is preproessed to have normalized value and
          fixed dimension [image_size, image_size, 3]
        label: label tensor of the same spatial dimension as the image.
      """
            with tf.name_scope('parser'):
                data = example_decoder.decode(value)
                image = data['image']
                label = data['labels_class']
                label = tf.to_int32(label)
                input_processor = SegmentationInputProcessor(
                    image, params['image_size'], label)
                # The image normalization is identical to Cloud TPU ResNet.
                input_processor.normalize_image()
                if self._is_training and params['input_rand_hflip']:
                    input_processor.random_horizontal_flip()
                if self._is_training:
                    input_processor.set_training_random_scale_factors(
                        params['train_scale_min'], params['train_scale_max'])
                image = input_processor.resize_and_crop_image()

                # Set padding to background (class=0) during training.
                if self._is_training:
                    label = input_processor.resize_and_crop_label(0)
                else:
                    label = input_processor.resize_and_crop_label(
                        params['ignore_label'])
                if params['use_bfloat16']:
                    image = tf.cast(image, dtype=tf.bfloat16)
                return image, label
コード例 #9
0
 def key_func(*args):
     return tf.cast(horizontal_image(*args), dtype=tf.int64)
コード例 #10
0
        def _dataset_parser(value):
            """Parse data to a fixed dimension input image and learning targets.

      Args:
        value: A dictionary contains an image and groundtruth annotations.

      Returns:
        features: A dictionary that contains the image and auxiliary
          information. The following describes {key: value} pairs in the
          dictionary.
          image: An image tensor that is preprocessed to have normalized value
            and fixed dimension [image_size, image_size, 3]
          image_info: Image information that includes the original height and
            width, the scale of the processed image to the original image, and
            the scaled height and width.
          source_ids: Source image id. Default value -1 if the source id is
            empty in the groundtruth annotation.
        labels: (only for training) A dictionary that contains groundtruth
          labels. The following describes {key: value} pairs in the dictionary.
          score_targets_dict: An ordered dictionary with keys
            [min_level, min_level+1, ..., max_level]. The values are tensor with
            shape [height_l, width_l, num_anchors]. The height_l and width_l
            represent the dimension of objectiveness score at l-th level.
          box_targets_dict: An ordered dictionary with keys
            [min_level, min_level+1, ..., max_level]. The values are tensor with
            shape [height_l, width_l, num_anchors * 4]. The height_l and
            width_l represent the dimension of bounding box regression output at
            l-th level.
          gt_boxes: Groundtruth bounding box annotations. The box is represented
             in [y1, x1, y2, x2] format. The tennsor is padded with -1 to the
             fixed dimension [self._max_num_instances, 4].
          gt_classes: Groundtruth classes annotations. The tennsor is padded
            with -1 to the fixed dimension [self._max_num_instances].
          cropped_gt_masks: Groundtruth masks cropped by the bounding box and
            resized to a fixed size determined by params['gt_mask_size']
      """
            with tf.name_scope('parser'):
                data = example_decoder.decode(value)

                image = data['image']
                source_id = data['source_id']
                source_id = tf.where(tf.equal(source_id, tf.constant('')),
                                     '-1', source_id)
                source_id = tf.string_to_number(source_id)

                if self._mode == tf.estimator.ModeKeys.PREDICT:
                    input_processor = InstanceSegmentationInputProcessor(
                        image, image_size, params['short_side_image_size'],
                        params['long_side_max_image_size'])
                    input_processor.normalize_image()
                    input_processor.set_scale_factors_to_mlperf_reference_size(
                    )
                    image = input_processor.resize_and_crop_image()
                    if params['use_bfloat16']:
                        image = tf.cast(image, dtype=tf.bfloat16)

                    image_info = input_processor.get_image_info()
                    return {
                        'images': image,
                        'image_info': image_info,
                        'source_ids': source_id
                    }

                # The following part is for training.
                instance_masks = data['groundtruth_instance_masks']
                boxes = data['groundtruth_boxes']
                classes = data['groundtruth_classes']
                classes = tf.reshape(tf.cast(classes, dtype=tf.float32),
                                     [-1, 1])
                if not params['use_category']:
                    classes = tf.cast(tf.greater(classes, 0), dtype=tf.float32)

                if (params['skip_crowd_during_training']
                        and self._mode == tf.estimator.ModeKeys.TRAIN):
                    indices = tf.where(
                        tf.logical_not(data['groundtruth_is_crowd']))
                    classes = tf.gather_nd(classes, indices)
                    boxes = tf.gather_nd(boxes, indices)
                    instance_masks = tf.gather_nd(instance_masks, indices)

                input_processor = InstanceSegmentationInputProcessor(
                    image, image_size, params['short_side_image_size'],
                    params['long_side_max_image_size'], boxes, classes,
                    instance_masks)
                input_processor.normalize_image()
                if params['input_rand_hflip']:
                    input_processor.random_horizontal_flip()

                input_processor.set_scale_factors_to_mlperf_reference_size()
                image = input_processor.resize_and_crop_image()

                boxes, classes = input_processor.resize_and_crop_boxes()
                cropped_gt_masks = input_processor.crop_gt_masks(
                    params['gt_mask_size'])

                image_info = input_processor.get_image_info()
                # Assign anchors.
                is_height_short_side = tf.less(image_info[3], image_info[4])
                score_targets, box_targets = tf.cond(
                    is_height_short_side,
                    lambda: anchor_labeler.label_anchors(boxes, classes),
                    lambda: height_long_side_anchor_labeler.label_anchors(boxes, classes))  # pylint: disable=line-too-long

                # Pad groundtruth data.
                boxes *= image_info[2]
                boxes = pad_to_fixed_size(boxes, -1,
                                          [self._max_num_instances, 4])
                classes = pad_to_fixed_size(classes, -1,
                                            [self._max_num_instances, 1])
                # Pads cropped_gt_masks.
                cropped_gt_masks = tf.reshape(
                    cropped_gt_masks, [-1, (params['gt_mask_size'] + 4)**2])
                cropped_gt_masks = pad_to_fixed_size(
                    cropped_gt_masks, -1,
                    [self._max_num_instances, (params['gt_mask_size'] + 4)**2])
                cropped_gt_masks = tf.reshape(cropped_gt_masks, [
                    self._max_num_instances, params['gt_mask_size'] + 4,
                    params['gt_mask_size'] + 4
                ])
                if params['use_bfloat16']:
                    image = tf.cast(image, dtype=tf.bfloat16)

                features = {}
                features['images'] = image
                features['image_info'] = image_info
                features['source_ids'] = source_id

                labels = {}
                for level in range(params['min_level'],
                                   params['max_level'] + 1):
                    labels['score_targets_%d' % level] = score_targets[level]
                    labels['box_targets_%d' % level] = box_targets[level]
                labels['gt_boxes'] = boxes
                labels['gt_classes'] = classes
                labels['cropped_gt_masks'] = cropped_gt_masks
                return features, labels
コード例 #11
0
        def _dataset_parser(value):
            """Parse data to a fixed dimension input image and learning targets.

      Args:
        value: A dictionary contains an image and groundtruth annotations.

      Returns:
        image: Image tensor that is preproessed to have normalized value and
          fixed dimension [image_size, image_size, 3]
        cls_targets_dict: ordered dictionary with keys
          [min_level, min_level+1, ..., max_level]. The values are tensor with
          shape [height_l, width_l, num_anchors]. The height_l and width_l
          represent the dimension of class logits at l-th level.
        box_targets_dict: ordered dictionary with keys
          [min_level, min_level+1, ..., max_level]. The values are tensor with
          shape [height_l, width_l, num_anchors * 4]. The height_l and
          width_l represent the dimension of bounding box regression output at
          l-th level.
        num_positives: Number of positive anchors in the image.
        source_id: Source image id. Default value -1 if the source id is empty
          in the groundtruth annotation.
        image_scale: Scale of the proccessed image to the original image.
        boxes: Groundtruth bounding box annotations. The box is represented in
          [y1, x1, y2, x2] format. The tennsor is padded with -1 to the fixed
          dimension [self._max_num_instances, 4].
        is_crowds: Groundtruth annotations to indicate if an annotation
          represents a group of instances by value {0, 1}. The tennsor is
          padded with 0 to the fixed dimension [self._max_num_instances].
        areas: Groundtruth areas annotations. The tennsor is padded with -1
          to the fixed dimension [self._max_num_instances].
        classes: Groundtruth classes annotations. The tennsor is padded with -1
          to the fixed dimension [self._max_num_instances].
      """
            with tf.name_scope('parser'):
                data = example_decoder.decode(value)
                source_id = data['source_id']
                image = data['image']
                boxes = data['groundtruth_boxes']
                classes = data['groundtruth_classes']
                classes = tf.reshape(tf.cast(classes, dtype=tf.float32),
                                     [-1, 1])
                areas = data['groundtruth_area']
                is_crowds = data['groundtruth_is_crowd']
                classes = tf.reshape(tf.cast(classes, dtype=tf.float32),
                                     [-1, 1])

                if params['skip_crowd_during_training'] and self._is_training:
                    indices = tf.where(
                        tf.logical_not(data['groundtruth_is_crowd']))
                    classes = tf.gather_nd(classes, indices)
                    boxes = tf.gather_nd(boxes, indices)

                # NOTE: The autoaugment method works best when used alongside the
                # standard horizontal flipping of images along with size jittering
                # and normalization.
                if params.get('autoaugment_policy',
                              None) and self._is_training:
                    image, boxes = autoaugment.distort_image_with_autoaugment(
                        image, boxes, params['autoaugment_policy'])

                input_processor = DetectionInputProcessor(
                    image, params['image_size'], boxes, classes)
                input_processor.normalize_image()
                if self._is_training and params['input_rand_hflip']:
                    input_processor.random_horizontal_flip()
                if self._is_training:
                    input_processor.set_training_random_scale_factors(
                        params['train_scale_min'], params['train_scale_max'])
                else:
                    input_processor.set_scale_factors_to_output_size()
                image = input_processor.resize_and_crop_image()
                boxes, classes = input_processor.resize_and_crop_boxes()

                # Assign anchors.
                (cls_targets, box_targets,
                 num_positives) = anchor_labeler.label_anchors(boxes, classes)

                source_id = tf.where(tf.equal(source_id, tf.constant('')),
                                     '-1', source_id)
                source_id = tf.string_to_number(source_id)

                # Pad groundtruth data for evaluation.
                image_scale = input_processor.image_scale_to_original
                boxes *= image_scale
                is_crowds = tf.cast(is_crowds, dtype=tf.float32)
                boxes = pad_to_fixed_size(boxes, -1,
                                          [self._max_num_instances, 4])
                is_crowds = pad_to_fixed_size(is_crowds, 0,
                                              [self._max_num_instances, 1])
                areas = pad_to_fixed_size(areas, -1,
                                          [self._max_num_instances, 1])
                classes = pad_to_fixed_size(classes, -1,
                                            [self._max_num_instances, 1])
                if params['use_bfloat16']:
                    image = tf.cast(image, dtype=tf.bfloat16)
                return (image, cls_targets, box_targets, num_positives,
                        source_id, image_scale, boxes, is_crowds, areas,
                        classes)
コード例 #12
0
  def parser(record):
    """function used to parse tfrecord."""

    record_spec = {
        "input": tf.FixedLenFeature([seq_len], tf.int64),
        "target": tf.FixedLenFeature([seq_len], tf.int64),
        "seg_id": tf.FixedLenFeature([seq_len], tf.int64),
        "label": tf.FixedLenFeature([1], tf.int64),
        "is_masked": tf.FixedLenFeature([seq_len], tf.int64),
    }

    # retrieve serialized example
    example = tf.parse_single_example(
        serialized=record,
        features=record_spec)

    inputs = example.pop("input")
    target = example.pop("target")
    is_masked = tf.cast(example.pop("is_masked"), tf.bool)

    non_reuse_len = seq_len - reuse_len
    assert perm_size <= reuse_len and perm_size <= non_reuse_len

    perm_mask_0, target_0, target_mask_0, input_k_0, input_q_0 = _local_perm(
        inputs[:reuse_len],
        target[:reuse_len],
        is_masked[:reuse_len],
        perm_size,
        reuse_len)

    perm_mask_1, target_1, target_mask_1, input_k_1, input_q_1 = _local_perm(
        inputs[reuse_len:],
        target[reuse_len:],
        is_masked[reuse_len:],
        perm_size,
        non_reuse_len)

    perm_mask_0 = tf.concat([perm_mask_0, tf.ones([reuse_len, non_reuse_len])],
                            axis=1)
    perm_mask_1 = tf.concat([tf.zeros([non_reuse_len, reuse_len]), perm_mask_1],
                            axis=1)
    perm_mask = tf.concat([perm_mask_0, perm_mask_1], axis=0)
    target = tf.concat([target_0, target_1], axis=0)
    target_mask = tf.concat([target_mask_0, target_mask_1], axis=0)
    input_k = tf.concat([input_k_0, input_k_1], axis=0)
    input_q = tf.concat([input_q_0, input_q_1], axis=0)

    if num_predict is not None:
      indices = tf.range(seq_len, dtype=tf.int64)
      bool_target_mask = tf.cast(target_mask, tf.bool)
      indices = tf.boolean_mask(indices, bool_target_mask)

      ##### extra padding due to CLS/SEP introduced after prepro
      actual_num_predict = tf.shape(indices)[0]
      pad_len = num_predict - actual_num_predict

      ##### target_mapping
      target_mapping = tf.one_hot(indices, seq_len, dtype=tf.float32)
      paddings = tf.zeros([pad_len, seq_len], dtype=target_mapping.dtype)
      target_mapping = tf.concat([target_mapping, paddings], axis=0)
      example["target_mapping"] = tf.reshape(target_mapping,
                                             [num_predict, seq_len])

      ##### target
      target = tf.boolean_mask(target, bool_target_mask)
      paddings = tf.zeros([pad_len], dtype=target.dtype)
      target = tf.concat([target, paddings], axis=0)
      example["target"] = tf.reshape(target, [num_predict])

      ##### target mask
      target_mask = tf.concat(
          [tf.ones([actual_num_predict], dtype=tf.float32),
           tf.zeros([pad_len], dtype=tf.float32)],
          axis=0)
      example["target_mask"] = tf.reshape(target_mask, [num_predict])
    else:
      example["target"] = tf.reshape(target, [seq_len])
      example["target_mask"] = tf.reshape(target_mask, [seq_len])

    # reshape back to fixed shape
    example["perm_mask"] = tf.reshape(perm_mask, [seq_len, seq_len])
    example["input_k"] = tf.reshape(input_k, [seq_len])
    example["input_q"] = tf.reshape(input_q, [seq_len])

    _convert_example(example, use_bfloat16)

    for k, v in example.items():
      tf.logging.info("%s: %s", k, v)

    return example
コード例 #13
0
def _local_perm(inputs, targets, is_masked, perm_size, seq_len):
  """
  Sample a permutation of the factorization order, and create an
  attention mask accordingly.

  Args:
    inputs: int64 Tensor in shape [seq_len], input ids.
    targets: int64 Tensor in shape [seq_len], target ids.
    is_masked: bool Tensor in shape [seq_len]. True means being selected
      for partial prediction.
    perm_size: the length of longest permutation. Could be set to be reuse_len.
      Should not be larger than reuse_len or there will be data leaks.
    seq_len: int, sequence length.
  """

  # Generate permutation indices
  index = tf.range(seq_len, dtype=tf.int64)
  index = tf.transpose(tf.reshape(index, [-1, perm_size]))
  index = tf.random_shuffle(index)
  index = tf.reshape(tf.transpose(index), [-1])

  # `perm_mask` and `target_mask`
  # non-functional tokens
  non_func_tokens = tf.logical_not(tf.logical_or(
      tf.equal(inputs, SEP_ID),
      tf.equal(inputs, CLS_ID)))

  non_mask_tokens = tf.logical_and(tf.logical_not(is_masked), non_func_tokens)
  masked_or_func_tokens = tf.logical_not(non_mask_tokens)

  # Set the permutation indices of non-masked (& non-funcional) tokens to the
  # smallest index (-1):
  # (1) they can be seen by all other positions
  # (2) they cannot see masked positions, so there won"t be information leak
  smallest_index = -tf.ones([seq_len], dtype=tf.int64)
  rev_index = tf.where(non_mask_tokens, smallest_index, index)

  # Create `target_mask`: non-funcional and maksed tokens
  # 1: use mask as input and have loss
  # 0: use token (or [SEP], [CLS]) as input and do not have loss
  target_tokens = tf.logical_and(masked_or_func_tokens, non_func_tokens)
  target_mask = tf.cast(target_tokens, tf.float32)

  # Create `perm_mask`
  # `target_tokens` cannot see themselves
  self_rev_index = tf.where(target_tokens, rev_index, rev_index + 1)

  # 1: cannot attend if i <= j and j is not non-masked (masked_or_func_tokens)
  # 0: can attend if i > j or j is non-masked
  perm_mask = tf.logical_and(
      self_rev_index[:, None] <= rev_index[None, :],
      masked_or_func_tokens)
  perm_mask = tf.cast(perm_mask, tf.float32)

  # new target: [next token] for LM and [curr token] (self) for PLM
  new_targets = tf.concat([inputs[0: 1], targets[: -1]],
                          axis=0)

  # construct inputs_k
  inputs_k = inputs

  # construct inputs_q
  inputs_q = target_mask

  return perm_mask, new_targets, target_mask, inputs_k, inputs_q