def build_center_net_meta_arch(): feature_extractor = center_net_resnet_feature_extractor.\ CenterNetResnetFeatureExtractor('resnet_v2_101') image_resizer_fn = functools.partial( preprocessor.resize_to_range, min_dimension=128, max_dimension=128, pad_to_max_dimesnion=True) object_detection_params = center_net_meta_arch.ObjectDetectionParams( classification_loss=losses.PenaltyReducedLogisticFocalLoss( alpha=1.0, beta=1.0), classification_loss_weight=1.0, localization_loss=losses.L1LocalizationLoss(), offset_loss_weight=1.0, scale_loss_weight=0.1, min_box_overlap_iou=1.0, max_box_predictions=5) return center_net_meta_arch.CenterNetMetaArch( is_training=True, add_summaries=False, num_classes=10, feature_extractor=feature_extractor, image_resizer_fn=image_resizer_fn, object_detection_params=object_detection_params)
def build_meta_arch(predict_full_resolution_masks=False, use_dice_loss=False): """Builds the DeepMAC meta architecture.""" feature_extractor = DummyFeatureExtractor(channel_means=(1.0, 2.0, 3.0), channel_stds=(10., 20., 30.), bgr_ordering=False, num_feature_outputs=2, stride=4) image_resizer_fn = functools.partial(preprocessor.resize_to_range, min_dimension=128, max_dimension=128, pad_to_max_dimesnion=True) object_center_params = center_net_meta_arch.ObjectCenterParams( classification_loss=losses.WeightedSigmoidClassificationLoss(), object_center_loss_weight=1.0, min_box_overlap_iou=1.0, max_box_predictions=5, use_labeled_classes=False) if use_dice_loss: classification_loss = losses.WeightedDiceClassificationLoss(False) else: classification_loss = losses.WeightedSigmoidClassificationLoss() deepmac_params = deepmac_meta_arch.DeepMACParams( classification_loss=classification_loss, dim=8, task_loss_weight=1.0, pixel_embedding_dim=2, allowed_masked_classes_ids=[], mask_size=16, mask_num_subsamples=-1, use_xy=True, network_type='hourglass10', use_instance_embedding=True, num_init_channels=8, predict_full_resolution_masks=predict_full_resolution_masks, postprocess_crop_size=128, max_roi_jitter_ratio=0.0, roi_jitter_mode='random') object_detection_params = center_net_meta_arch.ObjectDetectionParams( localization_loss=losses.L1LocalizationLoss(), offset_loss_weight=1.0, scale_loss_weight=0.1) return deepmac_meta_arch.DeepMACMetaArch( is_training=True, add_summaries=False, num_classes=6, feature_extractor=feature_extractor, object_center_params=object_center_params, deepmac_params=deepmac_params, object_detection_params=object_detection_params, image_resizer_fn=image_resizer_fn)
def object_detection_proto_to_params(od_config): """Converts CenterNet.ObjectDetection proto to parameter namedtuple.""" loss = losses_pb2.Loss() # Add dummy classification loss to avoid the loss_builder throwing error. # TODO(yuhuic): update the loss builder to take the classification loss # directly. loss.classification_loss.weighted_sigmoid.CopyFrom( losses_pb2.WeightedSigmoidClassificationLoss()) loss.localization_loss.CopyFrom(od_config.localization_loss) _, localization_loss, _, _, _, _, _ = (losses_builder.build(loss)) return center_net_meta_arch.ObjectDetectionParams( localization_loss=localization_loss, scale_loss_weight=od_config.scale_loss_weight, offset_loss_weight=od_config.offset_loss_weight, task_loss_weight=od_config.task_loss_weight)
def build_meta_arch(**override_params): """Builds the DeepMAC meta architecture.""" params = dict(predict_full_resolution_masks=False, use_instance_embedding=True, mask_num_subsamples=-1, network_type='hourglass10', use_xy=True, pixel_embedding_dim=2, dice_loss_prediction_probability=False, color_consistency_threshold=0.5, use_dice_loss=False, box_consistency_loss_normalize='normalize_auto', box_consistency_tightness=False, task_loss_weight=1.0, color_consistency_loss_weight=1.0, box_consistency_loss_weight=1.0, num_init_channels=8, dim=8, allowed_masked_classes_ids=[], mask_size=16, postprocess_crop_size=128, max_roi_jitter_ratio=0.0, roi_jitter_mode='random', color_consistency_dilation=2, color_consistency_warmup_steps=0, color_consistency_warmup_start=0) params.update(override_params) feature_extractor = DummyFeatureExtractor(channel_means=(1.0, 2.0, 3.0), channel_stds=(10., 20., 30.), bgr_ordering=False, num_feature_outputs=2, stride=4) image_resizer_fn = functools.partial(preprocessor.resize_to_range, min_dimension=128, max_dimension=128, pad_to_max_dimesnion=True) object_center_params = center_net_meta_arch.ObjectCenterParams( classification_loss=losses.WeightedSigmoidClassificationLoss(), object_center_loss_weight=1.0, min_box_overlap_iou=1.0, max_box_predictions=5, use_labeled_classes=False) use_dice_loss = params.pop('use_dice_loss') dice_loss_prediction_prob = params.pop('dice_loss_prediction_probability') if use_dice_loss: classification_loss = losses.WeightedDiceClassificationLoss( squared_normalization=False, is_prediction_probability=dice_loss_prediction_prob) else: classification_loss = losses.WeightedSigmoidClassificationLoss() deepmac_params = deepmac_meta_arch.DeepMACParams( classification_loss=classification_loss, **params) object_detection_params = center_net_meta_arch.ObjectDetectionParams( localization_loss=losses.L1LocalizationLoss(), offset_loss_weight=1.0, scale_loss_weight=0.1) return deepmac_meta_arch.DeepMACMetaArch( is_training=True, add_summaries=False, num_classes=6, feature_extractor=feature_extractor, object_center_params=object_center_params, deepmac_params=deepmac_params, object_detection_params=object_detection_params, image_resizer_fn=image_resizer_fn)