コード例 #1
0
  def test_model_initializing(self, init_checkpoint_modules):

    shared_backbone = ('segmentation_backbone' not in init_checkpoint_modules)
    shared_decoder = ('segmentation_decoder' not in init_checkpoint_modules and
                      shared_backbone)

    task_config = cfg.PanopticMaskRCNNTask(
        model=cfg.PanopticMaskRCNN(
            num_classes=2,
            input_size=[640, 640, 3],
            segmentation_model=segmentation_cfg.SemanticSegmentationModel(
                decoder=decoder_cfg.Decoder(type='fpn')),
            shared_backbone=shared_backbone,
            shared_decoder=shared_decoder))

    task = panoptic_maskrcnn.PanopticMaskRCNNTask(task_config)
    model = task.build_model()

    ckpt = tf.train.Checkpoint(**model.checkpoint_items)
    ckpt_save_dir = self.create_tempdir().full_path
    ckpt.save(os.path.join(ckpt_save_dir, 'ckpt'))

    if (init_checkpoint_modules == ['all'] or
        'backbone' in init_checkpoint_modules):
      task._task_config.init_checkpoint = ckpt_save_dir
    if ('segmentation_backbone' in init_checkpoint_modules or
        'segmentation_decoder' in init_checkpoint_modules):
      task._task_config.segmentation_init_checkpoint = ckpt_save_dir

    task._task_config.init_checkpoint_modules = init_checkpoint_modules
    task.initialize(model)
コード例 #2
0
ファイル: factory_test.py プロジェクト: kmady/models
 def test_builder(self, backbone_type, input_size,
                  segmentation_backbone_type, segmentation_decoder_type):
     num_classes = 2
     input_specs = tf.keras.layers.InputSpec(
         shape=[None, input_size[0], input_size[1], 3])
     segmentation_output_stride = 16
     level = int(np.math.log2(segmentation_output_stride))
     segmentation_model = semantic_segmentation.SemanticSegmentationModel(
         num_classes=2,
         backbone=backbones.Backbone(type=segmentation_backbone_type),
         decoder=decoders.Decoder(type=segmentation_decoder_type),
         head=semantic_segmentation.SegmentationHead(level=level))
     model_config = panoptic_maskrcnn_cfg.PanopticMaskRCNN(
         num_classes=num_classes,
         segmentation_model=segmentation_model,
         backbone=backbones.Backbone(type=backbone_type),
         shared_backbone=segmentation_backbone_type is None,
         shared_decoder=segmentation_decoder_type is None)
     l2_regularizer = tf.keras.regularizers.l2(5e-5)
     _ = factory.build_panoptic_maskrcnn(input_specs=input_specs,
                                         model_config=model_config,
                                         l2_regularizer=l2_regularizer)