def extract_images_and_targets(read_data): """Extract images and targets from the input dict.""" image = read_data[fields.InputDataFields.image] key = '' if fields.InputDataFields.source_id in read_data: key = read_data[fields.InputDataFields.source_id] location_gt = read_data[fields.InputDataFields.groundtruth_boxes] classes_gt = tf.cast( read_data[fields.InputDataFields.groundtruth_classes], tf.int32) classes_gt -= label_id_offset if merge_multiple_label_boxes: location_gt, classes_gt, _ = util_ops.merge_boxes_with_multiple_labels( location_gt, classes_gt, num_classes) else: classes_gt = util_ops.padded_one_hot_encoding(indices=classes_gt, depth=num_classes, left_pad=0) masks_gt = read_data.get( fields.InputDataFields.groundtruth_instance_masks) keypoints_gt = read_data.get( fields.InputDataFields.groundtruth_keypoints) if (merge_multiple_label_boxes and (masks_gt is not None or keypoints_gt is not None)): raise NotImplementedError('Multi-label support is only for boxes.') return image, key, location_gt, classes_gt, masks_gt, keypoints_gt
def transform_input_data(tensor_dict, model_preprocess_fn, image_resizer_fn, num_classes, data_augmentation_fn=None, merge_multiple_boxes=False, retain_original_image=False, use_multiclass_scores=False, use_bfloat16=False): """A single function that is responsible for all input data transformations. Data transformation functions are applied in the following order. 1. If key fields.InputDataFields.image_additional_channels is present in tensor_dict, the additional channels will be merged into fields.InputDataFields.image. 2. data_augmentation_fn (optional): applied on tensor_dict. 3. model_preprocess_fn: applied only on image tensor in tensor_dict. 4. image_resizer_fn: applied on original image and instance mask tensor in tensor_dict. 5. one_hot_encoding: applied to classes tensor in tensor_dict. 6. merge_multiple_boxes (optional): when groundtruth boxes are exactly the same they can be merged into a single box with an associated k-hot class label. Args: tensor_dict: dictionary containing input tensors keyed by fields.InputDataFields. model_preprocess_fn: model's preprocess function to apply on image tensor. This function must take in a 4-D float tensor and return a 4-D preprocess float tensor and a tensor containing the true image shape. image_resizer_fn: image resizer function to apply on groundtruth instance `masks. This function must take a 3-D float tensor of an image and a 3-D tensor of instance masks and return a resized version of these along with the true shapes. num_classes: number of max classes to one-hot (or k-hot) encode the class labels. data_augmentation_fn: (optional) data augmentation function to apply on input `tensor_dict`. merge_multiple_boxes: (optional) whether to merge multiple groundtruth boxes and classes for a given image if the boxes are exactly the same. retain_original_image: (optional) whether to retain original image in the output dictionary. use_multiclass_scores: whether to use multiclass scores as class targets instead of one-hot encoding of `groundtruth_classes`. use_bfloat16: (optional) a bool, whether to use bfloat16 in training. Returns: A dictionary keyed by fields.InputDataFields containing the tensors obtained after applying all the transformations. """ # Reshape flattened multiclass scores tensor into a 2D tensor of shape # [num_boxes, num_classes]. if fields.InputDataFields.multiclass_scores in tensor_dict: tensor_dict[fields.InputDataFields.multiclass_scores] = tf.reshape( tensor_dict[fields.InputDataFields.multiclass_scores], [ tf.shape(tensor_dict[fields.InputDataFields.groundtruth_boxes])[0], num_classes ]) if fields.InputDataFields.groundtruth_boxes in tensor_dict: tensor_dict = util_ops.filter_groundtruth_with_nan_box_coordinates( tensor_dict) tensor_dict = util_ops.filter_unrecognized_classes(tensor_dict) if retain_original_image: tensor_dict[fields.InputDataFields.original_image] = tf.cast( image_resizer_fn(tensor_dict[fields.InputDataFields.image], None)[0], tf.uint8) if fields.InputDataFields.image_additional_channels in tensor_dict: channels = tensor_dict[fields.InputDataFields.image_additional_channels] tensor_dict[fields.InputDataFields.image] = tf.concat( [tensor_dict[fields.InputDataFields.image], channels], axis=2) # Apply data augmentation ops. if data_augmentation_fn is not None: tensor_dict = data_augmentation_fn(tensor_dict) # Apply model preprocessing ops and resize instance masks. image = tensor_dict[fields.InputDataFields.image] preprocessed_resized_image, true_image_shape = model_preprocess_fn( tf.expand_dims(tf.to_float(image), axis=0)) if use_bfloat16: preprocessed_resized_image = tf.cast( preprocessed_resized_image, tf.bfloat16) tensor_dict[fields.InputDataFields.image] = tf.squeeze( preprocessed_resized_image, axis=0) tensor_dict[fields.InputDataFields.true_image_shape] = tf.squeeze( true_image_shape, axis=0) if fields.InputDataFields.groundtruth_instance_masks in tensor_dict: masks = tensor_dict[fields.InputDataFields.groundtruth_instance_masks] _, resized_masks, _ = image_resizer_fn(image, masks) if use_bfloat16: resized_masks = tf.cast(resized_masks, tf.bfloat16) tensor_dict[fields.InputDataFields. groundtruth_instance_masks] = resized_masks # Transform groundtruth classes to one hot encodings. label_offset = 1 zero_indexed_groundtruth_classes = tensor_dict[ fields.InputDataFields.groundtruth_classes] - label_offset tensor_dict[fields.InputDataFields.groundtruth_classes] = tf.one_hot( zero_indexed_groundtruth_classes, num_classes) if use_multiclass_scores: tensor_dict[fields.InputDataFields.groundtruth_classes] = tensor_dict[ fields.InputDataFields.multiclass_scores] tensor_dict.pop(fields.InputDataFields.multiclass_scores, None) if fields.InputDataFields.groundtruth_confidences in tensor_dict: groundtruth_confidences = tensor_dict[ fields.InputDataFields.groundtruth_confidences] # Map the confidences to the one-hot encoding of classes tensor_dict[fields.InputDataFields.groundtruth_confidences] = ( tf.reshape(groundtruth_confidences, [-1, 1]) * tensor_dict[fields.InputDataFields.groundtruth_classes]) else: groundtruth_confidences = tf.ones_like( zero_indexed_groundtruth_classes, dtype=tf.float32) tensor_dict[fields.InputDataFields.groundtruth_confidences] = ( tensor_dict[fields.InputDataFields.groundtruth_classes]) if merge_multiple_boxes: merged_boxes, merged_classes, merged_confidences, _ = ( util_ops.merge_boxes_with_multiple_labels( tensor_dict[fields.InputDataFields.groundtruth_boxes], zero_indexed_groundtruth_classes, groundtruth_confidences, num_classes)) merged_classes = tf.cast(merged_classes, tf.float32) tensor_dict[fields.InputDataFields.groundtruth_boxes] = merged_boxes tensor_dict[fields.InputDataFields.groundtruth_classes] = merged_classes tensor_dict[fields.InputDataFields.groundtruth_confidences] = ( merged_confidences) if fields.InputDataFields.groundtruth_boxes in tensor_dict: tensor_dict[fields.InputDataFields.num_groundtruth_boxes] = tf.shape( tensor_dict[fields.InputDataFields.groundtruth_boxes])[0] return tensor_dict
def transform_input_data(tensor_dict, model_preprocess_fn, image_resizer_fn, num_classes, data_augmentation_fn=None, merge_multiple_boxes=False, retain_original_image=False): """A single function that is responsible for all input data transformations. Data transformation functions are applied in the following order. 1. data_augmentation_fn (optional): applied on tensor_dict. 2. model_preprocess_fn: applied only on image tensor in tensor_dict. 3. image_resizer_fn: applied only on instance mask tensor in tensor_dict. 4. one_hot_encoding: applied to classes tensor in tensor_dict. 5. merge_multiple_boxes (optional): when groundtruth boxes are exactly the same they can be merged into a single box with an associated k-hot class label. Args: tensor_dict: dictionary containing input tensors keyed by fields.InputDataFields. model_preprocess_fn: model's preprocess function to apply on image tensor. This function must take in a 4-D float tensor and return a 4-D preprocess float tensor and a tensor containing the true image shape. image_resizer_fn: image resizer function to apply on groundtruth instance masks. This function must take a 4-D float tensor of image and a 4-D tensor of instances masks and return resized version of these along with the true shapes. num_classes: number of max classes to one-hot (or k-hot) encode the class labels. data_augmentation_fn: (optional) data augmentation function to apply on input `tensor_dict`. merge_multiple_boxes: (optional) whether to merge multiple groundtruth boxes and classes for a given image if the boxes are exactly the same. retain_original_image: (optional) whether to retain original image in the output dictionary. Returns: A dictionary keyed by fields.InputDataFields containing the tensors obtained after applying all the transformations. """ if retain_original_image: tensor_dict[fields.InputDataFields.original_image] = tensor_dict[ fields.InputDataFields.image] # Apply data augmentation ops. if data_augmentation_fn is not None: tensor_dict = data_augmentation_fn(tensor_dict) # Apply model preprocessing ops and resize instance masks. image = tf.expand_dims(tf.to_float( tensor_dict[fields.InputDataFields.image]), axis=0) preprocessed_resized_image, true_image_shape = model_preprocess_fn(image) tensor_dict[fields.InputDataFields.image] = tf.squeeze( preprocessed_resized_image, axis=0) tensor_dict[fields.InputDataFields.true_image_shape] = tf.squeeze( true_image_shape, axis=0) if fields.InputDataFields.groundtruth_instance_masks in tensor_dict: masks = tensor_dict[fields.InputDataFields.groundtruth_instance_masks] _, resized_masks, _ = image_resizer_fn(image, masks) tensor_dict[ fields.InputDataFields.groundtruth_instance_masks] = resized_masks # Transform groundtruth classes to one hot encodings. label_offset = 1 zero_indexed_groundtruth_classes = tensor_dict[ fields.InputDataFields.groundtruth_classes] - label_offset tensor_dict[fields.InputDataFields.groundtruth_classes] = tf.one_hot( zero_indexed_groundtruth_classes, num_classes) if merge_multiple_boxes: merged_boxes, merged_classes, _ = util_ops.merge_boxes_with_multiple_labels( tensor_dict[fields.InputDataFields.groundtruth_boxes], zero_indexed_groundtruth_classes, num_classes) tensor_dict[fields.InputDataFields.groundtruth_boxes] = merged_boxes tensor_dict[ fields.InputDataFields.groundtruth_classes] = merged_classes return tensor_dict