def test_model_initializing(self, init_checkpoint_modules): shared_backbone = ('segmentation_backbone' not in init_checkpoint_modules) shared_decoder = ('segmentation_decoder' not in init_checkpoint_modules and shared_backbone) task_config = cfg.PanopticMaskRCNNTask( model=cfg.PanopticMaskRCNN( num_classes=2, input_size=[640, 640, 3], segmentation_model=segmentation_cfg.SemanticSegmentationModel( decoder=decoder_cfg.Decoder(type='fpn')), shared_backbone=shared_backbone, shared_decoder=shared_decoder)) task = panoptic_maskrcnn.PanopticMaskRCNNTask(task_config) model = task.build_model() ckpt = tf.train.Checkpoint(**model.checkpoint_items) ckpt_save_dir = self.create_tempdir().full_path ckpt.save(os.path.join(ckpt_save_dir, 'ckpt')) if (init_checkpoint_modules == ['all'] or 'backbone' in init_checkpoint_modules): task._task_config.init_checkpoint = ckpt_save_dir if ('segmentation_backbone' in init_checkpoint_modules or 'segmentation_decoder' in init_checkpoint_modules): task._task_config.segmentation_init_checkpoint = ckpt_save_dir task._task_config.init_checkpoint_modules = init_checkpoint_modules task.initialize(model)
def test_builder(self, backbone_type, input_size, segmentation_backbone_type, segmentation_decoder_type): num_classes = 2 input_specs = tf.keras.layers.InputSpec( shape=[None, input_size[0], input_size[1], 3]) segmentation_output_stride = 16 level = int(np.math.log2(segmentation_output_stride)) segmentation_model = semantic_segmentation.SemanticSegmentationModel( num_classes=2, backbone=backbones.Backbone(type=segmentation_backbone_type), decoder=decoders.Decoder(type=segmentation_decoder_type), head=semantic_segmentation.SegmentationHead(level=level)) model_config = panoptic_maskrcnn_cfg.PanopticMaskRCNN( num_classes=num_classes, segmentation_model=segmentation_model, backbone=backbones.Backbone(type=backbone_type), shared_backbone=segmentation_backbone_type is None, shared_decoder=segmentation_decoder_type is None) l2_regularizer = tf.keras.regularizers.l2(5e-5) _ = factory.build_panoptic_maskrcnn(input_specs=input_specs, model_config=model_config, l2_regularizer=l2_regularizer)