Exemplo n.º 1
0
    def transform_and_pad_input_data_fn(tensor_dict):
        """Combines transform and pad operation."""
        num_classes = config_util.get_number_of_classes(model_config)

        image_resizer_config = config_util.get_image_resizer_config(
            model_config)
        image_resizer_fn = image_resizer_builder.build(image_resizer_config)
        keypoint_type_weight = eval_input_config.keypoint_type_weight or None

        transform_data_fn = functools.partial(
            transform_input_data,
            model_preprocess_fn=model_preprocess_fn,
            image_resizer_fn=image_resizer_fn,
            num_classes=num_classes,
            data_augmentation_fn=None,
            retain_original_image=eval_config.retain_original_images,
            retain_original_image_additional_channels=eval_config.
            retain_original_image_additional_channels,
            keypoint_type_weight=keypoint_type_weight)
        tensor_dict = pad_input_data_to_static_shapes(
            tensor_dict=transform_data_fn(tensor_dict),
            max_num_boxes=eval_input_config.max_number_of_boxes,
            num_classes=config_util.get_number_of_classes(model_config),
            spatial_image_shape=config_util.get_spatial_image_size(
                image_resizer_config),
            max_num_context_features=config_util.get_max_num_context_features(
                model_config),
            context_feature_length=config_util.get_context_feature_length(
                model_config))
        return (_get_features_dict(tensor_dict), _get_labels_dict(tensor_dict))
Exemplo n.º 2
0
    def transform_and_pad_input_data_fn(tensor_dict):
        """Combines transform and pad operation."""
        data_augmentation_options = [
            preprocessor_builder.build(step)
            for step in train_config.data_augmentation_options
        ]
        data_augmentation_fn = functools.partial(
            augment_input_data,
            data_augmentation_options=data_augmentation_options)

        image_resizer_config = config_util.get_image_resizer_config(
            model_config)
        image_resizer_fn = image_resizer_builder.build(image_resizer_config)
        keypoint_type_weight = train_input_config.keypoint_type_weight or None
        transform_data_fn = functools.partial(
            transform_input_data,
            model_preprocess_fn=model_preprocess_fn,
            image_resizer_fn=image_resizer_fn,
            num_classes=num_classes,
            data_augmentation_fn=data_augmentation_fn,
            merge_multiple_boxes=train_config.merge_multiple_label_boxes,
            retain_original_image=train_config.retain_original_images,
            use_multiclass_scores=train_config.use_multiclass_scores,
            use_bfloat16=train_config.use_bfloat16,
            keypoint_type_weight=keypoint_type_weight)

        tensor_dict = pad_input_data_to_static_shapes(
            tensor_dict=transform_data_fn(tensor_dict),
            max_num_boxes=train_input_config.max_number_of_boxes,
            num_classes=num_classes,
            spatial_image_shape=config_util.get_spatial_image_size(
                image_resizer_config),
            max_num_context_features=config_util.get_max_num_context_features(
                model_config),
            context_feature_length=config_util.get_context_feature_length(
                model_config))
        include_source_id = train_input_config.include_source_id
        return (_get_features_dict(tensor_dict, include_source_id),
                _get_labels_dict(tensor_dict))
Exemplo n.º 3
0
 def testGetMaxNumContextFeaturesFromModelConfig(self):
   model_config = model_pb2.DetectionModel()
   model_config.faster_rcnn.context_config.max_num_context_features = 10
   max_num_context_features = config_util.get_max_num_context_features(
       model_config)
   self.assertAllEqual(max_num_context_features, 10)