def _build_classification_loss(loss_config): """Builds a classification loss based on the loss config. Args: loss_config: A losses_pb2.ClassificationLoss object. Returns: Loss based on the config. Raises: ValueError: On invalid loss_config. """ if not isinstance(loss_config, losses_pb2.ClassificationLoss): raise ValueError( 'loss_config not of type losses_pb2.ClassificationLoss.') loss_type = loss_config.WhichOneof('classification_loss') if loss_type == 'weighted_sigmoid': return losses.WeightedSigmoidClassificationLoss() elif loss_type == 'weighted_sigmoid_focal': config = loss_config.weighted_sigmoid_focal alpha = None if config.HasField('alpha'): alpha = config.alpha return losses.SigmoidFocalClassificationLoss(gamma=config.gamma, alpha=alpha) elif loss_type == 'weighted_softmax': config = loss_config.weighted_softmax return losses.WeightedSoftmaxClassificationLoss( logit_scale=config.logit_scale) elif loss_type == 'weighted_logits_softmax': config = loss_config.weighted_logits_softmax return losses.WeightedSoftmaxClassificationAgainstLogitsLoss( logit_scale=config.logit_scale) elif loss_type == 'bootstrapped_sigmoid': config = loss_config.bootstrapped_sigmoid return losses.BootstrappedSigmoidClassificationLoss( alpha=config.alpha, bootstrap_type=('hard' if config.hard_bootstrap else 'soft')) elif loss_type == 'penalty_reduced_logistic_focal_loss': config = loss_config.penalty_reduced_logistic_focal_loss return losses.PenaltyReducedLogisticFocalLoss(alpha=config.alpha, beta=config.beta) elif loss_type == 'weighted_dice_classification_loss': config = loss_config.weighted_dice_classification_loss return losses.WeightedDiceClassificationLoss( squared_normalization=config.squared_normalization, is_prediction_probability=config.is_prediction_probability) else: raise ValueError('Empty loss config.')
def build_meta_arch(predict_full_resolution_masks=False, use_dice_loss=False): """Builds the DeepMAC meta architecture.""" feature_extractor = DummyFeatureExtractor(channel_means=(1.0, 2.0, 3.0), channel_stds=(10., 20., 30.), bgr_ordering=False, num_feature_outputs=2, stride=4) image_resizer_fn = functools.partial(preprocessor.resize_to_range, min_dimension=128, max_dimension=128, pad_to_max_dimesnion=True) object_center_params = center_net_meta_arch.ObjectCenterParams( classification_loss=losses.WeightedSigmoidClassificationLoss(), object_center_loss_weight=1.0, min_box_overlap_iou=1.0, max_box_predictions=5, use_labeled_classes=False) if use_dice_loss: classification_loss = losses.WeightedDiceClassificationLoss(False) else: classification_loss = losses.WeightedSigmoidClassificationLoss() deepmac_params = deepmac_meta_arch.DeepMACParams( classification_loss=classification_loss, dim=8, task_loss_weight=1.0, pixel_embedding_dim=2, allowed_masked_classes_ids=[], mask_size=16, mask_num_subsamples=-1, use_xy=True, network_type='hourglass10', use_instance_embedding=True, num_init_channels=8, predict_full_resolution_masks=predict_full_resolution_masks, postprocess_crop_size=128, max_roi_jitter_ratio=0.0, roi_jitter_mode='random') object_detection_params = center_net_meta_arch.ObjectDetectionParams( localization_loss=losses.L1LocalizationLoss(), offset_loss_weight=1.0, scale_loss_weight=0.1) return deepmac_meta_arch.DeepMACMetaArch( is_training=True, add_summaries=False, num_classes=6, feature_extractor=feature_extractor, object_center_params=object_center_params, deepmac_params=deepmac_params, object_detection_params=object_detection_params, image_resizer_fn=image_resizer_fn)
def build_meta_arch(**override_params): """Builds the DeepMAC meta architecture.""" params = dict(predict_full_resolution_masks=False, use_instance_embedding=True, mask_num_subsamples=-1, network_type='hourglass10', use_xy=True, pixel_embedding_dim=2, dice_loss_prediction_probability=False, color_consistency_threshold=0.5, use_dice_loss=False, box_consistency_loss_normalize='normalize_auto', box_consistency_tightness=False, task_loss_weight=1.0, color_consistency_loss_weight=1.0, box_consistency_loss_weight=1.0, num_init_channels=8, dim=8, allowed_masked_classes_ids=[], mask_size=16, postprocess_crop_size=128, max_roi_jitter_ratio=0.0, roi_jitter_mode='random', color_consistency_dilation=2, color_consistency_warmup_steps=0, color_consistency_warmup_start=0) params.update(override_params) feature_extractor = DummyFeatureExtractor(channel_means=(1.0, 2.0, 3.0), channel_stds=(10., 20., 30.), bgr_ordering=False, num_feature_outputs=2, stride=4) image_resizer_fn = functools.partial(preprocessor.resize_to_range, min_dimension=128, max_dimension=128, pad_to_max_dimesnion=True) object_center_params = center_net_meta_arch.ObjectCenterParams( classification_loss=losses.WeightedSigmoidClassificationLoss(), object_center_loss_weight=1.0, min_box_overlap_iou=1.0, max_box_predictions=5, use_labeled_classes=False) use_dice_loss = params.pop('use_dice_loss') dice_loss_prediction_prob = params.pop('dice_loss_prediction_probability') if use_dice_loss: classification_loss = losses.WeightedDiceClassificationLoss( squared_normalization=False, is_prediction_probability=dice_loss_prediction_prob) else: classification_loss = losses.WeightedSigmoidClassificationLoss() deepmac_params = deepmac_meta_arch.DeepMACParams( classification_loss=classification_loss, **params) object_detection_params = center_net_meta_arch.ObjectDetectionParams( localization_loss=losses.L1LocalizationLoss(), offset_loss_weight=1.0, scale_loss_weight=0.1) return deepmac_meta_arch.DeepMACMetaArch( is_training=True, add_summaries=False, num_classes=6, feature_extractor=feature_extractor, object_center_params=object_center_params, deepmac_params=deepmac_params, object_detection_params=object_detection_params, image_resizer_fn=image_resizer_fn)