示例#1
0
    def _postprocess(self, outputs: typing.NestedTensorDict):
        """Post-process (filtering) the outputs.

    Args:
      outputs: A dictionary of outputs.
    These following fields are added to outputs["postprocessed"]:
      "classes": A (B,N) integer tensor for the class ids.
      "binary_masks": A (B, H, W, N) tensor for the N binarized 0/1 masks. Masks
        for void cls are set to zero.
      "confidence": A (B, N) float tensor for the confidence of "classes".
      "mask_area": A (B, N) float tensor for the area of each mask. They are
        used in inference / visualization.
    """
        # Get postprocessed outputs
        outputs["postprocessed"] = {}

        ## Masks:
        mask_id_prob = outputs["instance_output"]["mask_id_prob"]
        mask_max_prob = tf.reduce_max(mask_id_prob, axis=-1, keepdims=True)
        thresholded_binary_masks = tf.cast(
            tf.math.logical_and(
                tf.equal(mask_max_prob, mask_id_prob),
                tf.greater_equal(mask_max_prob, self._mask_threshold)),
            tf.float32)
        area = tf.reduce_sum(thresholded_binary_masks, axis=(1, 2))  # (B, N)
        ## Classification:
        cls_prob = outputs["instance_output"]["cls_prob"]
        cls_max_prob = tf.reduce_max(cls_prob, axis=-1)  # B, N
        cls_max_id = tf.cast(tf.argmax(cls_prob, axis=-1), tf.float32)  # B, N

        ## filtering
        c = utilities.resolve_shape(cls_prob)[2]
        non_void = tf.reduce_all(
            tf.stack(
                [
                    tf.greater_equal(area,
                                     self._filter_area),  # mask large enough.
                    tf.not_equal(cls_max_id, 0),  # class-0 is for non-object.
                    tf.not_equal(cls_max_id, c -
                                 1),  # class-(c-1) is for background (last).
                    tf.greater_equal(cls_max_prob,
                                     self._class_threshold)  # prob >= thr
                ],
                axis=-1),
            axis=-1)
        non_void = tf.cast(non_void, tf.float32)

        # Storing
        outputs["postprocessed"]["classes"] = tf.cast(cls_max_id * non_void,
                                                      tf.int32)
        b, n = utilities.resolve_shape(non_void)
        outputs["postprocessed"]["binary_masks"] = (
            thresholded_binary_masks * tf.reshape(non_void, (b, 1, 1, n)))
        outputs["postprocessed"]["confidence"] = cls_max_prob
        outputs["postprocessed"]["mask_area"] = area
示例#2
0
def _erase(mask: tf.Tensor,
           feature: tf.Tensor,
           min_val: float = 0.,
           max_val: float = 256.) -> tf.Tensor:
    """Erase the feature maps with a mask.

  Erase feature maps with a mask and replace the erased area with uniform random
  noise. The mask can have different size from the feature maps.

  Args:
    mask: an (h, w) binay mask for pixels to erase with. Value 1 represents
      pixels to erase.
    feature: the (H, W, C) feature maps to erase from.
    min_val: The minimum value of random noise.
    max_val: The maximum value of random noise.

  Returns:
      The (H, W, C) feature maps, with pixels in mask replaced with noises. It's
    equal to mask * noise + (1 - mask) * feature.
  """
    h, w, c = utilities.resolve_shape(feature)
    resized_mask = tf.image.resize(
        tf.tile(tf.expand_dims(tf.cast(mask, tf.float32), -1), (1, 1, c)),
        (h, w))
    erased = tf.where(condition=(resized_mask > 0.5),
                      x=tf.cast(tf.random.uniform((h, w, c), min_val, max_val),
                                feature.dtype),
                      y=feature)
    return erased
示例#3
0
        def _rotate():
            """Rotation.

      These will be rotated:
        image,
        rbox,
        entity_id_mask,
      TODO(longshangbang): rotate vertices.

      Returns:
        The rotated tensors of the above fields.
      """
            k = tf.random.uniform([], 1, 4, dtype=tf.int32)
            h, w, _ = utilities.resolve_shape(data['image'])
            # Image
            rotated_img = tf.image.rot90(data['image'],
                                         k=k,
                                         name='image_rot90k')
            # Box
            rotate_box_op = functools.partial(utilities.rotate_rboxes90,
                                              rboxes=data['groundtruth_boxes'],
                                              image_width=w,
                                              image_height=h)
            rotated_boxes = tf.switch_case(
                k - 1,  # Indices start with 1.
                branch_fns=[
                    lambda: rotate_box_op(rotation_count=1),
                    lambda: rotate_box_op(rotation_count=2),
                    lambda: rotate_box_op(rotation_count=3)
                ])
            # Mask
            rotated_mask = tf.image.rot90(data['entity_id_mask'],
                                          k=k,
                                          name='mask_rot90k')
            return rotated_img, rotated_boxes, rotated_mask
示例#4
0
    def call(self,
             features: typing.TensorDict,
             training: bool = False) -> typing.NestedTensorDict:
        """Forward pass of the model.

    Args:
      features: The input features: {"images": tf.Tensor}. Shape = [B, H, W, C]
      training: Whether it's training mode.

    Returns:
      A dictionary of output with this structure:
        {
          "max_deep_lab": {
            All the max deeplab outputs are here, including both backbone and
            decoder.
          }
          "segmentation_output": {
            "word_score": tf.Tensor, [B, h, w],
          }
          "instance_output": {
            "cls_logits": tf.Tensor, [B, N, C],
            "mask_id_logits": tf.Tensor, [B, H, W, N],
            "cls_prob":  tf.Tensor, [B, N, C],
            "mask_id_prob": tf.Tensor, [B, H, W, N],
          }
          "postprocessed": {
            "classes": A (B, N) tensor for the class ids. Zero for non-firing
              slots.
            "binary_masks": A (B, H, W, N) tensor for the N binary masks. Masks
              for void cls are set to zero.
            "confidence": A (B, N) float tensor for the confidence of "classes".
            "mask_area": A (B, N) float tensor for the area of each mask.
          }
          "transformer_group_feature": (B, N, C) float tensor (normalized),
          "para_affinity": (B, N, N) float tensor.
        }

      Class-0 is for void. Class-(C-1) is for background. Class-1~(C-2) is for
      valid classes.
    """
        # backbone
        backbone_output = self._backbone_fn(features["images"], training)
        # split instance embedding and paragraph embedding;
        # then perform paragraph grouping
        para_fts = self._get_para_outputs(backbone_output, training)
        affinity = tf.linalg.matmul(para_fts, para_fts, transpose_b=True)
        # text detection head
        decoder_output = self._decoder(backbone_output, training)
        output_dict = {
            "max_deep_lab": decoder_output,
            "transformer_group_feature": para_fts,
            "para_affinity": affinity,
        }
        input_shape = utilities.resolve_shape(features["images"])
        self._get_semantic_outputs(output_dict, input_shape)
        self._get_instance_outputs(output_dict, input_shape)
        self._postprocess(output_dict)

        return output_dict
示例#5
0
def _instance_discrimination_loss(loss_dict: Dict[str, Any],
                                  labels: Dict[str, Any],
                                  outputs: Dict[str, Any],
                                  tau: float = gin.REQUIRED):
    """Instance discrimination loss.

  This method adds the ID loss term to loss_dict directly.

  Args:
    loss_dict: A dictionary for the loss. The values are loss scalars.
    labels: The label dictionary.
    outputs: The output dictionary.
    tau: The temperature term in the loss
  """
    # The normalized feature, shape=(B, H/4, W/4, D)
    g = outputs["max_deep_lab"]["pixel_space_normalized_feature"]
    b, h, w = utilities.resolve_shape(g)[:3]
    # The ground-truth masks, shape=(B, N, H, W) --> (B, N, H/4, W/4)
    m = labels["masks"]
    m = tf.image.resize(tf.transpose(m, (0, 2, 3, 1)), (h, w),
                        tf.image.ResizeMethod.NEAREST_NEIGHBOR)
    m = tf.transpose(m, (0, 3, 1, 2))
    # The number of ground-truth instance (K), shape=(B,)
    num = labels["num_instance"]
    n = utilities.resolve_shape(m)[1]  # max number of predictions
    # is_void[b, i] = 1 if instance i in batch b is a padded slot.
    is_void = tf.cast(tf.expand_dims(tf.range(n), 0), tf.float32)  # (1, n)
    is_void = tf.cast(tf.math.greater_equal(is_void, tf.expand_dims(num, 1)),
                      tf.float32)

    # (B, N, D)
    t = tf.math.l2_normalize(tf.einsum("bhwd,bnhw->bnd", g, m), axis=-1)
    inst_dist_logits = tf.einsum("bhwd,bid->bhwi", g, t) / tau  # (B, H, W, N)
    inst_dist_logits = inst_dist_logits - 100. * tf.reshape(
        is_void, (b, 1, 1, n))
    mask_id = tf.cast(
        tf.einsum("bnhw,n->bhw", m, tf.range(n, dtype=tf.float32)), tf.int32)
    loss_map = tf.nn.sparse_softmax_cross_entropy_with_logits(
        labels=mask_id, logits=inst_dist_logits)  # B, H, W
    valid_mask = tf.reduce_sum(m, axis=1)
    loss_inst_dist = (
        (tf.reduce_sum(loss_map * valid_mask, axis=[1, 2]) + EPSILON) /
        (tf.reduce_sum(valid_mask, axis=[1, 2]) + EPSILON))
    loss_dict["loss_inst_dist"] = tf.reduce_mean(loss_inst_dist)
示例#6
0
 def _preprocess_labels(self, labels: typing.TensorDict):
     # Preprocessing
     # Converted the integer mask to one-hot embedded masks.
     num_instances = utilities.resolve_shape(
         labels["instance_labels"]["masks_sizes"])[1]
     labels["instance_labels"]["masks"] = tf.one_hot(
         labels["instance_labels"]["masks"],
         depth=num_instances,
         axis=1,
         dtype=tf.float32)  # (B, N, H, W)
示例#7
0
    def _coloring(self, masks: tf.Tensor) -> tf.Tensor:
        """Coloring segmentation masks.

    Used in visualization.

    Args:
      masks: A float binary tensor of shape (B, H, W, N), representing `B`
        samples, with `N` masks of size `H*W` each. Each of the `N` masks will
        be assigned a random color.

    Returns:
      A (b, h, w, 3) float tensor in [0., 1.] for the coloring result.
    """
        b, h, w, n = utilities.resolve_shape(masks)
        palette = tf.random.uniform((1, n, 3), 0.5, 1.)
        colored = tf.reshape(tf.matmul(tf.reshape(masks, (b, -1, n)), palette),
                             (b, h, w, 3))
        return colored
示例#8
0
    def _crop_and_resize(self, data: TensorDict, unused_features: TensorDict,
                         unused_labels: TensorDict):
        """Perform random cropping and resizing."""
        # TODO(longshangbang): resize & translate box as well
        # TODO(longshangbang): resize & translate vertices as well

        # Get cropping target.
        h, w = utilities.resolve_shape(data['image'])[:2]
        left, top, crop_w, crop_h, pad_w, pad_h = self._get_crop_box(
            tf.cast(h, tf.float32), tf.cast(w, tf.float32))

        # Crop the image. (Pad the images if the crop box is larger than image.)
        if self._is_training:
            # padding left, top, right, bottom
            pad_left = tf.random.uniform([], 0, pad_w + 1, dtype=tf.int32)
            pad_top = tf.random.uniform([], 0, pad_h + 1, dtype=tf.int32)
        else:
            pad_left = 0
            pad_top = 0
        cropped_img = tf.image.crop_to_bounding_box(data['image'], top, left,
                                                    crop_h, crop_w)
        padded_img = tf.pad(
            cropped_img,
            [[pad_top, pad_h - pad_top], [pad_left, pad_w - pad_left], [0, 0]],
            constant_values=127)

        # Resize images
        data['resized_image'] = tf.image.resize(
            padded_img, (self._output_dimension, self._output_dimension))
        data['resized_image'] = tf.cast(data['resized_image'], tf.uint8)

        # Crop the masks
        cropped_masks = tf.image.crop_to_bounding_box(data['entity_id_mask'],
                                                      top, left, crop_h,
                                                      crop_w)
        padded_masks = tf.pad(
            cropped_masks,
            [[pad_top, pad_h - pad_top], [pad_left, pad_w - pad_left], [0, 0]])

        # Resize masks
        data['resized_masks'] = tf.image.resize(
            padded_masks, (self._mask_dimension, self._mask_dimension),
            method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)
        data['resized_masks'] = tf.squeeze(data['resized_masks'], -1)
示例#9
0
def _dice_sim(pred: tf.Tensor, ground_truth: tf.Tensor) -> tf.Tensor:
    """Dice Coefficient for mask similarity.

  Args:
    pred: The predicted mask. [B, N, H, W], in [0, 1].
    ground_truth: The ground-truth mask. [B, N, H, W], in [0, 1] or {0, 1}.

  Returns:
    A matrix for the losses: m[b, i, j] is the dice similarity between pred `i`
    and gt `j` in batch `b`.
  """
    b, n = utilities.resolve_shape(pred)[:2]
    ground_truth = tf.reshape(tf.transpose(ground_truth, (0, 2, 3, 1)),
                              (b, -1, n))  # B, HW, N
    pred = tf.reshape(pred, (b, n, -1))  # B, N, HW
    numerator = tf.matmul(pred, ground_truth) * 2.
    # TODO(longshangbang): The official implementation does not square the scores.
    # Need to do experiment to determine which one is better.
    denominator = (
        tf.math.reduce_sum(tf.math.square(ground_truth), 1, keepdims=True) +
        tf.math.reduce_sum(tf.math.square(pred), 2, keepdims=True))
    return (numerator + EPSILON) / (denominator + EPSILON)
示例#10
0
    def _get_para_outputs(self, outputs: typing.TensorDict,
                          training: bool) -> tf.Tensor:
        """Apply the paragraph head.

    This function first splits the features for instance classification and
    instance grouping. Then, the additional grouping branch (transformer layers)
    is applied to further encode the grouping features. Finally, a tensor of
    normalized grouping features is returned.

    Args:
      outputs: output dictionary from the backbone.
      training: training / eval mode mark.

    Returns:
      The normalized paragraph embedding vector of shape (B, N, C).
    """
        # Project the object embeddings into classification feature and grouping
        # feature.
        fts = outputs["transformer_class_feature"]  # B,N,C
        class_feature = self._class_embed_head(fts, training)
        group_feature = self._para_embed_head(fts, training)
        outputs["transformer_class_feature"] = class_feature
        outputs["transformer_group_feature"] = group_feature

        # Feed the grouping features into additional group encoding branch.
        # First we need to build the attention_bias which is used the standard
        # transformer encoder.
        input_shape = utilities.resolve_shape(group_feature)
        b = input_shape[0]
        n = int(input_shape[1])
        seq_len = tf.constant(n, shape=(b, ))
        padding_mask = utilities.get_padding_mask_from_valid_lengths(
            seq_len, n, tf.float32)
        attention_bias = utilities.get_transformer_attention_bias(padding_mask)
        group_feature = self._para_proj(
            self._para_head(group_feature, attention_bias, None, training))
        return tf.math.l2_normalize(group_feature, axis=-1)
示例#11
0
def _entity_mask_loss(loss_dict: Dict[str, tf.Tensor],
                      labels: tf.Tensor,
                      outputs: tf.Tensor,
                      alpha: float = gin.REQUIRED):
    """PQ loss for entity-mask training.

  This method adds the PQ loss term to loss_dict directly. The match result will
  also be stored in outputs (As a [B, N_pred, N_gt] float tensor).

  Args:
    loss_dict: A dictionary for the loss. The values are loss scalars.
    labels: A dict containing: `num_instance` - (B,) `masks` - (B, N, H, W)
      `classes` - (B, N)
    outputs: A dict containing:
      `cls_prob`: (B, N, C)
      `mask_id_prob`: (B, H, W, N)
      `cls_logits`: (B, N, C)
      `mask_id_logits`: (B, H, W, N)
    alpha: Weight for pos/neg balance.
  """
    # Classification score: (B, N, N)
    # in batch b, the probability of prediction i being class of gt j, i.e.:
    # score[b, i, j] = pred_cls[b, i, gt_cls[b, j]]
    gt_cls = labels["classes"]  # (B, N)
    pred_cls = outputs["cls_prob"]  # (B, N, C)
    b, n = utilities.resolve_shape(pred_cls)[:2]
    # indices[b, i, j] = gt_cls[b, j]
    indices = tf.tile(tf.expand_dims(gt_cls, 1), (1, n, 1))
    cls_score = tf.gather(pred_cls, tf.cast(indices, tf.int32), batch_dims=2)

    # Mask score (dice): (B, N, N)
    # mask_score[b, i, j]: dice-similarity for pred i and gt j in batch b.
    mask_score = _dice_sim(tf.transpose(outputs["mask_id_prob"], (0, 3, 1, 2)),
                           labels["masks"])

    # Get similarity matrix and matching.
    # padded mask[b, j, i] = -1 << other scores, if i >= num_instance[b]
    similarity = cls_score * mask_score
    padded_mask = tf.cast(tf.reshape(tf.range(n), (1, 1, n)), tf.float32)
    padded_mask = tf.cast(
        tf.math.greater_equal(padded_mask,
                              tf.reshape(labels["num_instance"], (b, 1, 1))),
        tf.float32)
    # The constant value for padding has no effect.
    masked_similarity = similarity * (1. - padded_mask) + padded_mask * (-1.)
    matched_mask = matchers_ops.hungarian_matching(-masked_similarity)
    matched_mask = tf.cast(matched_mask, tf.float32) * (1 - padded_mask)
    outputs["matched_mask"] = matched_mask
    # Pos loss
    loss_pos = (tf.stop_gradient(cls_score) * (-mask_score) +
                tf.stop_gradient(mask_score) * (-tf.math.log(cls_score)))
    loss_pos = tf.reduce_sum(loss_pos * matched_mask, axis=[1, 2])  # (B,)
    # Neg loss
    matched_pred = tf.cast(
        tf.reduce_sum(matched_mask, axis=2) > 0, tf.float32)  # (B, N)
    # 0 for void class
    log_loss = -tf.nn.log_softmax(outputs["cls_logits"])[:, :, 0]  # (B, N)
    loss_neg = tf.reduce_sum(log_loss * (1. - matched_pred), axis=-1)  # (B,)

    loss_pq = (alpha * loss_pos + (1 - alpha) * loss_neg) / n
    loss_pq = tf.reduce_mean(loss_pq)
    loss_dict["loss_pq"] = loss_pq
示例#12
0
    def _get_instance_labels(self, data: TensorDict, features: TensorDict,
                             labels: NestedTensorDict):
        """Generate the labels for text entity detection."""

        labels['instance_labels'] = {}
        # (1) Depending on `detection_unit`:
        #     Convert the word-id map to line-id map or use the word-id map directly
        # Word entity ids start from 1 in the map, so pad a -1 at the beginning of
        # the parent list to counter this offset.
        padded_parent = tf.concat(
            [tf.constant([-1]),
             tf.cast(data['groundtruth_parent'], tf.int32)], 0)
        if self._detection_unit == DetectionClass.WORD:
            entity_id_mask = data['resized_masks']
        elif self._detection_unit == DetectionClass.LINE:
            # The pixel value is entity_id + 1, shape = [H, W]; 0 for background.
            # correctness:
            # 0s in data['resized_masks'] --> padded_parent[0] == -1
            # i-th entity in plp.entities --> i+1 in data['resized_masks']
            #                             --> padded_parent[i+1]
            #                             --> data['groundtruth_parent'][i]
            #                             --> the parent of i-th entity
            entity_id_mask = tf.gather(padded_parent,
                                       data['resized_masks']) + 1
        elif self._detection_unit == DetectionClass.PARAGRAPH:
            # directly segmenting paragraphs; two hops here.
            entity_id_mask = tf.gather(padded_parent,
                                       data['resized_masks']) + 1
            entity_id_mask = tf.gather(padded_parent, entity_id_mask) + 1
        else:
            raise ValueError(f'No such detection unit: {self._detection_unit}')
        data['entity_id_mask'] = entity_id_mask

        # (2) Get individual masks for entities.
        entity_selection_mask = tf.equal(data['groundtruth_classes'],
                                         self._detection_unit)
        num_all_entity = utilities.resolve_shape(
            data['groundtruth_classes'])[0]
        # entity_ids is a 1-D tensor for IDs of all entities of a certain type.
        entity_ids = tf.boolean_mask(tf.range(num_all_entity, dtype=tf.int32),
                                     entity_selection_mask)  # (N,)
        # +1 to match the entity ids in entity_id_mask
        entity_ids = tf.reshape(entity_ids, (-1, 1, 1)) + 1
        individual_masks = tf.expand_dims(entity_id_mask, 0)
        individual_masks = tf.equal(entity_ids,
                                    individual_masks)  # (N, H, W), bool
        # TODO(longshangbang): replace with real mask sizes computing.
        # Currently, we use full-resolution masks for individual_masks. In order to
        # compute mask sizes, we need to convert individual_masks to int/float type.
        # This will cause OOM because the mask is too large.
        masks_sizes = tf.cast(tf.reduce_any(individual_masks, axis=[1, 2]),
                              tf.float32)
        # remove empty masks (usually caused by cropping)
        non_empty_masks_ids = tf.not_equal(masks_sizes, 0)
        valid_masks = tf.boolean_mask(individual_masks, non_empty_masks_ids)
        valid_entity_ids = tf.boolean_mask(entity_ids,
                                           non_empty_masks_ids)[:, 0, 0]

        # (3) Write num of instance
        num_instance = tf.reduce_sum(tf.cast(non_empty_masks_ids, tf.float32))
        num_instance_and_bkg = num_instance + 1
        if self._max_num_instance >= 0:
            num_instance_and_bkg = tf.minimum(num_instance_and_bkg,
                                              self._max_num_instance)
        labels['instance_labels']['num_instance'] = num_instance_and_bkg

        # (4) Write instance masks
        num_entity_int = tf.cast(num_instance, tf.int32)
        max_num_entities = self._max_num_instance - 1  # Spare 1 for bkg.
        pad_num = tf.maximum(max_num_entities - num_entity_int, 0)
        padded_valid_masks = tf.pad(valid_masks,
                                    [[0, pad_num], [0, 0], [0, 0]])

        # If there are more instances than allowed, randomly sample some.
        # `random_selection_mask` is a 0/1 array; the maximum number of 1 is
        # `self._max_num_instance`; if not bound, it's an array with all 1s.
        if self._max_num_instance >= 0:
            padded_size = num_entity_int + pad_num
            random_selection = tf.random.uniform((padded_size, ),
                                                 dtype=tf.float32)
            selected_indices = tf.math.top_k(random_selection,
                                             k=max_num_entities)[1]
            random_selection_mask = tf.scatter_nd(
                indices=tf.expand_dims(selected_indices, axis=-1),
                updates=tf.ones((max_num_entities, ), dtype=tf.bool),
                shape=(padded_size, ))
        else:
            random_selection_mask = tf.ones((num_entity_int, ), dtype=tf.bool)
        random_discard_mask = tf.logical_not(random_selection_mask)

        kept_masks = tf.boolean_mask(padded_valid_masks, random_selection_mask)
        erased_masks = tf.boolean_mask(padded_valid_masks, random_discard_mask)
        erased_masks = tf.cast(tf.reduce_any(erased_masks, axis=0), tf.float32)
        # erase text instances that are obmitted.
        features['images'] = _erase(erased_masks, features['images'], -1., 1.)
        labels['segmentation_output']['gt_word_score'] *= 1. - erased_masks
        kept_masks_and_bkg = tf.concat(
            [
                tf.math.logical_not(
                    tf.reduce_any(kept_masks, axis=0, keepdims=True)),  # bkg
                kept_masks,
            ],
            0)
        labels['instance_labels']['masks'] = tf.argmax(kept_masks_and_bkg,
                                                       axis=0)

        # (5) Write mask size
        # TODO(longshangbang): replace with real masks sizes
        masks_sizes = tf.cast(tf.reduce_any(kept_masks_and_bkg, axis=[1, 2]),
                              tf.float32)
        labels['instance_labels']['masks_sizes'] = masks_sizes
        # (6) Write classes.
        classes = tf.ones((num_instance, ), dtype=tf.int32)
        classes = tf.concat([tf.constant(2, tf.int32, (1, )), classes],
                            0)  # bkg
        if self._max_num_instance >= 0:
            classes = utilities.truncate_or_pad(classes,
                                                self._max_num_instance, 0)
        labels['instance_labels']['classes'] = classes

        # (7) gt-weights
        selected_ids = tf.boolean_mask(valid_entity_ids,
                                       random_selection_mask[:num_entity_int])

        if self._detection_unit != DetectionClass.PARAGRAPH:
            gt_text = tf.gather(data['groundtruth_text'], selected_ids - 1)
            gt_weights = tf.cast(tf.strings.length(gt_text) > 0, tf.float32)
        else:
            text_types = tf.concat(
                [
                    tf.constant([8]),
                    tf.cast(data['groundtruth_content_type'], tf.int32),
                    # TODO(longshangbang): temp solution for tfes with no para labels
                    tf.constant(8, shape=(1000, )),
                ],
                0)
            para_types = tf.gather(text_types, selected_ids)

            gt_weights = tf.cast(tf.not_equal(para_types, NOT_ANNOTATED_ID),
                                 tf.float32)

        gt_weights = tf.concat([tf.constant(1., shape=(1, )), gt_weights],
                               0)  # bkg
        if self._max_num_instance >= 0:
            gt_weights = utilities.truncate_or_pad(gt_weights,
                                                   self._max_num_instance, 0)
        labels['instance_labels']['gt_weights'] = gt_weights

        # (8) get paragraph label
        # In this step, an array `{p_i}` is generated. `p_i` is an integer that
        # indicates the group of paragraph which i-th text belongs to. `p_i` == -1
        # if this instance is non-text or it has no paragraph labels.
        # word -> line -> paragraph
        if self._detection_unit == DetectionClass.WORD:
            num_hop = 2
        elif self._detection_unit == DetectionClass.LINE:
            num_hop = 1
        elif self._detection_unit == DetectionClass.PARAGRAPH:
            num_hop = 0
        else:
            raise ValueError(
                f'No such detection unit: {self._detection_unit}. '
                'Note that this error should have been raised in '
                'previous lines, not here!')
        para_ids = tf.identity(selected_ids)  # == id in plp + 1
        for _ in range(num_hop):
            para_ids = tf.gather(padded_parent, para_ids) + 1

        text_types = tf.concat(
            [
                tf.constant([8]),
                tf.cast(data['groundtruth_content_type'], tf.int32),
                # TODO(longshangbang): tricks for tfes that have not para labels
                tf.constant(8, shape=(1000, )),
            ],
            0)
        para_types = tf.gather(text_types, para_ids)

        para_ids = para_ids - 1  # revert to id in plp.entities; -1 for no labels
        valid_para = tf.cast(tf.not_equal(para_types, NOT_ANNOTATED_ID),
                             tf.int32)
        para_ids = valid_para * para_ids + (1 - valid_para) * (-1)
        para_ids = tf.concat([tf.constant([-1]), para_ids], 0)  # add bkg

        has_para_ids = tf.cast(tf.reduce_sum(valid_para) > 0, tf.float32)

        if self._max_num_instance >= 0:
            para_ids = utilities.truncate_or_pad(para_ids,
                                                 self._max_num_instance, 0, -1)
        labels['paragraph_labels'] = {
            'paragraph_ids': para_ids,
            'has_para_ids': has_para_ids
        }