def test(self, output_dir=None, model_to_test=None): if output_dir is not None: self.cfg.OUTPUT_DIR = output_dir model = build_detection_model(self.cfg) device = torch.device(self.cfg.MODEL.DEVICE) model.to(device) arguments = {} arguments["iteration"] = 0 output_dir = self.cfg.OUTPUT_DIR save_to_disk = get_rank() == 0 checkpointer = DetectronCheckpointer( self.cfg, model, None, None, output_dir, save_to_disk ) if model_to_test is not None: self.cfg.MODEL.WEIGHT = model_to_test if self.cfg.MODEL.WEIGHT.startswith('/') or 'catalog' in self.cfg.MODEL.WEIGHT: model_path = self.cfg.MODEL.WEIGHT else: model_path = os.path.abspath(os.path.join(os.path.dirname(__file__), os.path.pardir, os.path.pardir, os.path.pardir, os.path.pardir, 'Data', 'pretrained_feature_extractors', self.cfg.MODEL.WEIGHT)) extra_checkpoint_data = checkpointer.load(model_path, use_latest=False) checkpointer.optimizer = make_optimizer(self.cfg, checkpointer.model) checkpointer.scheduler = make_lr_scheduler(self.cfg, checkpointer.optimizer) # Initialize mixed-precision training use_mixed_precision = self.cfg.DTYPE == "float16" amp_opt_level = 'O1' if use_mixed_precision else 'O0' model, optimizer = amp.initialize(checkpointer.model, checkpointer.optimizer, opt_level=amp_opt_level) if self.distributed: model = torch.nn.parallel.DistributedDataParallel( model, device_ids=[self.local_rank], output_device=self.local_rank, # this should be removed if we update BatchNorm stats broadcast_buffers=False, ) synchronize() _ = inference( # The result can be used for additional logging, e. g. for TensorBoard model, # The method changes the segmentation mask format in a data loader, # so every time a new data loader is created: make_data_loader(self.cfg, is_train=False, is_distributed=(get_world_size() > 1), is_target_task=self.is_target_task), dataset_name="[Test]", iou_types=("bbox",), box_only=False if cfg.MODEL.RETINANET_ON else cfg.MODEL.RPN_ONLY, device=cfg.MODEL.DEVICE, expected_results=cfg.TEST.EXPECTED_RESULTS, expected_results_sigma_tol=cfg.TEST.EXPECTED_RESULTS_SIGMA_TOL, output_folder=None, is_target_task=self.is_target_task, ) synchronize() logger = logging.getLogger("maskrcnn_benchmark") logger.handlers=[]
def train(self, output_dir=None, fine_tune_last_layers=False, fine_tune_rpn=False): if output_dir is not None: self.cfg.OUTPUT_DIR = output_dir model = build_detection_model(self.cfg) device = torch.device(self.cfg.MODEL.DEVICE) model.to(device) arguments = {} arguments["iteration"] = 0 output_dir = self.cfg.OUTPUT_DIR save_to_disk = get_rank() == 0 checkpointer = DetectronCheckpointer(self.cfg, model, None, None, output_dir, save_to_disk) if self.cfg.MODEL.WEIGHT.startswith( '/') or 'catalog' in self.cfg.MODEL.WEIGHT: model_path = self.cfg.MODEL.WEIGHT else: model_path = os.path.abspath( os.path.join(os.path.dirname(__file__), os.path.pardir, os.path.pardir, os.path.pardir, os.path.pardir, 'Data', 'pretrained_feature_extractors', self.cfg.MODEL.WEIGHT)) extra_checkpoint_data = checkpointer.load(model_path) if self.cfg.MINIBOOTSTRAP.DETECTOR.NUM_CLASSES + 1 != self.cfg.MODEL.ROI_BOX_HEAD.NUM_CLASSES: checkpointer.model.roi_heads.box.predictor.cls_score = torch.nn.Linear( in_features=checkpointer.model.roi_heads.box.predictor. cls_score.in_features, out_features=self.cfg.MINIBOOTSTRAP.DETECTOR.NUM_CLASSES + 1, bias=True) checkpointer.model.roi_heads.box.predictor.bbox_pred = torch.nn.Linear( in_features=checkpointer.model.roi_heads.box.predictor. cls_score.in_features, out_features=(self.cfg.MINIBOOTSTRAP.DETECTOR.NUM_CLASSES + 1) * 4, bias=True) if hasattr(checkpointer.model.roi_heads, 'mask'): checkpointer.model.roi_heads.mask.predictor.mask_fcn_logits = torch.nn.Conv2d( in_channels=checkpointer.model.roi_heads.mask.predictor. mask_fcn_logits.in_channels, out_channels=self.cfg.MINIBOOTSTRAP.DETECTOR.NUM_CLASSES + 1, kernel_size=(1, 1), stride=(1, 1)) checkpointer.model.to(device) if fine_tune_last_layers: checkpointer.model.roi_heads.box.predictor.cls_score = torch.nn.Linear( in_features=checkpointer.model.roi_heads.box.predictor. cls_score.in_features, out_features=self.cfg.MINIBOOTSTRAP.DETECTOR.NUM_CLASSES + 1, bias=True) checkpointer.model.roi_heads.box.predictor.bbox_pred = torch.nn.Linear( in_features=checkpointer.model.roi_heads.box.predictor. cls_score.in_features, out_features=(self.cfg.MINIBOOTSTRAP.DETECTOR.NUM_CLASSES + 1) * 4, bias=True) if hasattr(checkpointer.model.roi_heads, 'mask'): checkpointer.model.roi_heads.mask.predictor.mask_fcn_logits = torch.nn.Conv2d( in_channels=checkpointer.model.roi_heads.mask.predictor. mask_fcn_logits.in_channels, out_channels=self.cfg.MINIBOOTSTRAP.DETECTOR.NUM_CLASSES + 1, kernel_size=(1, 1), stride=(1, 1)) # Freeze backbone layers for elem in checkpointer.model.backbone.parameters(): elem.requires_grad = False if not fine_tune_rpn: # Freeze RPN layers for elem in checkpointer.model.rpn.parameters(): elem.requires_grad = False else: for elem in checkpointer.model.rpn.head.conv.parameters(): elem.requires_grad = False checkpointer.model.rpn.head.cls_logits = torch.nn.Conv2d( in_channels=checkpointer.model.rpn.head.cls_logits. in_channels, out_channels=checkpointer.model.rpn.head.cls_logits. out_channels, kernel_size=(1, 1), stride=(1, 1)) checkpointer.model.rpn.head.bbox_pred = torch.nn.Conv2d( in_channels=checkpointer.model.rpn.head.bbox_pred. in_channels, out_channels=checkpointer.model.rpn.head.bbox_pred. out_channels, kernel_size=(1, 1), stride=(1, 1)) # Freeze roi_heads layers with the exception of the predictor ones for elem in checkpointer.model.roi_heads.box.feature_extractor.parameters( ): elem.requires_grad = False for elem in checkpointer.model.roi_heads.box.predictor.parameters( ): elem.requires_grad = True if hasattr(checkpointer.model.roi_heads, 'mask'): for elem in checkpointer.model.roi_heads.mask.predictor.parameters( ): elem.requires_grad = False for elem in checkpointer.model.roi_heads.mask.predictor.mask_fcn_logits.parameters( ): elem.requires_grad = True checkpointer.model.to(device) checkpointer.optimizer = make_optimizer(self.cfg, checkpointer.model) checkpointer.scheduler = make_lr_scheduler(self.cfg, checkpointer.optimizer) # Initialize mixed-precision training use_mixed_precision = self.cfg.DTYPE == "float16" amp_opt_level = 'O1' if use_mixed_precision else 'O0' model, optimizer = amp.initialize(checkpointer.model, checkpointer.optimizer, opt_level=amp_opt_level) if self.distributed: model = torch.nn.parallel.DistributedDataParallel( model, device_ids=[self.local_rank], output_device=self.local_rank, # this should be removed if we update BatchNorm stats broadcast_buffers=False, ) data_loader = make_data_loader(self.cfg, is_train=True, is_distributed=self.distributed, start_iter=arguments["iteration"], is_target_task=self.is_target_task) test_period = self.cfg.SOLVER.TEST_PERIOD if test_period > 0: data_loader_val = make_data_loader( self.cfg, is_train=False, is_distributed=self.distributed, is_target_task=self.is_target_task) else: data_loader_val = None checkpoint_period = self.cfg.SOLVER.CHECKPOINT_PERIOD do_train(self.cfg, model, data_loader, data_loader_val, checkpointer.optimizer, checkpointer.scheduler, checkpointer, device, checkpoint_period, test_period, arguments, is_target_task=self.is_target_task) logger = logging.getLogger("maskrcnn_benchmark") logger.handlers = []