示例#1
0
    def __init__(self, params):
        super(AttributeMaskrcnnModel, self).__init__(params)

        self._params = params

        self._include_mask = params.architecture.include_mask

        # Architecture generators.
        self._backbone_fn = factory.backbone_generator(params)
        self._fpn_fn = factory.multilevel_features_generator(params)
        self._rpn_head_fn = factory.rpn_head_generator(params)
        self._generate_rois_fn = roi_ops.ROIGenerator(params.roi_proposal)
        self._sample_rois_fn = roi_sampler.ROISampler(params.roi_sampling)
        self._sample_masks_fn = target_ops.MaskSampler(
            params.architecture.mask_target_size,
            params.mask_sampling.num_mask_samples_per_image)

        self._frcnn_head_fn = factory.fast_rcnn_head_generator(params)
        if self._include_mask:
            self._mrcnn_head_fn = factory.mask_rcnn_head_generator(params)

        # Loss function.
        self._rpn_score_loss_fn = losses.RpnScoreLoss(params.rpn_score_loss)
        self._rpn_box_loss_fn = losses.RpnBoxLoss(params.rpn_box_loss)
        self._frcnn_class_loss_fn = losses.FastrcnnClassLoss()
        self._frcnn_attribute_loss_fn = attribute_loss.FastrcnnAttributeLoss()
        self._frcnn_box_loss_fn = losses.FastrcnnBoxLoss(params.frcnn_box_loss)
        if self._include_mask:
            self._mask_loss_fn = losses.MaskrcnnLoss()

        self._generate_detections_fn = postprocess_ops.GenericDetectionGenerator(
            params.postprocess)
示例#2
0
    def __init__(self, params):
        super(MaskrcnnModel, self).__init__(params)

        self._anchor_params = params.anchor

        self._include_mask = params.architecture.include_mask

        # Architecture generators.
        self._backbone_fn = factory.backbone_generator(params)
        self._fpn_fn = factory.multilevel_features_generator(params)
        self._rpn_head_fn = factory.rpn_head_generator(params.rpn_head)
        self._generate_rois_fn = roi_ops.ROIGenerator(params.roi_proposal)
        self._sample_rois_fn = sampling_ops.ROISampler(params.roi_sampling)
        self._sample_masks_fn = sampling_ops.MaskSampler(params.mask_sampling)

        self._frcnn_head_fn = factory.fast_rcnn_head_generator(
            params.frcnn_head)
        if self._include_mask:
            self._mrcnn_head_fn = factory.mask_rcnn_head_generator(
                params.mrcnn_head)

        # Loss function.
        self._rpn_score_loss_fn = losses.RpnScoreLoss(params.rpn_score_loss)
        self._rpn_box_loss_fn = losses.RpnBoxLoss(params.rpn_box_loss)
        self._frcnn_class_loss_fn = losses.FastrcnnClassLoss()
        self._frcnn_box_loss_fn = losses.FastrcnnBoxLoss(params.frcnn_box_loss)
        if self._include_mask:
            self._mask_loss_fn = losses.MaskrcnnLoss()

        self._generate_detections_fn = postprocess_ops.GenericDetectionGenerator(
            params.postprocess)

        self._transpose_input = params.train.transpose_input
    def __init__(self, params):
        super(CascadeMaskrcnnModel, self).__init__(params)

        self._params = params

        self._include_mask = params.architecture.include_mask

        # Architecture generators.
        self._backbone_fn = factory.backbone_generator(params)
        self._fpn_fn = factory.multilevel_features_generator(params)
        self._rpn_head_fn = factory.rpn_head_generator(params)
        self._generate_rois_fn = roi_ops.ROIGenerator(params.roi_proposal)
        self._sample_rois_fn = target_ops.ROISampler(params.roi_sampling)
        self._sample_masks_fn = target_ops.MaskSampler(
            params.architecture.mask_target_size,
            params.mask_sampling.num_mask_samples_per_image)

        self._frcnn_head_fn = factory.fast_rcnn_head_generator(params)
        if self._include_mask:
            self._mrcnn_head_fn = factory.mask_rcnn_head_generator(params)

        # Loss function.
        self._rpn_score_loss_fn = losses.RpnScoreLoss(params.rpn_score_loss)
        self._rpn_box_loss_fn = losses.RpnBoxLoss(params.rpn_box_loss)
        self._frcnn_class_loss_fn = losses.FastrcnnClassLoss()
        self._frcnn_box_loss_fn = losses.FastrcnnBoxLoss(
            params.frcnn_box_loss, params.frcnn_head.class_agnostic_bbox_pred)
        if self._include_mask:
            self._mask_loss_fn = losses.MaskrcnnLoss()

        # IoU thresholds for additional FRCNN heads in Cascade mode. 'fg_iou_thresh'
        # is the first threshold.
        self._cascade_iou_thresholds = params.roi_sampling.cascade_iou_thresholds
        self._num_roi_samples = params.roi_sampling.num_samples_per_image
        # Weights for the regression losses for each FRCNN layer.
        # TODO(golnazg): makes this param configurable.
        self._cascade_layer_to_weights = [
            [10.0, 10.0, 5.0, 5.0],
            [20.0, 20.0, 10.0, 10.0],
            [30.0, 30.0, 15.0, 15.0],
        ]
        self._class_agnostic_bbox_pred = params.frcnn_head.class_agnostic_bbox_pred
        self._cascade_class_ensemble = params.frcnn_head.cascade_class_ensemble

        self._generate_detections_fn = postprocess_ops.GenericDetectionGenerator(
            params.postprocess)