Exemple #1
0
def _build_classification_loss(loss_config):
    """Builds a classification loss based on the loss config.

  Args:
    loss_config: A losses_pb2.ClassificationLoss object.

  Returns:
    Loss based on the config.

  Raises:
    ValueError: On invalid loss_config.
  """
    if not isinstance(loss_config, losses_pb2.ClassificationLoss):
        raise ValueError(
            'loss_config not of type losses_pb2.ClassificationLoss.')

    loss_type = loss_config.WhichOneof('classification_loss')

    if loss_type == 'weighted_sigmoid':
        return losses.WeightedSigmoidClassificationLoss()

    elif loss_type == 'weighted_sigmoid_focal':
        config = loss_config.weighted_sigmoid_focal
        alpha = None
        if config.HasField('alpha'):
            alpha = config.alpha
        return losses.SigmoidFocalClassificationLoss(gamma=config.gamma,
                                                     alpha=alpha)

    elif loss_type == 'weighted_softmax':
        config = loss_config.weighted_softmax
        return losses.WeightedSoftmaxClassificationLoss(
            logit_scale=config.logit_scale)

    elif loss_type == 'weighted_logits_softmax':
        config = loss_config.weighted_logits_softmax
        return losses.WeightedSoftmaxClassificationAgainstLogitsLoss(
            logit_scale=config.logit_scale)

    elif loss_type == 'bootstrapped_sigmoid':
        config = loss_config.bootstrapped_sigmoid
        return losses.BootstrappedSigmoidClassificationLoss(
            alpha=config.alpha,
            bootstrap_type=('hard' if config.hard_bootstrap else 'soft'))

    elif loss_type == 'penalty_reduced_logistic_focal_loss':
        config = loss_config.penalty_reduced_logistic_focal_loss
        return losses.PenaltyReducedLogisticFocalLoss(alpha=config.alpha,
                                                      beta=config.beta)

    elif loss_type == 'weighted_dice_classification_loss':
        config = loss_config.weighted_dice_classification_loss
        return losses.WeightedDiceClassificationLoss(
            squared_normalization=config.squared_normalization,
            is_prediction_probability=config.is_prediction_probability)

    else:
        raise ValueError('Empty loss config.')
Exemple #2
0
def build_meta_arch(predict_full_resolution_masks=False, use_dice_loss=False):
    """Builds the DeepMAC meta architecture."""

    feature_extractor = DummyFeatureExtractor(channel_means=(1.0, 2.0, 3.0),
                                              channel_stds=(10., 20., 30.),
                                              bgr_ordering=False,
                                              num_feature_outputs=2,
                                              stride=4)
    image_resizer_fn = functools.partial(preprocessor.resize_to_range,
                                         min_dimension=128,
                                         max_dimension=128,
                                         pad_to_max_dimesnion=True)

    object_center_params = center_net_meta_arch.ObjectCenterParams(
        classification_loss=losses.WeightedSigmoidClassificationLoss(),
        object_center_loss_weight=1.0,
        min_box_overlap_iou=1.0,
        max_box_predictions=5,
        use_labeled_classes=False)

    if use_dice_loss:
        classification_loss = losses.WeightedDiceClassificationLoss(False)
    else:
        classification_loss = losses.WeightedSigmoidClassificationLoss()

    deepmac_params = deepmac_meta_arch.DeepMACParams(
        classification_loss=classification_loss,
        dim=8,
        task_loss_weight=1.0,
        pixel_embedding_dim=2,
        allowed_masked_classes_ids=[],
        mask_size=16,
        mask_num_subsamples=-1,
        use_xy=True,
        network_type='hourglass10',
        use_instance_embedding=True,
        num_init_channels=8,
        predict_full_resolution_masks=predict_full_resolution_masks,
        postprocess_crop_size=128,
        max_roi_jitter_ratio=0.0,
        roi_jitter_mode='random')

    object_detection_params = center_net_meta_arch.ObjectDetectionParams(
        localization_loss=losses.L1LocalizationLoss(),
        offset_loss_weight=1.0,
        scale_loss_weight=0.1)

    return deepmac_meta_arch.DeepMACMetaArch(
        is_training=True,
        add_summaries=False,
        num_classes=6,
        feature_extractor=feature_extractor,
        object_center_params=object_center_params,
        deepmac_params=deepmac_params,
        object_detection_params=object_detection_params,
        image_resizer_fn=image_resizer_fn)
Exemple #3
0
def build_meta_arch(**override_params):
    """Builds the DeepMAC meta architecture."""

    params = dict(predict_full_resolution_masks=False,
                  use_instance_embedding=True,
                  mask_num_subsamples=-1,
                  network_type='hourglass10',
                  use_xy=True,
                  pixel_embedding_dim=2,
                  dice_loss_prediction_probability=False,
                  color_consistency_threshold=0.5,
                  use_dice_loss=False,
                  box_consistency_loss_normalize='normalize_auto',
                  box_consistency_tightness=False,
                  task_loss_weight=1.0,
                  color_consistency_loss_weight=1.0,
                  box_consistency_loss_weight=1.0,
                  num_init_channels=8,
                  dim=8,
                  allowed_masked_classes_ids=[],
                  mask_size=16,
                  postprocess_crop_size=128,
                  max_roi_jitter_ratio=0.0,
                  roi_jitter_mode='random',
                  color_consistency_dilation=2,
                  color_consistency_warmup_steps=0,
                  color_consistency_warmup_start=0)

    params.update(override_params)

    feature_extractor = DummyFeatureExtractor(channel_means=(1.0, 2.0, 3.0),
                                              channel_stds=(10., 20., 30.),
                                              bgr_ordering=False,
                                              num_feature_outputs=2,
                                              stride=4)
    image_resizer_fn = functools.partial(preprocessor.resize_to_range,
                                         min_dimension=128,
                                         max_dimension=128,
                                         pad_to_max_dimesnion=True)

    object_center_params = center_net_meta_arch.ObjectCenterParams(
        classification_loss=losses.WeightedSigmoidClassificationLoss(),
        object_center_loss_weight=1.0,
        min_box_overlap_iou=1.0,
        max_box_predictions=5,
        use_labeled_classes=False)

    use_dice_loss = params.pop('use_dice_loss')
    dice_loss_prediction_prob = params.pop('dice_loss_prediction_probability')
    if use_dice_loss:
        classification_loss = losses.WeightedDiceClassificationLoss(
            squared_normalization=False,
            is_prediction_probability=dice_loss_prediction_prob)
    else:
        classification_loss = losses.WeightedSigmoidClassificationLoss()

    deepmac_params = deepmac_meta_arch.DeepMACParams(
        classification_loss=classification_loss, **params)

    object_detection_params = center_net_meta_arch.ObjectDetectionParams(
        localization_loss=losses.L1LocalizationLoss(),
        offset_loss_weight=1.0,
        scale_loss_weight=0.1)

    return deepmac_meta_arch.DeepMACMetaArch(
        is_training=True,
        add_summaries=False,
        num_classes=6,
        feature_extractor=feature_extractor,
        object_center_params=object_center_params,
        deepmac_params=deepmac_params,
        object_detection_params=object_detection_params,
        image_resizer_fn=image_resizer_fn)