def _postprocess(self, outputs: typing.NestedTensorDict): """Post-process (filtering) the outputs. Args: outputs: A dictionary of outputs. These following fields are added to outputs["postprocessed"]: "classes": A (B,N) integer tensor for the class ids. "binary_masks": A (B, H, W, N) tensor for the N binarized 0/1 masks. Masks for void cls are set to zero. "confidence": A (B, N) float tensor for the confidence of "classes". "mask_area": A (B, N) float tensor for the area of each mask. They are used in inference / visualization. """ # Get postprocessed outputs outputs["postprocessed"] = {} ## Masks: mask_id_prob = outputs["instance_output"]["mask_id_prob"] mask_max_prob = tf.reduce_max(mask_id_prob, axis=-1, keepdims=True) thresholded_binary_masks = tf.cast( tf.math.logical_and( tf.equal(mask_max_prob, mask_id_prob), tf.greater_equal(mask_max_prob, self._mask_threshold)), tf.float32) area = tf.reduce_sum(thresholded_binary_masks, axis=(1, 2)) # (B, N) ## Classification: cls_prob = outputs["instance_output"]["cls_prob"] cls_max_prob = tf.reduce_max(cls_prob, axis=-1) # B, N cls_max_id = tf.cast(tf.argmax(cls_prob, axis=-1), tf.float32) # B, N ## filtering c = utilities.resolve_shape(cls_prob)[2] non_void = tf.reduce_all( tf.stack( [ tf.greater_equal(area, self._filter_area), # mask large enough. tf.not_equal(cls_max_id, 0), # class-0 is for non-object. tf.not_equal(cls_max_id, c - 1), # class-(c-1) is for background (last). tf.greater_equal(cls_max_prob, self._class_threshold) # prob >= thr ], axis=-1), axis=-1) non_void = tf.cast(non_void, tf.float32) # Storing outputs["postprocessed"]["classes"] = tf.cast(cls_max_id * non_void, tf.int32) b, n = utilities.resolve_shape(non_void) outputs["postprocessed"]["binary_masks"] = ( thresholded_binary_masks * tf.reshape(non_void, (b, 1, 1, n))) outputs["postprocessed"]["confidence"] = cls_max_prob outputs["postprocessed"]["mask_area"] = area
def _erase(mask: tf.Tensor, feature: tf.Tensor, min_val: float = 0., max_val: float = 256.) -> tf.Tensor: """Erase the feature maps with a mask. Erase feature maps with a mask and replace the erased area with uniform random noise. The mask can have different size from the feature maps. Args: mask: an (h, w) binay mask for pixels to erase with. Value 1 represents pixels to erase. feature: the (H, W, C) feature maps to erase from. min_val: The minimum value of random noise. max_val: The maximum value of random noise. Returns: The (H, W, C) feature maps, with pixels in mask replaced with noises. It's equal to mask * noise + (1 - mask) * feature. """ h, w, c = utilities.resolve_shape(feature) resized_mask = tf.image.resize( tf.tile(tf.expand_dims(tf.cast(mask, tf.float32), -1), (1, 1, c)), (h, w)) erased = tf.where(condition=(resized_mask > 0.5), x=tf.cast(tf.random.uniform((h, w, c), min_val, max_val), feature.dtype), y=feature) return erased
def _rotate(): """Rotation. These will be rotated: image, rbox, entity_id_mask, TODO(longshangbang): rotate vertices. Returns: The rotated tensors of the above fields. """ k = tf.random.uniform([], 1, 4, dtype=tf.int32) h, w, _ = utilities.resolve_shape(data['image']) # Image rotated_img = tf.image.rot90(data['image'], k=k, name='image_rot90k') # Box rotate_box_op = functools.partial(utilities.rotate_rboxes90, rboxes=data['groundtruth_boxes'], image_width=w, image_height=h) rotated_boxes = tf.switch_case( k - 1, # Indices start with 1. branch_fns=[ lambda: rotate_box_op(rotation_count=1), lambda: rotate_box_op(rotation_count=2), lambda: rotate_box_op(rotation_count=3) ]) # Mask rotated_mask = tf.image.rot90(data['entity_id_mask'], k=k, name='mask_rot90k') return rotated_img, rotated_boxes, rotated_mask
def call(self, features: typing.TensorDict, training: bool = False) -> typing.NestedTensorDict: """Forward pass of the model. Args: features: The input features: {"images": tf.Tensor}. Shape = [B, H, W, C] training: Whether it's training mode. Returns: A dictionary of output with this structure: { "max_deep_lab": { All the max deeplab outputs are here, including both backbone and decoder. } "segmentation_output": { "word_score": tf.Tensor, [B, h, w], } "instance_output": { "cls_logits": tf.Tensor, [B, N, C], "mask_id_logits": tf.Tensor, [B, H, W, N], "cls_prob": tf.Tensor, [B, N, C], "mask_id_prob": tf.Tensor, [B, H, W, N], } "postprocessed": { "classes": A (B, N) tensor for the class ids. Zero for non-firing slots. "binary_masks": A (B, H, W, N) tensor for the N binary masks. Masks for void cls are set to zero. "confidence": A (B, N) float tensor for the confidence of "classes". "mask_area": A (B, N) float tensor for the area of each mask. } "transformer_group_feature": (B, N, C) float tensor (normalized), "para_affinity": (B, N, N) float tensor. } Class-0 is for void. Class-(C-1) is for background. Class-1~(C-2) is for valid classes. """ # backbone backbone_output = self._backbone_fn(features["images"], training) # split instance embedding and paragraph embedding; # then perform paragraph grouping para_fts = self._get_para_outputs(backbone_output, training) affinity = tf.linalg.matmul(para_fts, para_fts, transpose_b=True) # text detection head decoder_output = self._decoder(backbone_output, training) output_dict = { "max_deep_lab": decoder_output, "transformer_group_feature": para_fts, "para_affinity": affinity, } input_shape = utilities.resolve_shape(features["images"]) self._get_semantic_outputs(output_dict, input_shape) self._get_instance_outputs(output_dict, input_shape) self._postprocess(output_dict) return output_dict
def _instance_discrimination_loss(loss_dict: Dict[str, Any], labels: Dict[str, Any], outputs: Dict[str, Any], tau: float = gin.REQUIRED): """Instance discrimination loss. This method adds the ID loss term to loss_dict directly. Args: loss_dict: A dictionary for the loss. The values are loss scalars. labels: The label dictionary. outputs: The output dictionary. tau: The temperature term in the loss """ # The normalized feature, shape=(B, H/4, W/4, D) g = outputs["max_deep_lab"]["pixel_space_normalized_feature"] b, h, w = utilities.resolve_shape(g)[:3] # The ground-truth masks, shape=(B, N, H, W) --> (B, N, H/4, W/4) m = labels["masks"] m = tf.image.resize(tf.transpose(m, (0, 2, 3, 1)), (h, w), tf.image.ResizeMethod.NEAREST_NEIGHBOR) m = tf.transpose(m, (0, 3, 1, 2)) # The number of ground-truth instance (K), shape=(B,) num = labels["num_instance"] n = utilities.resolve_shape(m)[1] # max number of predictions # is_void[b, i] = 1 if instance i in batch b is a padded slot. is_void = tf.cast(tf.expand_dims(tf.range(n), 0), tf.float32) # (1, n) is_void = tf.cast(tf.math.greater_equal(is_void, tf.expand_dims(num, 1)), tf.float32) # (B, N, D) t = tf.math.l2_normalize(tf.einsum("bhwd,bnhw->bnd", g, m), axis=-1) inst_dist_logits = tf.einsum("bhwd,bid->bhwi", g, t) / tau # (B, H, W, N) inst_dist_logits = inst_dist_logits - 100. * tf.reshape( is_void, (b, 1, 1, n)) mask_id = tf.cast( tf.einsum("bnhw,n->bhw", m, tf.range(n, dtype=tf.float32)), tf.int32) loss_map = tf.nn.sparse_softmax_cross_entropy_with_logits( labels=mask_id, logits=inst_dist_logits) # B, H, W valid_mask = tf.reduce_sum(m, axis=1) loss_inst_dist = ( (tf.reduce_sum(loss_map * valid_mask, axis=[1, 2]) + EPSILON) / (tf.reduce_sum(valid_mask, axis=[1, 2]) + EPSILON)) loss_dict["loss_inst_dist"] = tf.reduce_mean(loss_inst_dist)
def _preprocess_labels(self, labels: typing.TensorDict): # Preprocessing # Converted the integer mask to one-hot embedded masks. num_instances = utilities.resolve_shape( labels["instance_labels"]["masks_sizes"])[1] labels["instance_labels"]["masks"] = tf.one_hot( labels["instance_labels"]["masks"], depth=num_instances, axis=1, dtype=tf.float32) # (B, N, H, W)
def _coloring(self, masks: tf.Tensor) -> tf.Tensor: """Coloring segmentation masks. Used in visualization. Args: masks: A float binary tensor of shape (B, H, W, N), representing `B` samples, with `N` masks of size `H*W` each. Each of the `N` masks will be assigned a random color. Returns: A (b, h, w, 3) float tensor in [0., 1.] for the coloring result. """ b, h, w, n = utilities.resolve_shape(masks) palette = tf.random.uniform((1, n, 3), 0.5, 1.) colored = tf.reshape(tf.matmul(tf.reshape(masks, (b, -1, n)), palette), (b, h, w, 3)) return colored
def _crop_and_resize(self, data: TensorDict, unused_features: TensorDict, unused_labels: TensorDict): """Perform random cropping and resizing.""" # TODO(longshangbang): resize & translate box as well # TODO(longshangbang): resize & translate vertices as well # Get cropping target. h, w = utilities.resolve_shape(data['image'])[:2] left, top, crop_w, crop_h, pad_w, pad_h = self._get_crop_box( tf.cast(h, tf.float32), tf.cast(w, tf.float32)) # Crop the image. (Pad the images if the crop box is larger than image.) if self._is_training: # padding left, top, right, bottom pad_left = tf.random.uniform([], 0, pad_w + 1, dtype=tf.int32) pad_top = tf.random.uniform([], 0, pad_h + 1, dtype=tf.int32) else: pad_left = 0 pad_top = 0 cropped_img = tf.image.crop_to_bounding_box(data['image'], top, left, crop_h, crop_w) padded_img = tf.pad( cropped_img, [[pad_top, pad_h - pad_top], [pad_left, pad_w - pad_left], [0, 0]], constant_values=127) # Resize images data['resized_image'] = tf.image.resize( padded_img, (self._output_dimension, self._output_dimension)) data['resized_image'] = tf.cast(data['resized_image'], tf.uint8) # Crop the masks cropped_masks = tf.image.crop_to_bounding_box(data['entity_id_mask'], top, left, crop_h, crop_w) padded_masks = tf.pad( cropped_masks, [[pad_top, pad_h - pad_top], [pad_left, pad_w - pad_left], [0, 0]]) # Resize masks data['resized_masks'] = tf.image.resize( padded_masks, (self._mask_dimension, self._mask_dimension), method=tf.image.ResizeMethod.NEAREST_NEIGHBOR) data['resized_masks'] = tf.squeeze(data['resized_masks'], -1)
def _dice_sim(pred: tf.Tensor, ground_truth: tf.Tensor) -> tf.Tensor: """Dice Coefficient for mask similarity. Args: pred: The predicted mask. [B, N, H, W], in [0, 1]. ground_truth: The ground-truth mask. [B, N, H, W], in [0, 1] or {0, 1}. Returns: A matrix for the losses: m[b, i, j] is the dice similarity between pred `i` and gt `j` in batch `b`. """ b, n = utilities.resolve_shape(pred)[:2] ground_truth = tf.reshape(tf.transpose(ground_truth, (0, 2, 3, 1)), (b, -1, n)) # B, HW, N pred = tf.reshape(pred, (b, n, -1)) # B, N, HW numerator = tf.matmul(pred, ground_truth) * 2. # TODO(longshangbang): The official implementation does not square the scores. # Need to do experiment to determine which one is better. denominator = ( tf.math.reduce_sum(tf.math.square(ground_truth), 1, keepdims=True) + tf.math.reduce_sum(tf.math.square(pred), 2, keepdims=True)) return (numerator + EPSILON) / (denominator + EPSILON)
def _get_para_outputs(self, outputs: typing.TensorDict, training: bool) -> tf.Tensor: """Apply the paragraph head. This function first splits the features for instance classification and instance grouping. Then, the additional grouping branch (transformer layers) is applied to further encode the grouping features. Finally, a tensor of normalized grouping features is returned. Args: outputs: output dictionary from the backbone. training: training / eval mode mark. Returns: The normalized paragraph embedding vector of shape (B, N, C). """ # Project the object embeddings into classification feature and grouping # feature. fts = outputs["transformer_class_feature"] # B,N,C class_feature = self._class_embed_head(fts, training) group_feature = self._para_embed_head(fts, training) outputs["transformer_class_feature"] = class_feature outputs["transformer_group_feature"] = group_feature # Feed the grouping features into additional group encoding branch. # First we need to build the attention_bias which is used the standard # transformer encoder. input_shape = utilities.resolve_shape(group_feature) b = input_shape[0] n = int(input_shape[1]) seq_len = tf.constant(n, shape=(b, )) padding_mask = utilities.get_padding_mask_from_valid_lengths( seq_len, n, tf.float32) attention_bias = utilities.get_transformer_attention_bias(padding_mask) group_feature = self._para_proj( self._para_head(group_feature, attention_bias, None, training)) return tf.math.l2_normalize(group_feature, axis=-1)
def _entity_mask_loss(loss_dict: Dict[str, tf.Tensor], labels: tf.Tensor, outputs: tf.Tensor, alpha: float = gin.REQUIRED): """PQ loss for entity-mask training. This method adds the PQ loss term to loss_dict directly. The match result will also be stored in outputs (As a [B, N_pred, N_gt] float tensor). Args: loss_dict: A dictionary for the loss. The values are loss scalars. labels: A dict containing: `num_instance` - (B,) `masks` - (B, N, H, W) `classes` - (B, N) outputs: A dict containing: `cls_prob`: (B, N, C) `mask_id_prob`: (B, H, W, N) `cls_logits`: (B, N, C) `mask_id_logits`: (B, H, W, N) alpha: Weight for pos/neg balance. """ # Classification score: (B, N, N) # in batch b, the probability of prediction i being class of gt j, i.e.: # score[b, i, j] = pred_cls[b, i, gt_cls[b, j]] gt_cls = labels["classes"] # (B, N) pred_cls = outputs["cls_prob"] # (B, N, C) b, n = utilities.resolve_shape(pred_cls)[:2] # indices[b, i, j] = gt_cls[b, j] indices = tf.tile(tf.expand_dims(gt_cls, 1), (1, n, 1)) cls_score = tf.gather(pred_cls, tf.cast(indices, tf.int32), batch_dims=2) # Mask score (dice): (B, N, N) # mask_score[b, i, j]: dice-similarity for pred i and gt j in batch b. mask_score = _dice_sim(tf.transpose(outputs["mask_id_prob"], (0, 3, 1, 2)), labels["masks"]) # Get similarity matrix and matching. # padded mask[b, j, i] = -1 << other scores, if i >= num_instance[b] similarity = cls_score * mask_score padded_mask = tf.cast(tf.reshape(tf.range(n), (1, 1, n)), tf.float32) padded_mask = tf.cast( tf.math.greater_equal(padded_mask, tf.reshape(labels["num_instance"], (b, 1, 1))), tf.float32) # The constant value for padding has no effect. masked_similarity = similarity * (1. - padded_mask) + padded_mask * (-1.) matched_mask = matchers_ops.hungarian_matching(-masked_similarity) matched_mask = tf.cast(matched_mask, tf.float32) * (1 - padded_mask) outputs["matched_mask"] = matched_mask # Pos loss loss_pos = (tf.stop_gradient(cls_score) * (-mask_score) + tf.stop_gradient(mask_score) * (-tf.math.log(cls_score))) loss_pos = tf.reduce_sum(loss_pos * matched_mask, axis=[1, 2]) # (B,) # Neg loss matched_pred = tf.cast( tf.reduce_sum(matched_mask, axis=2) > 0, tf.float32) # (B, N) # 0 for void class log_loss = -tf.nn.log_softmax(outputs["cls_logits"])[:, :, 0] # (B, N) loss_neg = tf.reduce_sum(log_loss * (1. - matched_pred), axis=-1) # (B,) loss_pq = (alpha * loss_pos + (1 - alpha) * loss_neg) / n loss_pq = tf.reduce_mean(loss_pq) loss_dict["loss_pq"] = loss_pq
def _get_instance_labels(self, data: TensorDict, features: TensorDict, labels: NestedTensorDict): """Generate the labels for text entity detection.""" labels['instance_labels'] = {} # (1) Depending on `detection_unit`: # Convert the word-id map to line-id map or use the word-id map directly # Word entity ids start from 1 in the map, so pad a -1 at the beginning of # the parent list to counter this offset. padded_parent = tf.concat( [tf.constant([-1]), tf.cast(data['groundtruth_parent'], tf.int32)], 0) if self._detection_unit == DetectionClass.WORD: entity_id_mask = data['resized_masks'] elif self._detection_unit == DetectionClass.LINE: # The pixel value is entity_id + 1, shape = [H, W]; 0 for background. # correctness: # 0s in data['resized_masks'] --> padded_parent[0] == -1 # i-th entity in plp.entities --> i+1 in data['resized_masks'] # --> padded_parent[i+1] # --> data['groundtruth_parent'][i] # --> the parent of i-th entity entity_id_mask = tf.gather(padded_parent, data['resized_masks']) + 1 elif self._detection_unit == DetectionClass.PARAGRAPH: # directly segmenting paragraphs; two hops here. entity_id_mask = tf.gather(padded_parent, data['resized_masks']) + 1 entity_id_mask = tf.gather(padded_parent, entity_id_mask) + 1 else: raise ValueError(f'No such detection unit: {self._detection_unit}') data['entity_id_mask'] = entity_id_mask # (2) Get individual masks for entities. entity_selection_mask = tf.equal(data['groundtruth_classes'], self._detection_unit) num_all_entity = utilities.resolve_shape( data['groundtruth_classes'])[0] # entity_ids is a 1-D tensor for IDs of all entities of a certain type. entity_ids = tf.boolean_mask(tf.range(num_all_entity, dtype=tf.int32), entity_selection_mask) # (N,) # +1 to match the entity ids in entity_id_mask entity_ids = tf.reshape(entity_ids, (-1, 1, 1)) + 1 individual_masks = tf.expand_dims(entity_id_mask, 0) individual_masks = tf.equal(entity_ids, individual_masks) # (N, H, W), bool # TODO(longshangbang): replace with real mask sizes computing. # Currently, we use full-resolution masks for individual_masks. In order to # compute mask sizes, we need to convert individual_masks to int/float type. # This will cause OOM because the mask is too large. masks_sizes = tf.cast(tf.reduce_any(individual_masks, axis=[1, 2]), tf.float32) # remove empty masks (usually caused by cropping) non_empty_masks_ids = tf.not_equal(masks_sizes, 0) valid_masks = tf.boolean_mask(individual_masks, non_empty_masks_ids) valid_entity_ids = tf.boolean_mask(entity_ids, non_empty_masks_ids)[:, 0, 0] # (3) Write num of instance num_instance = tf.reduce_sum(tf.cast(non_empty_masks_ids, tf.float32)) num_instance_and_bkg = num_instance + 1 if self._max_num_instance >= 0: num_instance_and_bkg = tf.minimum(num_instance_and_bkg, self._max_num_instance) labels['instance_labels']['num_instance'] = num_instance_and_bkg # (4) Write instance masks num_entity_int = tf.cast(num_instance, tf.int32) max_num_entities = self._max_num_instance - 1 # Spare 1 for bkg. pad_num = tf.maximum(max_num_entities - num_entity_int, 0) padded_valid_masks = tf.pad(valid_masks, [[0, pad_num], [0, 0], [0, 0]]) # If there are more instances than allowed, randomly sample some. # `random_selection_mask` is a 0/1 array; the maximum number of 1 is # `self._max_num_instance`; if not bound, it's an array with all 1s. if self._max_num_instance >= 0: padded_size = num_entity_int + pad_num random_selection = tf.random.uniform((padded_size, ), dtype=tf.float32) selected_indices = tf.math.top_k(random_selection, k=max_num_entities)[1] random_selection_mask = tf.scatter_nd( indices=tf.expand_dims(selected_indices, axis=-1), updates=tf.ones((max_num_entities, ), dtype=tf.bool), shape=(padded_size, )) else: random_selection_mask = tf.ones((num_entity_int, ), dtype=tf.bool) random_discard_mask = tf.logical_not(random_selection_mask) kept_masks = tf.boolean_mask(padded_valid_masks, random_selection_mask) erased_masks = tf.boolean_mask(padded_valid_masks, random_discard_mask) erased_masks = tf.cast(tf.reduce_any(erased_masks, axis=0), tf.float32) # erase text instances that are obmitted. features['images'] = _erase(erased_masks, features['images'], -1., 1.) labels['segmentation_output']['gt_word_score'] *= 1. - erased_masks kept_masks_and_bkg = tf.concat( [ tf.math.logical_not( tf.reduce_any(kept_masks, axis=0, keepdims=True)), # bkg kept_masks, ], 0) labels['instance_labels']['masks'] = tf.argmax(kept_masks_and_bkg, axis=0) # (5) Write mask size # TODO(longshangbang): replace with real masks sizes masks_sizes = tf.cast(tf.reduce_any(kept_masks_and_bkg, axis=[1, 2]), tf.float32) labels['instance_labels']['masks_sizes'] = masks_sizes # (6) Write classes. classes = tf.ones((num_instance, ), dtype=tf.int32) classes = tf.concat([tf.constant(2, tf.int32, (1, )), classes], 0) # bkg if self._max_num_instance >= 0: classes = utilities.truncate_or_pad(classes, self._max_num_instance, 0) labels['instance_labels']['classes'] = classes # (7) gt-weights selected_ids = tf.boolean_mask(valid_entity_ids, random_selection_mask[:num_entity_int]) if self._detection_unit != DetectionClass.PARAGRAPH: gt_text = tf.gather(data['groundtruth_text'], selected_ids - 1) gt_weights = tf.cast(tf.strings.length(gt_text) > 0, tf.float32) else: text_types = tf.concat( [ tf.constant([8]), tf.cast(data['groundtruth_content_type'], tf.int32), # TODO(longshangbang): temp solution for tfes with no para labels tf.constant(8, shape=(1000, )), ], 0) para_types = tf.gather(text_types, selected_ids) gt_weights = tf.cast(tf.not_equal(para_types, NOT_ANNOTATED_ID), tf.float32) gt_weights = tf.concat([tf.constant(1., shape=(1, )), gt_weights], 0) # bkg if self._max_num_instance >= 0: gt_weights = utilities.truncate_or_pad(gt_weights, self._max_num_instance, 0) labels['instance_labels']['gt_weights'] = gt_weights # (8) get paragraph label # In this step, an array `{p_i}` is generated. `p_i` is an integer that # indicates the group of paragraph which i-th text belongs to. `p_i` == -1 # if this instance is non-text or it has no paragraph labels. # word -> line -> paragraph if self._detection_unit == DetectionClass.WORD: num_hop = 2 elif self._detection_unit == DetectionClass.LINE: num_hop = 1 elif self._detection_unit == DetectionClass.PARAGRAPH: num_hop = 0 else: raise ValueError( f'No such detection unit: {self._detection_unit}. ' 'Note that this error should have been raised in ' 'previous lines, not here!') para_ids = tf.identity(selected_ids) # == id in plp + 1 for _ in range(num_hop): para_ids = tf.gather(padded_parent, para_ids) + 1 text_types = tf.concat( [ tf.constant([8]), tf.cast(data['groundtruth_content_type'], tf.int32), # TODO(longshangbang): tricks for tfes that have not para labels tf.constant(8, shape=(1000, )), ], 0) para_types = tf.gather(text_types, para_ids) para_ids = para_ids - 1 # revert to id in plp.entities; -1 for no labels valid_para = tf.cast(tf.not_equal(para_types, NOT_ANNOTATED_ID), tf.int32) para_ids = valid_para * para_ids + (1 - valid_para) * (-1) para_ids = tf.concat([tf.constant([-1]), para_ids], 0) # add bkg has_para_ids = tf.cast(tf.reduce_sum(valid_para) > 0, tf.float32) if self._max_num_instance >= 0: para_ids = utilities.truncate_or_pad(para_ids, self._max_num_instance, 0, -1) labels['paragraph_labels'] = { 'paragraph_ids': para_ids, 'has_para_ids': has_para_ids }