Пример #1
0
def build_meta_arch(predict_full_resolution_masks=False, use_dice_loss=False):
    """Builds the DeepMAC meta architecture."""

    feature_extractor = DummyFeatureExtractor(channel_means=(1.0, 2.0, 3.0),
                                              channel_stds=(10., 20., 30.),
                                              bgr_ordering=False,
                                              num_feature_outputs=2,
                                              stride=4)
    image_resizer_fn = functools.partial(preprocessor.resize_to_range,
                                         min_dimension=128,
                                         max_dimension=128,
                                         pad_to_max_dimesnion=True)

    object_center_params = center_net_meta_arch.ObjectCenterParams(
        classification_loss=losses.WeightedSigmoidClassificationLoss(),
        object_center_loss_weight=1.0,
        min_box_overlap_iou=1.0,
        max_box_predictions=5,
        use_labeled_classes=False)

    if use_dice_loss:
        classification_loss = losses.WeightedDiceClassificationLoss(False)
    else:
        classification_loss = losses.WeightedSigmoidClassificationLoss()

    deepmac_params = deepmac_meta_arch.DeepMACParams(
        classification_loss=classification_loss,
        dim=8,
        task_loss_weight=1.0,
        pixel_embedding_dim=2,
        allowed_masked_classes_ids=[],
        mask_size=16,
        mask_num_subsamples=-1,
        use_xy=True,
        network_type='hourglass10',
        use_instance_embedding=True,
        num_init_channels=8,
        predict_full_resolution_masks=predict_full_resolution_masks,
        postprocess_crop_size=128,
        max_roi_jitter_ratio=0.0,
        roi_jitter_mode='random')

    object_detection_params = center_net_meta_arch.ObjectDetectionParams(
        localization_loss=losses.L1LocalizationLoss(),
        offset_loss_weight=1.0,
        scale_loss_weight=0.1)

    return deepmac_meta_arch.DeepMACMetaArch(
        is_training=True,
        add_summaries=False,
        num_classes=6,
        feature_extractor=feature_extractor,
        object_center_params=object_center_params,
        deepmac_params=deepmac_params,
        object_detection_params=object_detection_params,
        image_resizer_fn=image_resizer_fn)
Пример #2
0
def object_center_proto_to_params(oc_config):
    """Converts CenterNet.ObjectCenter proto to parameter namedtuple."""
    loss = losses_pb2.Loss()
    # Add dummy localization loss to avoid the loss_builder throwing error.
    # TODO(yuhuic): update the loss builder to take the localization loss
    # directly.
    loss.localization_loss.weighted_l2.CopyFrom(
        losses_pb2.WeightedL2LocalizationLoss())
    loss.classification_loss.CopyFrom(oc_config.classification_loss)
    classification_loss, _, _, _, _, _, _ = (losses_builder.build(loss))
    return center_net_meta_arch.ObjectCenterParams(
        classification_loss=classification_loss,
        object_center_loss_weight=oc_config.object_center_loss_weight,
        heatmap_bias_init=oc_config.heatmap_bias_init,
        min_box_overlap_iou=oc_config.min_box_overlap_iou,
        max_box_predictions=oc_config.max_box_predictions,
        use_labeled_classes=oc_config.use_labeled_classes)
Пример #3
0
def build_meta_arch(**override_params):
    """Builds the DeepMAC meta architecture."""

    params = dict(predict_full_resolution_masks=False,
                  use_instance_embedding=True,
                  mask_num_subsamples=-1,
                  network_type='hourglass10',
                  use_xy=True,
                  pixel_embedding_dim=2,
                  dice_loss_prediction_probability=False,
                  color_consistency_threshold=0.5,
                  use_dice_loss=False,
                  box_consistency_loss_normalize='normalize_auto',
                  box_consistency_tightness=False,
                  task_loss_weight=1.0,
                  color_consistency_loss_weight=1.0,
                  box_consistency_loss_weight=1.0,
                  num_init_channels=8,
                  dim=8,
                  allowed_masked_classes_ids=[],
                  mask_size=16,
                  postprocess_crop_size=128,
                  max_roi_jitter_ratio=0.0,
                  roi_jitter_mode='random',
                  color_consistency_dilation=2,
                  color_consistency_warmup_steps=0,
                  color_consistency_warmup_start=0)

    params.update(override_params)

    feature_extractor = DummyFeatureExtractor(channel_means=(1.0, 2.0, 3.0),
                                              channel_stds=(10., 20., 30.),
                                              bgr_ordering=False,
                                              num_feature_outputs=2,
                                              stride=4)
    image_resizer_fn = functools.partial(preprocessor.resize_to_range,
                                         min_dimension=128,
                                         max_dimension=128,
                                         pad_to_max_dimesnion=True)

    object_center_params = center_net_meta_arch.ObjectCenterParams(
        classification_loss=losses.WeightedSigmoidClassificationLoss(),
        object_center_loss_weight=1.0,
        min_box_overlap_iou=1.0,
        max_box_predictions=5,
        use_labeled_classes=False)

    use_dice_loss = params.pop('use_dice_loss')
    dice_loss_prediction_prob = params.pop('dice_loss_prediction_probability')
    if use_dice_loss:
        classification_loss = losses.WeightedDiceClassificationLoss(
            squared_normalization=False,
            is_prediction_probability=dice_loss_prediction_prob)
    else:
        classification_loss = losses.WeightedSigmoidClassificationLoss()

    deepmac_params = deepmac_meta_arch.DeepMACParams(
        classification_loss=classification_loss, **params)

    object_detection_params = center_net_meta_arch.ObjectDetectionParams(
        localization_loss=losses.L1LocalizationLoss(),
        offset_loss_weight=1.0,
        scale_loss_weight=0.1)

    return deepmac_meta_arch.DeepMACMetaArch(
        is_training=True,
        add_summaries=False,
        num_classes=6,
        feature_extractor=feature_extractor,
        object_center_params=object_center_params,
        deepmac_params=deepmac_params,
        object_detection_params=object_detection_params,
        image_resizer_fn=image_resizer_fn)