コード例 #1
0
ファイル: losses_builder.py プロジェクト: nilskk/models
def _build_localization_loss(loss_config):
    """Builds a localization loss based on the loss config.

  Args:
    loss_config: A losses_pb2.LocalizationLoss object.

  Returns:
    Loss based on the config.

  Raises:
    ValueError: On invalid loss_config.
  """
    if not isinstance(loss_config, losses_pb2.LocalizationLoss):
        raise ValueError(
            'loss_config not of type losses_pb2.LocalizationLoss.')

    loss_type = loss_config.WhichOneof('localization_loss')

    if loss_type == 'weighted_l2':
        return losses.WeightedL2LocalizationLoss()

    if loss_type == 'weighted_smooth_l1':
        return losses.WeightedSmoothL1LocalizationLoss(
            loss_config.weighted_smooth_l1.delta)

    if loss_type == 'weighted_iou':
        return losses.WeightedIOULocalizationLoss()

    if loss_type == 'l1_localization_loss':
        return losses.L1LocalizationLoss()

    if loss_type == 'weighted_giou':
        return losses.WeightedGIOULocalizationLoss()

    raise ValueError('Empty loss config.')
コード例 #2
0
def build_center_net_meta_arch():

  feature_extractor = center_net_resnet_feature_extractor.\
                       CenterNetResnetFeatureExtractor('resnet_v2_101')
  image_resizer_fn = functools.partial(
      preprocessor.resize_to_range,
      min_dimension=128,
      max_dimension=128,
      pad_to_max_dimesnion=True)
  object_detection_params = center_net_meta_arch.ObjectDetectionParams(
      classification_loss=losses.PenaltyReducedLogisticFocalLoss(
          alpha=1.0, beta=1.0),
      classification_loss_weight=1.0,
      localization_loss=losses.L1LocalizationLoss(),
      offset_loss_weight=1.0,
      scale_loss_weight=0.1,
      min_box_overlap_iou=1.0,
      max_box_predictions=5)
  return center_net_meta_arch.CenterNetMetaArch(
      is_training=True,
      add_summaries=False,
      num_classes=10,
      feature_extractor=feature_extractor,
      image_resizer_fn=image_resizer_fn,
      object_detection_params=object_detection_params)
コード例 #3
0
def build_meta_arch(predict_full_resolution_masks=False, use_dice_loss=False):
    """Builds the DeepMAC meta architecture."""

    feature_extractor = DummyFeatureExtractor(channel_means=(1.0, 2.0, 3.0),
                                              channel_stds=(10., 20., 30.),
                                              bgr_ordering=False,
                                              num_feature_outputs=2,
                                              stride=4)
    image_resizer_fn = functools.partial(preprocessor.resize_to_range,
                                         min_dimension=128,
                                         max_dimension=128,
                                         pad_to_max_dimesnion=True)

    object_center_params = center_net_meta_arch.ObjectCenterParams(
        classification_loss=losses.WeightedSigmoidClassificationLoss(),
        object_center_loss_weight=1.0,
        min_box_overlap_iou=1.0,
        max_box_predictions=5,
        use_labeled_classes=False)

    if use_dice_loss:
        classification_loss = losses.WeightedDiceClassificationLoss(False)
    else:
        classification_loss = losses.WeightedSigmoidClassificationLoss()

    deepmac_params = deepmac_meta_arch.DeepMACParams(
        classification_loss=classification_loss,
        dim=8,
        task_loss_weight=1.0,
        pixel_embedding_dim=2,
        allowed_masked_classes_ids=[],
        mask_size=16,
        mask_num_subsamples=-1,
        use_xy=True,
        network_type='hourglass10',
        use_instance_embedding=True,
        num_init_channels=8,
        predict_full_resolution_masks=predict_full_resolution_masks,
        postprocess_crop_size=128,
        max_roi_jitter_ratio=0.0,
        roi_jitter_mode='random')

    object_detection_params = center_net_meta_arch.ObjectDetectionParams(
        localization_loss=losses.L1LocalizationLoss(),
        offset_loss_weight=1.0,
        scale_loss_weight=0.1)

    return deepmac_meta_arch.DeepMACMetaArch(
        is_training=True,
        add_summaries=False,
        num_classes=6,
        feature_extractor=feature_extractor,
        object_center_params=object_center_params,
        deepmac_params=deepmac_params,
        object_detection_params=object_detection_params,
        image_resizer_fn=image_resizer_fn)
コード例 #4
0
ファイル: losses_v2_test.py プロジェクト: VictorYu2015/models
    def test_returns_correct_loss(self):

        loss = losses.L1LocalizationLoss()

        pred = [[0.1, 0.2], [0.7, 0.5]]
        target = [[0.9, 1.0], [0.1, 0.4]]

        weights = [[1.0, 0.0], [1.0, 1.0]]

        np.testing.assert_allclose(loss._compute_loss(pred, target, weights),
                                   [[0.8, 0.0], [0.6, 0.1]],
                                   rtol=1e-6)
コード例 #5
0
def build_meta_arch(**override_params):
    """Builds the DeepMAC meta architecture."""

    params = dict(predict_full_resolution_masks=False,
                  use_instance_embedding=True,
                  mask_num_subsamples=-1,
                  network_type='hourglass10',
                  use_xy=True,
                  pixel_embedding_dim=2,
                  dice_loss_prediction_probability=False,
                  color_consistency_threshold=0.5,
                  use_dice_loss=False,
                  box_consistency_loss_normalize='normalize_auto',
                  box_consistency_tightness=False,
                  task_loss_weight=1.0,
                  color_consistency_loss_weight=1.0,
                  box_consistency_loss_weight=1.0,
                  num_init_channels=8,
                  dim=8,
                  allowed_masked_classes_ids=[],
                  mask_size=16,
                  postprocess_crop_size=128,
                  max_roi_jitter_ratio=0.0,
                  roi_jitter_mode='random',
                  color_consistency_dilation=2,
                  color_consistency_warmup_steps=0,
                  color_consistency_warmup_start=0)

    params.update(override_params)

    feature_extractor = DummyFeatureExtractor(channel_means=(1.0, 2.0, 3.0),
                                              channel_stds=(10., 20., 30.),
                                              bgr_ordering=False,
                                              num_feature_outputs=2,
                                              stride=4)
    image_resizer_fn = functools.partial(preprocessor.resize_to_range,
                                         min_dimension=128,
                                         max_dimension=128,
                                         pad_to_max_dimesnion=True)

    object_center_params = center_net_meta_arch.ObjectCenterParams(
        classification_loss=losses.WeightedSigmoidClassificationLoss(),
        object_center_loss_weight=1.0,
        min_box_overlap_iou=1.0,
        max_box_predictions=5,
        use_labeled_classes=False)

    use_dice_loss = params.pop('use_dice_loss')
    dice_loss_prediction_prob = params.pop('dice_loss_prediction_probability')
    if use_dice_loss:
        classification_loss = losses.WeightedDiceClassificationLoss(
            squared_normalization=False,
            is_prediction_probability=dice_loss_prediction_prob)
    else:
        classification_loss = losses.WeightedSigmoidClassificationLoss()

    deepmac_params = deepmac_meta_arch.DeepMACParams(
        classification_loss=classification_loss, **params)

    object_detection_params = center_net_meta_arch.ObjectDetectionParams(
        localization_loss=losses.L1LocalizationLoss(),
        offset_loss_weight=1.0,
        scale_loss_weight=0.1)

    return deepmac_meta_arch.DeepMACMetaArch(
        is_training=True,
        add_summaries=False,
        num_classes=6,
        feature_extractor=feature_extractor,
        object_center_params=object_center_params,
        deepmac_params=deepmac_params,
        object_detection_params=object_detection_params,
        image_resizer_fn=image_resizer_fn)