def build_center_net_meta_arch():

  feature_extractor = center_net_resnet_feature_extractor.\
                       CenterNetResnetFeatureExtractor('resnet_v2_101')
  image_resizer_fn = functools.partial(
      preprocessor.resize_to_range,
      min_dimension=128,
      max_dimension=128,
      pad_to_max_dimesnion=True)
  object_detection_params = center_net_meta_arch.ObjectDetectionParams(
      classification_loss=losses.PenaltyReducedLogisticFocalLoss(
          alpha=1.0, beta=1.0),
      classification_loss_weight=1.0,
      localization_loss=losses.L1LocalizationLoss(),
      offset_loss_weight=1.0,
      scale_loss_weight=0.1,
      min_box_overlap_iou=1.0,
      max_box_predictions=5)
  return center_net_meta_arch.CenterNetMetaArch(
      is_training=True,
      add_summaries=False,
      num_classes=10,
      feature_extractor=feature_extractor,
      image_resizer_fn=image_resizer_fn,
      object_detection_params=object_detection_params)
Exemple #2
0
def build_meta_arch(predict_full_resolution_masks=False, use_dice_loss=False):
    """Builds the DeepMAC meta architecture."""

    feature_extractor = DummyFeatureExtractor(channel_means=(1.0, 2.0, 3.0),
                                              channel_stds=(10., 20., 30.),
                                              bgr_ordering=False,
                                              num_feature_outputs=2,
                                              stride=4)
    image_resizer_fn = functools.partial(preprocessor.resize_to_range,
                                         min_dimension=128,
                                         max_dimension=128,
                                         pad_to_max_dimesnion=True)

    object_center_params = center_net_meta_arch.ObjectCenterParams(
        classification_loss=losses.WeightedSigmoidClassificationLoss(),
        object_center_loss_weight=1.0,
        min_box_overlap_iou=1.0,
        max_box_predictions=5,
        use_labeled_classes=False)

    if use_dice_loss:
        classification_loss = losses.WeightedDiceClassificationLoss(False)
    else:
        classification_loss = losses.WeightedSigmoidClassificationLoss()

    deepmac_params = deepmac_meta_arch.DeepMACParams(
        classification_loss=classification_loss,
        dim=8,
        task_loss_weight=1.0,
        pixel_embedding_dim=2,
        allowed_masked_classes_ids=[],
        mask_size=16,
        mask_num_subsamples=-1,
        use_xy=True,
        network_type='hourglass10',
        use_instance_embedding=True,
        num_init_channels=8,
        predict_full_resolution_masks=predict_full_resolution_masks,
        postprocess_crop_size=128,
        max_roi_jitter_ratio=0.0,
        roi_jitter_mode='random')

    object_detection_params = center_net_meta_arch.ObjectDetectionParams(
        localization_loss=losses.L1LocalizationLoss(),
        offset_loss_weight=1.0,
        scale_loss_weight=0.1)

    return deepmac_meta_arch.DeepMACMetaArch(
        is_training=True,
        add_summaries=False,
        num_classes=6,
        feature_extractor=feature_extractor,
        object_center_params=object_center_params,
        deepmac_params=deepmac_params,
        object_detection_params=object_detection_params,
        image_resizer_fn=image_resizer_fn)
Exemple #3
0
def object_detection_proto_to_params(od_config):
    """Converts CenterNet.ObjectDetection proto to parameter namedtuple."""
    loss = losses_pb2.Loss()
    # Add dummy classification loss to avoid the loss_builder throwing error.
    # TODO(yuhuic): update the loss builder to take the classification loss
    # directly.
    loss.classification_loss.weighted_sigmoid.CopyFrom(
        losses_pb2.WeightedSigmoidClassificationLoss())
    loss.localization_loss.CopyFrom(od_config.localization_loss)
    _, localization_loss, _, _, _, _, _ = (losses_builder.build(loss))
    return center_net_meta_arch.ObjectDetectionParams(
        localization_loss=localization_loss,
        scale_loss_weight=od_config.scale_loss_weight,
        offset_loss_weight=od_config.offset_loss_weight,
        task_loss_weight=od_config.task_loss_weight)
Exemple #4
0
def build_meta_arch(**override_params):
    """Builds the DeepMAC meta architecture."""

    params = dict(predict_full_resolution_masks=False,
                  use_instance_embedding=True,
                  mask_num_subsamples=-1,
                  network_type='hourglass10',
                  use_xy=True,
                  pixel_embedding_dim=2,
                  dice_loss_prediction_probability=False,
                  color_consistency_threshold=0.5,
                  use_dice_loss=False,
                  box_consistency_loss_normalize='normalize_auto',
                  box_consistency_tightness=False,
                  task_loss_weight=1.0,
                  color_consistency_loss_weight=1.0,
                  box_consistency_loss_weight=1.0,
                  num_init_channels=8,
                  dim=8,
                  allowed_masked_classes_ids=[],
                  mask_size=16,
                  postprocess_crop_size=128,
                  max_roi_jitter_ratio=0.0,
                  roi_jitter_mode='random',
                  color_consistency_dilation=2,
                  color_consistency_warmup_steps=0,
                  color_consistency_warmup_start=0)

    params.update(override_params)

    feature_extractor = DummyFeatureExtractor(channel_means=(1.0, 2.0, 3.0),
                                              channel_stds=(10., 20., 30.),
                                              bgr_ordering=False,
                                              num_feature_outputs=2,
                                              stride=4)
    image_resizer_fn = functools.partial(preprocessor.resize_to_range,
                                         min_dimension=128,
                                         max_dimension=128,
                                         pad_to_max_dimesnion=True)

    object_center_params = center_net_meta_arch.ObjectCenterParams(
        classification_loss=losses.WeightedSigmoidClassificationLoss(),
        object_center_loss_weight=1.0,
        min_box_overlap_iou=1.0,
        max_box_predictions=5,
        use_labeled_classes=False)

    use_dice_loss = params.pop('use_dice_loss')
    dice_loss_prediction_prob = params.pop('dice_loss_prediction_probability')
    if use_dice_loss:
        classification_loss = losses.WeightedDiceClassificationLoss(
            squared_normalization=False,
            is_prediction_probability=dice_loss_prediction_prob)
    else:
        classification_loss = losses.WeightedSigmoidClassificationLoss()

    deepmac_params = deepmac_meta_arch.DeepMACParams(
        classification_loss=classification_loss, **params)

    object_detection_params = center_net_meta_arch.ObjectDetectionParams(
        localization_loss=losses.L1LocalizationLoss(),
        offset_loss_weight=1.0,
        scale_loss_weight=0.1)

    return deepmac_meta_arch.DeepMACMetaArch(
        is_training=True,
        add_summaries=False,
        num_classes=6,
        feature_extractor=feature_extractor,
        object_center_params=object_center_params,
        deepmac_params=deepmac_params,
        object_detection_params=object_detection_params,
        image_resizer_fn=image_resizer_fn)