def transform_and_pad_input_data_fn(tensor_dict): """Combines transform and pad operation.""" num_classes = config_util.get_number_of_classes(model_config) image_resizer_config = config_util.get_image_resizer_config( model_config) image_resizer_fn = image_resizer_builder.build(image_resizer_config) keypoint_type_weight = eval_input_config.keypoint_type_weight or None transform_data_fn = functools.partial( transform_input_data, model_preprocess_fn=model_preprocess_fn, image_resizer_fn=image_resizer_fn, num_classes=num_classes, data_augmentation_fn=None, retain_original_image=eval_config.retain_original_images, retain_original_image_additional_channels=eval_config. retain_original_image_additional_channels, keypoint_type_weight=keypoint_type_weight) tensor_dict = pad_input_data_to_static_shapes( tensor_dict=transform_data_fn(tensor_dict), max_num_boxes=eval_input_config.max_number_of_boxes, num_classes=config_util.get_number_of_classes(model_config), spatial_image_shape=config_util.get_spatial_image_size( image_resizer_config), max_num_context_features=config_util.get_max_num_context_features( model_config), context_feature_length=config_util.get_context_feature_length( model_config)) return (_get_features_dict(tensor_dict), _get_labels_dict(tensor_dict))
def transform_and_pad_input_data_fn(tensor_dict): """Combines transform and pad operation.""" data_augmentation_options = [ preprocessor_builder.build(step) for step in train_config.data_augmentation_options ] data_augmentation_fn = functools.partial( augment_input_data, data_augmentation_options=data_augmentation_options) image_resizer_config = config_util.get_image_resizer_config( model_config) image_resizer_fn = image_resizer_builder.build(image_resizer_config) keypoint_type_weight = train_input_config.keypoint_type_weight or None transform_data_fn = functools.partial( transform_input_data, model_preprocess_fn=model_preprocess_fn, image_resizer_fn=image_resizer_fn, num_classes=num_classes, data_augmentation_fn=data_augmentation_fn, merge_multiple_boxes=train_config.merge_multiple_label_boxes, retain_original_image=train_config.retain_original_images, use_multiclass_scores=train_config.use_multiclass_scores, use_bfloat16=train_config.use_bfloat16, keypoint_type_weight=keypoint_type_weight) tensor_dict = pad_input_data_to_static_shapes( tensor_dict=transform_data_fn(tensor_dict), max_num_boxes=train_input_config.max_number_of_boxes, num_classes=num_classes, spatial_image_shape=config_util.get_spatial_image_size( image_resizer_config), max_num_context_features=config_util.get_max_num_context_features( model_config), context_feature_length=config_util.get_context_feature_length( model_config)) include_source_id = train_input_config.include_source_id return (_get_features_dict(tensor_dict, include_source_id), _get_labels_dict(tensor_dict))
def testGetContextFeatureLengthFromModelConfig(self): model_config = model_pb2.DetectionModel() model_config.faster_rcnn.context_config.context_feature_length = 100 context_feature_length = config_util.get_context_feature_length( model_config) self.assertAllEqual(context_feature_length, 100)