コード例 #1
0
    def __init__(self, params):
        super(ShapeMaskModel, self).__init__(params)

        self._params = params

        # Architecture generators.
        self._backbone_fn = factory.backbone_generator(params)
        self._fpn_fn = factory.multilevel_features_generator(params)
        self._retinanet_head_fn = factory.retinanet_head_generator(params)
        self._shape_prior_head_fn = factory.shapeprior_head_generator(params)
        self._coarse_mask_fn = factory.coarsemask_head_generator(params)
        self._fine_mask_fn = factory.finemask_head_generator(params)

        self._outer_box_scale = params.architecture.outer_box_scale

        # Loss function.
        self._cls_loss_fn = losses.RetinanetClassLoss(
            params.retinanet_loss, params.architecture.num_classes)
        self._box_loss_fn = losses.RetinanetBoxLoss(params.retinanet_loss)
        self._box_loss_weight = params.retinanet_loss.box_loss_weight
        # Mask loss function.
        self._shapemask_prior_loss_fn = losses.ShapemaskMseLoss()
        self._shapemask_loss_fn = losses.ShapemaskLoss()
        self._shape_prior_loss_weight = (
            params.shapemask_loss.shape_prior_loss_weight)
        self._coarse_mask_loss_weight = (
            params.shapemask_loss.coarse_mask_loss_weight)
        self._fine_mask_loss_weight = (
            params.shapemask_loss.fine_mask_loss_weight)
        # Predict function.
        self._generate_detections_fn = postprocess_ops.MultilevelDetectionGenerator(
            params.architecture.min_level, params.architecture.max_level,
            params.postprocess)
コード例 #2
0
    def __init__(self, params):
        super(RetinanetModel, self).__init__(params)

        # For eval metrics.
        self._params = params

        # Architecture generators.
        self._backbone_fn = factory.backbone_generator(params)
        self._fpn_fn = factory.multilevel_features_generator(params)
        self._head_fn = factory.retinanet_head_generator(params.retinanet_head)

        # Loss function.
        self._cls_loss_fn = losses.RetinanetClassLoss(params.retinanet_loss)
        self._box_loss_fn = losses.RetinanetBoxLoss(params.retinanet_loss)
        self._box_loss_weight = params.retinanet_loss.box_loss_weight
        self._keras_model = None

        # Predict function.
        self._generate_detections_fn = postprocess_ops.MultilevelDetectionGenerator(
            params.postprocess)

        self._transpose_input = params.train.transpose_input
        assert not self._transpose_input, 'Transpose input is not supportted.'
        # Input layer.
        input_shape = (params.retinanet_parser.output_size +
                       [params.retinanet_parser.num_channels])
        self._input_layer = tf.keras.layers.Input(
            shape=input_shape,
            name='',
            batch_size=self._params.train.batch_size,
            dtype=tf.bfloat16 if self._use_bfloat16 else tf.float32)
コード例 #3
0
ファイル: retinanet_model.py プロジェクト: xjwangziyan/tpu
    def __init__(self, params):
        super(RetinanetModel, self).__init__(params)

        # Architecture generators.
        self._backbone_fn = factory.backbone_generator(params)
        self._fpn_fn = factory.multilevel_features_generator(params)
        self._head_fn = factory.retinanet_head_generator(params.retinanet_head)

        # Loss function.
        self._cls_loss_fn = losses.RetinanetClassLoss(params.retinanet_loss)
        self._box_loss_fn = losses.RetinanetBoxLoss(params.retinanet_loss)
        self._box_loss_weight = params.retinanet_loss.box_loss_weight

        # Predict function.
        self._generate_detections_fn = postprocess_ops.MultilevelDetectionGenerator(
            params.postprocess)

        self._transpose_input = params.train.transpose_input
コード例 #4
0
ファイル: retinanet_model.py プロジェクト: qing0991/tpu
    def __init__(self, params):
        super(RetinanetModel, self).__init__(params)

        # Architecture generators.
        self._backbone_fn = factory.backbone_generator(params)
        self._fpn_fn = factory.multilevel_features_generator(params)
        self._head_fn = factory.retinanet_head_generator(params.retinanet_head)

        # Loss function.
        self._cls_loss_fn = losses.RetinanetClassLoss(params.retinanet_loss)
        self._box_loss_fn = losses.RetinanetBoxLoss(params.retinanet_loss)
        self._box_loss_weight = params.retinanet_loss.box_loss_weight

        # Predict function.
        self._generate_detections_fn = postprocess.GenerateOneStageDetections(
            params.postprocess)

        self._l2_weight_decay = params.architecture.l2_weight_decay
コード例 #5
0
ファイル: retinanet_model.py プロジェクト: vishalbelsare/tpu
    def __init__(self, params):
        super(RetinanetModel, self).__init__(params)

        self._params = params

        # Architecture generators.
        self._backbone_fn = factory.backbone_generator(params)
        self._fpn_fn = factory.multilevel_features_generator(params)
        self._head_fn = factory.retinanet_head_generator(params)

        # Loss function.
        self._cls_loss_fn = losses.RetinanetClassLoss(
            params.retinanet_loss, params.architecture.num_classes)
        self._box_loss_fn = losses.RetinanetBoxLoss(params.retinanet_loss)
        self._box_loss_weight = params.retinanet_loss.box_loss_weight
        self._focal_loss_normalizer_momentum = (
            params.retinanet_loss.normalizer_momentum)

        # Predict function.
        self._generate_detections_fn = postprocess_ops.MultilevelDetectionGenerator(
            params.architecture.min_level, params.architecture.max_level,
            params.postprocess)