Exemplo n.º 1
0
    def __init__(self, params):
        super(ShapeMaskModel, self).__init__(params)

        self._params = params

        # Architecture generators.
        self._backbone_fn = factory.backbone_generator(params)
        self._fpn_fn = factory.multilevel_features_generator(params)
        self._retinanet_head_fn = factory.retinanet_head_generator(params)
        self._shape_prior_head_fn = factory.shapeprior_head_generator(params)
        self._coarse_mask_fn = factory.coarsemask_head_generator(params)
        self._fine_mask_fn = factory.finemask_head_generator(params)

        self._outer_box_scale = params.architecture.outer_box_scale

        # Loss function.
        self._cls_loss_fn = losses.RetinanetClassLoss(
            params.retinanet_loss, params.architecture.num_classes)
        self._box_loss_fn = losses.RetinanetBoxLoss(params.retinanet_loss)
        self._box_loss_weight = params.retinanet_loss.box_loss_weight
        # Mask loss function.
        self._shapemask_prior_loss_fn = losses.ShapemaskMseLoss()
        self._shapemask_loss_fn = losses.ShapemaskLoss()
        self._shape_prior_loss_weight = (
            params.shapemask_loss.shape_prior_loss_weight)
        self._coarse_mask_loss_weight = (
            params.shapemask_loss.coarse_mask_loss_weight)
        self._fine_mask_loss_weight = (
            params.shapemask_loss.fine_mask_loss_weight)
        # Predict function.
        self._generate_detections_fn = postprocess_ops.MultilevelDetectionGenerator(
            params.architecture.min_level, params.architecture.max_level,
            params.postprocess)
Exemplo n.º 2
0
    def __init__(self, params):
        super(MaskrcnnModel, self).__init__(params)

        self._anchor_params = params.anchor

        self._include_mask = params.architecture.include_mask

        # Architecture generators.
        self._backbone_fn = factory.backbone_generator(params)
        self._fpn_fn = factory.multilevel_features_generator(params)
        self._rpn_head_fn = factory.rpn_head_generator(params.rpn_head)
        self._generate_rois_fn = roi_ops.ROIGenerator(params.roi_proposal)
        self._sample_rois_fn = sampling_ops.ROISampler(params.roi_sampling)
        self._sample_masks_fn = sampling_ops.MaskSampler(params.mask_sampling)

        self._frcnn_head_fn = factory.fast_rcnn_head_generator(
            params.frcnn_head)
        if self._include_mask:
            self._mrcnn_head_fn = factory.mask_rcnn_head_generator(
                params.mrcnn_head)

        # Loss function.
        self._rpn_score_loss_fn = losses.RpnScoreLoss(params.rpn_score_loss)
        self._rpn_box_loss_fn = losses.RpnBoxLoss(params.rpn_box_loss)
        self._frcnn_class_loss_fn = losses.FastrcnnClassLoss()
        self._frcnn_box_loss_fn = losses.FastrcnnBoxLoss(params.frcnn_box_loss)
        if self._include_mask:
            self._mask_loss_fn = losses.MaskrcnnLoss()

        self._generate_detections_fn = postprocess_ops.GenericDetectionGenerator(
            params.postprocess)

        self._transpose_input = params.train.transpose_input
Exemplo n.º 3
0
    def __init__(self, params):
        super(RetinanetModel, self).__init__(params)

        # For eval metrics.
        self._params = params

        # Architecture generators.
        self._backbone_fn = factory.backbone_generator(params)
        self._fpn_fn = factory.multilevel_features_generator(params)
        self._head_fn = factory.retinanet_head_generator(params.retinanet_head)

        # Loss function.
        self._cls_loss_fn = losses.RetinanetClassLoss(params.retinanet_loss)
        self._box_loss_fn = losses.RetinanetBoxLoss(params.retinanet_loss)
        self._box_loss_weight = params.retinanet_loss.box_loss_weight
        self._keras_model = None

        # Predict function.
        self._generate_detections_fn = postprocess_ops.MultilevelDetectionGenerator(
            params.postprocess)

        self._transpose_input = params.train.transpose_input
        assert not self._transpose_input, 'Transpose input is not supportted.'
        # Input layer.
        input_shape = (params.retinanet_parser.output_size +
                       [params.retinanet_parser.num_channels])
        self._input_layer = tf.keras.layers.Input(
            shape=input_shape,
            name='',
            batch_size=self._params.train.batch_size,
            dtype=tf.bfloat16 if self._use_bfloat16 else tf.float32)
Exemplo n.º 4
0
    def __init__(self, params):
        super(AttributeMaskrcnnModel, self).__init__(params)

        self._params = params

        self._include_mask = params.architecture.include_mask

        # Architecture generators.
        self._backbone_fn = factory.backbone_generator(params)
        self._fpn_fn = factory.multilevel_features_generator(params)
        self._rpn_head_fn = factory.rpn_head_generator(params)
        self._generate_rois_fn = roi_ops.ROIGenerator(params.roi_proposal)
        self._sample_rois_fn = roi_sampler.ROISampler(params.roi_sampling)
        self._sample_masks_fn = target_ops.MaskSampler(
            params.architecture.mask_target_size,
            params.mask_sampling.num_mask_samples_per_image)

        self._frcnn_head_fn = factory.fast_rcnn_head_generator(params)
        if self._include_mask:
            self._mrcnn_head_fn = factory.mask_rcnn_head_generator(params)

        # Loss function.
        self._rpn_score_loss_fn = losses.RpnScoreLoss(params.rpn_score_loss)
        self._rpn_box_loss_fn = losses.RpnBoxLoss(params.rpn_box_loss)
        self._frcnn_class_loss_fn = losses.FastrcnnClassLoss()
        self._frcnn_attribute_loss_fn = attribute_loss.FastrcnnAttributeLoss()
        self._frcnn_box_loss_fn = losses.FastrcnnBoxLoss(params.frcnn_box_loss)
        if self._include_mask:
            self._mask_loss_fn = losses.MaskrcnnLoss()

        self._generate_detections_fn = postprocess_ops.GenericDetectionGenerator(
            params.postprocess)
Exemplo n.º 5
0
    def __init__(self, params):
        super(SegmentationModel, self).__init__(params)

        # Architecture generators.
        self._backbone_fn = factory.backbone_generator(params)
        self._fpn_fn = factory.multilevel_features_generator(params)
        self._head_fn = factory.segmentation_head_generator(
            params.segmentation_head)
        self._num_classes = params.segmentation_head.num_classes

        # Loss function.
        self._loss_fn = losses.SegmentationLoss(params.segmentation_loss)

        self._l2_weight_decay = params.train.l2_weight_decay
        self._transpose_input = params.train.transpose_input
    def __init__(self, params):
        super(SegmentationModel, self).__init__(params)

        # Architecture generators.
        self._backbone_fn = factory.backbone_generator(params)
        self._fpn_fn = factory.multilevel_features_generator(params)
        self._head_fn = factory.segmentation_head_generator(params)
        self._num_classes = params.architecture.num_classes
        self._level = params.segmentation_head.level

        # Loss function.
        self._loss_fn = losses.SegmentationLoss(params.segmentation_loss)

        self._use_aspp = params.architecture.use_aspp
        self._use_pyramid_fusion = params.architecture.use_pyramid_fusion
Exemplo n.º 7
0
    def __init__(self, params):
        super(ViLDModel, self).__init__(params)

        self._params = params

        self._include_mask = params.architecture.include_mask

        self._losses = params.train.losses
        # feature distill
        self._feat_distill = params.architecture.visual_feature_distill
        if self._feat_distill == 'None':
            self._feat_distill = None
        self._feat_distill_dim = params.architecture.visual_feature_dim
        self._max_distill_rois = params.architecture.max_num_rois
        self._feat_distill_weight = params.architecture.feat_distill_weight
        self._normalize_feat_during_training = params.architecture.normalize_feat_during_training

        # Architecture generators.
        self._backbone_fn = factory.backbone_generator(params)
        self._fpn_fn = factory.multilevel_features_generator(params)
        self._rpn_head_fn = factory.rpn_head_generator(params)
        self._generate_rois_fn = roi_ops.ROIGenerator(params.roi_proposal)
        self._sample_rois_fn = target_ops.ROISampler(params.roi_sampling)
        self._sample_masks_fn = target_ops.MaskSampler(
            params.architecture.mask_target_size,
            params.mask_sampling.num_mask_samples_per_image)

        self._frcnn_head_fn = factory.vild_fast_rcnn_head_generator(params)
        if self._include_mask:
            self._mrcnn_head_fn = factory.mask_rcnn_head_generator(params)

        # Loss function.
        self._rpn_score_loss_fn = losses.RpnScoreLoss(params.rpn_score_loss)
        self._rpn_box_loss_fn = losses.RpnBoxLoss(params.rpn_box_loss)
        self._frcnn_class_loss_fn = vild_losses.FastrcnnClassLoss(
            params.frcnn_class_loss)
        self._frcnn_box_loss_fn = losses.FastrcnnBoxLoss(
            params.frcnn_box_loss,
            class_agnostic_bbox_pred=params.frcnn_head.class_agnostic_bbox_pred
        )
        if self._include_mask:
            self._mask_loss_fn = losses.MaskrcnnLoss()

        self._generate_detections_fn = postprocess_ops.GenericDetectionGenerator(
            params.postprocess,
            discard_background=params.postprocess.discard_background,
            visual_feature_distill=self._feat_distill)
Exemplo n.º 8
0
    def __init__(self, params):
        super(CascadeMaskrcnnModel, self).__init__(params)

        self._params = params

        self._include_mask = params.architecture.include_mask

        # Architecture generators.
        self._backbone_fn = factory.backbone_generator(params)
        self._fpn_fn = factory.multilevel_features_generator(params)
        self._rpn_head_fn = factory.rpn_head_generator(params)
        self._generate_rois_fn = roi_ops.ROIGenerator(params.roi_proposal)
        self._sample_rois_fn = target_ops.ROISampler(params.roi_sampling)
        self._sample_masks_fn = target_ops.MaskSampler(
            params.architecture.mask_target_size,
            params.mask_sampling.num_mask_samples_per_image)

        self._frcnn_head_fn = factory.fast_rcnn_head_generator(params)
        if self._include_mask:
            self._mrcnn_head_fn = factory.mask_rcnn_head_generator(params)

        # Loss function.
        self._rpn_score_loss_fn = losses.RpnScoreLoss(params.rpn_score_loss)
        self._rpn_box_loss_fn = losses.RpnBoxLoss(params.rpn_box_loss)
        self._frcnn_class_loss_fn = losses.FastrcnnClassLoss()
        self._frcnn_box_loss_fn = losses.FastrcnnBoxLoss(
            params.frcnn_box_loss, params.frcnn_head.class_agnostic_bbox_pred)
        if self._include_mask:
            self._mask_loss_fn = losses.MaskrcnnLoss()

        # IoU thresholds for additional FRCNN heads in Cascade mode. 'fg_iou_thresh'
        # is the first threshold.
        self._cascade_iou_thresholds = params.roi_sampling.cascade_iou_thresholds
        self._num_roi_samples = params.roi_sampling.num_samples_per_image
        # Weights for the regression losses for each FRCNN layer.
        # TODO(golnazg): makes this param configurable.
        self._cascade_layer_to_weights = [
            [10.0, 10.0, 5.0, 5.0],
            [20.0, 20.0, 10.0, 10.0],
            [30.0, 30.0, 15.0, 15.0],
        ]
        self._class_agnostic_bbox_pred = params.frcnn_head.class_agnostic_bbox_pred
        self._cascade_class_ensemble = params.frcnn_head.cascade_class_ensemble

        self._generate_detections_fn = postprocess_ops.GenericDetectionGenerator(
            params.postprocess)
Exemplo n.º 9
0
    def __init__(self, params):
        super(RetinanetModel, self).__init__(params)

        # Architecture generators.
        self._backbone_fn = factory.backbone_generator(params)
        self._fpn_fn = factory.multilevel_features_generator(params)
        self._head_fn = factory.retinanet_head_generator(params.retinanet_head)

        # Loss function.
        self._cls_loss_fn = losses.RetinanetClassLoss(params.retinanet_loss)
        self._box_loss_fn = losses.RetinanetBoxLoss(params.retinanet_loss)
        self._box_loss_weight = params.retinanet_loss.box_loss_weight

        # Predict function.
        self._generate_detections_fn = postprocess_ops.MultilevelDetectionGenerator(
            params.postprocess)

        self._transpose_input = params.train.transpose_input
Exemplo n.º 10
0
    def __init__(self, params):
        super(RetinanetModel, self).__init__(params)

        # Architecture generators.
        self._backbone_fn = factory.backbone_generator(params)
        self._fpn_fn = factory.multilevel_features_generator(params)
        self._head_fn = factory.retinanet_head_generator(params.retinanet_head)

        # Loss function.
        self._cls_loss_fn = losses.RetinanetClassLoss(params.retinanet_loss)
        self._box_loss_fn = losses.RetinanetBoxLoss(params.retinanet_loss)
        self._box_loss_weight = params.retinanet_loss.box_loss_weight

        # Predict function.
        self._generate_detections_fn = postprocess.GenerateOneStageDetections(
            params.postprocess)

        self._l2_weight_decay = params.architecture.l2_weight_decay
Exemplo n.º 11
0
    def __init__(self, params):
        super(RetinanetModel, self).__init__(params)

        self._params = params

        # Architecture generators.
        self._backbone_fn = factory.backbone_generator(params)
        self._fpn_fn = factory.multilevel_features_generator(params)
        self._head_fn = factory.retinanet_head_generator(params)

        # Loss function.
        self._cls_loss_fn = losses.RetinanetClassLoss(
            params.retinanet_loss, params.architecture.num_classes)
        self._box_loss_fn = losses.RetinanetBoxLoss(params.retinanet_loss)
        self._box_loss_weight = params.retinanet_loss.box_loss_weight
        self._focal_loss_normalizer_momentum = (
            params.retinanet_loss.normalizer_momentum)

        # Predict function.
        self._generate_detections_fn = postprocess_ops.MultilevelDetectionGenerator(
            params.architecture.min_level, params.architecture.max_level,
            params.postprocess)