Exemple #1
0
    def __init__(self, params):
        super(RetinanetModel, self).__init__(params)

        # For eval metrics.
        self._params = params

        # Architecture generators.
        self._backbone_fn = factory.backbone_generator(params)
        self._fpn_fn = factory.multilevel_features_generator(params)
        self._head_fn = factory.retinanet_head_generator(params.retinanet_head)

        # Loss function.
        self._cls_loss_fn = losses.RetinanetClassLoss(params.retinanet_loss)
        self._box_loss_fn = losses.RetinanetBoxLoss(params.retinanet_loss)
        self._box_loss_weight = params.retinanet_loss.box_loss_weight
        self._keras_model = None

        # Predict function.
        self._generate_detections_fn = postprocess_ops.MultilevelDetectionGenerator(
            params.postprocess)

        self._transpose_input = params.train.transpose_input
        assert not self._transpose_input, 'Transpose input is not supportted.'
        # Input layer.
        input_shape = (params.retinanet_parser.output_size +
                       [params.retinanet_parser.num_channels])
        self._input_layer = tf.keras.layers.Input(
            shape=input_shape,
            name='',
            dtype=tf.bfloat16 if self._use_bfloat16 else tf.float32)
Exemple #2
0
    def __init__(self, params):
        super(ShapeMaskModel, self).__init__(params)

        self._params = params
        self._keras_model = None

        # Architecture generators.
        self._backbone_fn = factory.backbone_generator(params)
        self._fpn_fn = factory.multilevel_features_generator(params)
        self._retinanet_head_fn = factory.retinanet_head_generator(params)
        self._shape_prior_head_fn = factory.shapeprior_head_generator(params)
        self._coarse_mask_fn = factory.coarsemask_head_generator(params)
        self._fine_mask_fn = factory.finemask_head_generator(params)

        # Loss functions.
        self._cls_loss_fn = losses.RetinanetClassLoss(
            params.retinanet_loss, params.architecture.num_classes)
        self._box_loss_fn = losses.RetinanetBoxLoss(params.retinanet_loss)
        self._box_loss_weight = params.retinanet_loss.box_loss_weight

        # Mask loss function.
        self._shapemask_prior_loss_fn = losses.ShapemaskMseLoss()
        self._shapemask_loss_fn = losses.ShapemaskLoss()
        self._shape_prior_loss_weight = (
            params.shapemask_loss.shape_prior_loss_weight)
        self._coarse_mask_loss_weight = (
            params.shapemask_loss.coarse_mask_loss_weight)
        self._fine_mask_loss_weight = (
            params.shapemask_loss.fine_mask_loss_weight)

        # Predict function.
        self._generate_detections_fn = postprocess_ops.MultilevelDetectionGenerator(
            params.architecture.min_level, params.architecture.max_level,
            params.postprocess)
Exemple #3
0
  def __init__(self, params):
    super(RetinanetModel, self).__init__(params)

    # For eval metrics.
    self._params = params

    # Architecture generators.
    self._backbone_fn = factory.backbone_generator(params)
    self._fpn_fn = factory.multilevel_features_generator(params)
    self._head_fn = factory.retinanet_head_generator(params)

    # Loss function.
    self._cls_loss_fn = losses.RetinanetClassLoss(
        params.retinanet_loss, params.architecture.num_classes)
    self._box_loss_fn = losses.RetinanetBoxLoss(params.retinanet_loss)
    self._box_loss_weight = params.retinanet_loss.box_loss_weight
    self._keras_model = None

    # Predict function.
    self._generate_detections_fn = postprocess_ops.MultilevelDetectionGenerator(
        params.architecture.min_level,
        params.architecture.max_level,
        params.postprocess)

    self._transpose_input = params.train.transpose_input
    assert not self._transpose_input, 'Transpose input is not supported.'
  def __init__(self, params):
    super(MaskrcnnModel, self).__init__(params)

    # For eval metrics.
    self._params = params
    self._keras_model = None

    self._include_mask = params.architecture.include_mask

    # Architecture generators.
    self._backbone_fn = factory.backbone_generator(params)
    self._fpn_fn = factory.multilevel_features_generator(params)
    self._rpn_head_fn = factory.rpn_head_generator(params)
    self._generate_rois_fn = roi_ops.ROIGenerator(params.roi_proposal)
    self._sample_rois_fn = target_ops.ROISampler(params.roi_sampling)
    self._sample_masks_fn = target_ops.MaskSampler(
        params.architecture.mask_target_size,
        params.mask_sampling.num_mask_samples_per_image)

    self._frcnn_head_fn = factory.fast_rcnn_head_generator(params)
    if self._include_mask:
      self._mrcnn_head_fn = factory.mask_rcnn_head_generator(params)

    # Loss function.
    self._rpn_score_loss_fn = losses.RpnScoreLoss(params.rpn_score_loss)
    self._rpn_box_loss_fn = losses.RpnBoxLoss(params.rpn_box_loss)
    self._frcnn_class_loss_fn = losses.FastrcnnClassLoss()
    self._frcnn_box_loss_fn = losses.FastrcnnBoxLoss(params.frcnn_box_loss)
    if self._include_mask:
      self._mask_loss_fn = losses.MaskrcnnLoss()

    self._generate_detections_fn = postprocess_ops.GenericDetectionGenerator(
        params.postprocess)

    self._transpose_input = params.train.transpose_input
    assert not self._transpose_input, 'Transpose input is not supportted.'
Exemple #5
0
    def __init__(self, params):
        super(OlnMaskModel, self).__init__(params)

        self._params = params

        # Different heads and layers.
        self._include_rpn_class = params.architecture.include_rpn_class
        self._include_mask = params.architecture.include_mask
        self._include_frcnn_class = params.architecture.include_frcnn_class
        self._include_frcnn_box = params.architecture.include_frcnn_box
        self._include_centerness = params.rpn_head.has_centerness
        self._include_box_score = (params.frcnn_head.has_scoring
                                   and params.architecture.include_frcnn_box)
        self._include_mask_score = (params.mrcnn_head.has_scoring
                                    and params.architecture.include_mask)

        # Architecture generators.
        self._backbone_fn = factory.backbone_generator(params)
        self._fpn_fn = factory.multilevel_features_generator(params)
        self._rpn_head_fn = factory.rpn_head_generator(params)
        if self._include_centerness:
            self._rpn_head_fn = factory.oln_rpn_head_generator(params)
        else:
            self._rpn_head_fn = factory.rpn_head_generator(params)
        self._generate_rois_fn = roi_ops.OlnROIGenerator(params.roi_proposal)
        self._sample_rois_fn = target_ops.ROIScoreSampler(params.roi_sampling)
        self._sample_masks_fn = target_ops.MaskSampler(
            params.architecture.mask_target_size,
            params.mask_sampling.num_mask_samples_per_image)

        if self._include_box_score:
            self._frcnn_head_fn = factory.oln_box_score_head_generator(params)
        else:
            self._frcnn_head_fn = factory.fast_rcnn_head_generator(params)

        if self._include_mask:
            if self._include_mask_score:
                self._mrcnn_head_fn = factory.oln_mask_score_head_generator(
                    params)
            else:
                self._mrcnn_head_fn = factory.mask_rcnn_head_generator(params)

        # Loss function.
        self._rpn_score_loss_fn = losses.RpnScoreLoss(params.rpn_score_loss)
        self._rpn_box_loss_fn = losses.RpnBoxLoss(params.rpn_box_loss)
        if self._include_centerness:
            self._rpn_iou_loss_fn = losses.OlnRpnIoULoss()
            self._rpn_center_loss_fn = losses.OlnRpnCenterLoss()
        self._frcnn_class_loss_fn = losses.FastrcnnClassLoss()
        self._frcnn_box_loss_fn = losses.FastrcnnBoxLoss(params.frcnn_box_loss)
        if self._include_box_score:
            self._frcnn_box_score_loss_fn = losses.OlnBoxScoreLoss(
                params.frcnn_box_score_loss)
        if self._include_mask:
            self._mask_loss_fn = losses.MaskrcnnLoss()

        self._generate_detections_fn = postprocess_ops.OlnDetectionGenerator(
            params.postprocess)

        self._transpose_input = params.train.transpose_input
        assert not self._transpose_input, 'Transpose input is not supportted.'