예제 #1
0
    def test_get_boxes_with_five_classes_share_box_across_classes(self):
        mask_box_predictor = (
            box_predictor_builder.build_mask_rcnn_keras_box_predictor(
                is_training=False,
                num_classes=5,
                fc_hyperparams=self._build_hyperparams(),
                freeze_batchnorm=False,
                use_dropout=False,
                dropout_keep_prob=0.5,
                box_code_size=4,
                share_box_across_classes=True))

        def graph_fn(image_features):

            box_predictions = mask_box_predictor([image_features],
                                                 prediction_stage=2)
            return (box_predictions[box_predictor.BOX_ENCODINGS],
                    box_predictions[
                        box_predictor.CLASS_PREDICTIONS_WITH_BACKGROUND])

        image_features = np.random.rand(2, 7, 7, 3).astype(np.float32)
        (box_encodings, class_predictions_with_background) = self.execute(
            graph_fn, [image_features])
        self.assertAllEqual(box_encodings.shape, [2, 1, 1, 4])
        self.assertAllEqual(class_predictions_with_background.shape, [2, 1, 6])
 def graph_fn(image_features):
     mask_box_predictor = (
         box_predictor_builder.build_mask_rcnn_keras_box_predictor(
             is_training=False,
             num_classes=5,
             fc_hyperparams=self._build_hyperparams(),
             freeze_batchnorm=False,
             use_dropout=False,
             dropout_keep_prob=0.5,
             box_code_size=4,
             conv_hyperparams=self._build_hyperparams(
                 op_type=hyperparams_pb2.Hyperparams.CONV),
             predict_instance_masks=True))
     box_predictions = mask_box_predictor([image_features],
                                          prediction_stage=3)
     return (box_predictions[box_predictor.MASK_PREDICTIONS], )
 def graph_fn(image_features):
     mask_box_predictor = (
         box_predictor_builder.build_mask_rcnn_keras_box_predictor(
             is_training=False,
             num_classes=5,
             fc_hyperparams=self._build_hyperparams(),
             freeze_batchnorm=False,
             use_dropout=False,
             dropout_keep_prob=0.5,
             box_code_size=4,
             share_box_across_classes=True))
     box_predictions = mask_box_predictor([image_features],
                                          prediction_stage=2)
     return (box_predictions[box_predictor.BOX_ENCODINGS],
             box_predictions[
                 box_predictor.CLASS_PREDICTIONS_WITH_BACKGROUND])
 def test_do_not_return_instance_masks_without_request(self):
     image_features = tf.random.uniform([2, 7, 7, 3], dtype=tf.float32)
     mask_box_predictor = (
         box_predictor_builder.build_mask_rcnn_keras_box_predictor(
             is_training=False,
             num_classes=5,
             fc_hyperparams=self._build_hyperparams(),
             freeze_batchnorm=False,
             use_dropout=False,
             dropout_keep_prob=0.5,
             box_code_size=4))
     box_predictions = mask_box_predictor([image_features],
                                          prediction_stage=2)
     self.assertEqual(len(box_predictions), 2)
     self.assertTrue(box_predictor.BOX_ENCODINGS in box_predictions)
     self.assertTrue(
         box_predictor.CLASS_PREDICTIONS_WITH_BACKGROUND in box_predictions)
예제 #5
0
 def _get_mask_rcnn_box_predictor(self):
     conv_hyperparams = KerasLayerHyperparamsFromData(
         activation=self.activation,
         initializer=self.initializer,
         batch_norm_params=self.batch_norm_params)
     fc_hyperparams = KerasLayerHyperparamsFromData(
         activation=self.activation,
         initializer=self.initializer,
         batch_norm_params=self.batch_norm_params)
     mask_rcnn_box_predictor = (
         box_predictor_builder.build_mask_rcnn_keras_box_predictor(
             is_training=None,
             num_classes=self.num_classes,
             add_background_class=True,
             fc_hyperparams=fc_hyperparams,
             freeze_batchnorm=self.freeze_batchnorm,
             use_dropout=self._dropout_keep_prob is not None,
             dropout_keep_prob=self._dropout_keep_prob,
             box_code_size=4,
             conv_hyperparams=conv_hyperparams))
     return mask_rcnn_box_predictor
예제 #6
0
    def test_get_instance_masks(self):
        mask_box_predictor = (
            box_predictor_builder.build_mask_rcnn_keras_box_predictor(
                is_training=False,
                num_classes=5,
                fc_hyperparams=self._build_hyperparams(),
                freeze_batchnorm=False,
                use_dropout=False,
                dropout_keep_prob=0.5,
                box_code_size=4,
                conv_hyperparams=self._build_hyperparams(
                    op_type=hyperparams_pb2.Hyperparams.CONV),
                predict_instance_masks=True))

        def graph_fn(image_features):
            box_predictions = mask_box_predictor([image_features],
                                                 prediction_stage=3)
            return (box_predictions[box_predictor.MASK_PREDICTIONS], )

        image_features = np.random.rand(2, 7, 7, 3).astype(np.float32)
        mask_predictions = self.execute(graph_fn, [image_features])
        self.assertAllEqual(mask_predictions.shape, [2, 1, 5, 14, 14])