예제 #1
0
    def test(self, output_dir=None, model_to_test=None):
        if output_dir is not None:
            self.cfg.OUTPUT_DIR = output_dir
        model = build_detection_model(self.cfg)
        device = torch.device(self.cfg.MODEL.DEVICE)
        model.to(device)

        arguments = {}
        arguments["iteration"] = 0

        output_dir = self.cfg.OUTPUT_DIR

        save_to_disk = get_rank() == 0
        checkpointer = DetectronCheckpointer(
            self.cfg, model, None, None, output_dir, save_to_disk
        )

        if model_to_test is not None:
            self.cfg.MODEL.WEIGHT = model_to_test

        if self.cfg.MODEL.WEIGHT.startswith('/') or 'catalog' in self.cfg.MODEL.WEIGHT:
            model_path = self.cfg.MODEL.WEIGHT
        else:
            model_path = os.path.abspath(os.path.join(os.path.dirname(__file__), os.path.pardir, os.path.pardir, os.path.pardir, os.path.pardir, 'Data', 'pretrained_feature_extractors', self.cfg.MODEL.WEIGHT))

        extra_checkpoint_data = checkpointer.load(model_path, use_latest=False)

        checkpointer.optimizer = make_optimizer(self.cfg, checkpointer.model)
        checkpointer.scheduler = make_lr_scheduler(self.cfg, checkpointer.optimizer)

        # Initialize mixed-precision training
        use_mixed_precision = self.cfg.DTYPE == "float16"
        amp_opt_level = 'O1' if use_mixed_precision else 'O0'
        model, optimizer = amp.initialize(checkpointer.model, checkpointer.optimizer, opt_level=amp_opt_level)

        if self.distributed:
            model = torch.nn.parallel.DistributedDataParallel(
                model, device_ids=[self.local_rank], output_device=self.local_rank,
                # this should be removed if we update BatchNorm stats
                broadcast_buffers=False,
            )
        synchronize()
        _ = inference(  # The result can be used for additional logging, e. g. for TensorBoard
            model,
            # The method changes the segmentation mask format in a data loader,
            # so every time a new data loader is created:
            make_data_loader(self.cfg, is_train=False, is_distributed=(get_world_size() > 1), is_target_task=self.is_target_task),
            dataset_name="[Test]",
            iou_types=("bbox",),
            box_only=False if cfg.MODEL.RETINANET_ON else cfg.MODEL.RPN_ONLY,
            device=cfg.MODEL.DEVICE,
            expected_results=cfg.TEST.EXPECTED_RESULTS,
            expected_results_sigma_tol=cfg.TEST.EXPECTED_RESULTS_SIGMA_TOL,
            output_folder=None,
            is_target_task=self.is_target_task,
        )
        synchronize()

        logger = logging.getLogger("maskrcnn_benchmark")
        logger.handlers=[]
예제 #2
0
    def train(self,
              output_dir=None,
              fine_tune_last_layers=False,
              fine_tune_rpn=False):
        if output_dir is not None:
            self.cfg.OUTPUT_DIR = output_dir
        model = build_detection_model(self.cfg)
        device = torch.device(self.cfg.MODEL.DEVICE)
        model.to(device)

        arguments = {}
        arguments["iteration"] = 0

        output_dir = self.cfg.OUTPUT_DIR

        save_to_disk = get_rank() == 0
        checkpointer = DetectronCheckpointer(self.cfg, model, None, None,
                                             output_dir, save_to_disk)

        if self.cfg.MODEL.WEIGHT.startswith(
                '/') or 'catalog' in self.cfg.MODEL.WEIGHT:
            model_path = self.cfg.MODEL.WEIGHT
        else:
            model_path = os.path.abspath(
                os.path.join(os.path.dirname(__file__), os.path.pardir,
                             os.path.pardir, os.path.pardir, os.path.pardir,
                             'Data', 'pretrained_feature_extractors',
                             self.cfg.MODEL.WEIGHT))

        extra_checkpoint_data = checkpointer.load(model_path)

        if self.cfg.MINIBOOTSTRAP.DETECTOR.NUM_CLASSES + 1 != self.cfg.MODEL.ROI_BOX_HEAD.NUM_CLASSES:
            checkpointer.model.roi_heads.box.predictor.cls_score = torch.nn.Linear(
                in_features=checkpointer.model.roi_heads.box.predictor.
                cls_score.in_features,
                out_features=self.cfg.MINIBOOTSTRAP.DETECTOR.NUM_CLASSES + 1,
                bias=True)
            checkpointer.model.roi_heads.box.predictor.bbox_pred = torch.nn.Linear(
                in_features=checkpointer.model.roi_heads.box.predictor.
                cls_score.in_features,
                out_features=(self.cfg.MINIBOOTSTRAP.DETECTOR.NUM_CLASSES + 1)
                * 4,
                bias=True)
            if hasattr(checkpointer.model.roi_heads, 'mask'):
                checkpointer.model.roi_heads.mask.predictor.mask_fcn_logits = torch.nn.Conv2d(
                    in_channels=checkpointer.model.roi_heads.mask.predictor.
                    mask_fcn_logits.in_channels,
                    out_channels=self.cfg.MINIBOOTSTRAP.DETECTOR.NUM_CLASSES +
                    1,
                    kernel_size=(1, 1),
                    stride=(1, 1))
            checkpointer.model.to(device)

        if fine_tune_last_layers:
            checkpointer.model.roi_heads.box.predictor.cls_score = torch.nn.Linear(
                in_features=checkpointer.model.roi_heads.box.predictor.
                cls_score.in_features,
                out_features=self.cfg.MINIBOOTSTRAP.DETECTOR.NUM_CLASSES + 1,
                bias=True)
            checkpointer.model.roi_heads.box.predictor.bbox_pred = torch.nn.Linear(
                in_features=checkpointer.model.roi_heads.box.predictor.
                cls_score.in_features,
                out_features=(self.cfg.MINIBOOTSTRAP.DETECTOR.NUM_CLASSES + 1)
                * 4,
                bias=True)
            if hasattr(checkpointer.model.roi_heads, 'mask'):
                checkpointer.model.roi_heads.mask.predictor.mask_fcn_logits = torch.nn.Conv2d(
                    in_channels=checkpointer.model.roi_heads.mask.predictor.
                    mask_fcn_logits.in_channels,
                    out_channels=self.cfg.MINIBOOTSTRAP.DETECTOR.NUM_CLASSES +
                    1,
                    kernel_size=(1, 1),
                    stride=(1, 1))
            # Freeze backbone layers
            for elem in checkpointer.model.backbone.parameters():
                elem.requires_grad = False
            if not fine_tune_rpn:
                # Freeze RPN layers
                for elem in checkpointer.model.rpn.parameters():
                    elem.requires_grad = False
            else:
                for elem in checkpointer.model.rpn.head.conv.parameters():
                    elem.requires_grad = False
                checkpointer.model.rpn.head.cls_logits = torch.nn.Conv2d(
                    in_channels=checkpointer.model.rpn.head.cls_logits.
                    in_channels,
                    out_channels=checkpointer.model.rpn.head.cls_logits.
                    out_channels,
                    kernel_size=(1, 1),
                    stride=(1, 1))
                checkpointer.model.rpn.head.bbox_pred = torch.nn.Conv2d(
                    in_channels=checkpointer.model.rpn.head.bbox_pred.
                    in_channels,
                    out_channels=checkpointer.model.rpn.head.bbox_pred.
                    out_channels,
                    kernel_size=(1, 1),
                    stride=(1, 1))
            # Freeze roi_heads layers with the exception of the predictor ones
            for elem in checkpointer.model.roi_heads.box.feature_extractor.parameters(
            ):
                elem.requires_grad = False
            for elem in checkpointer.model.roi_heads.box.predictor.parameters(
            ):
                elem.requires_grad = True
            if hasattr(checkpointer.model.roi_heads, 'mask'):
                for elem in checkpointer.model.roi_heads.mask.predictor.parameters(
                ):
                    elem.requires_grad = False
                for elem in checkpointer.model.roi_heads.mask.predictor.mask_fcn_logits.parameters(
                ):
                    elem.requires_grad = True
            checkpointer.model.to(device)

        checkpointer.optimizer = make_optimizer(self.cfg, checkpointer.model)
        checkpointer.scheduler = make_lr_scheduler(self.cfg,
                                                   checkpointer.optimizer)

        # Initialize mixed-precision training
        use_mixed_precision = self.cfg.DTYPE == "float16"
        amp_opt_level = 'O1' if use_mixed_precision else 'O0'
        model, optimizer = amp.initialize(checkpointer.model,
                                          checkpointer.optimizer,
                                          opt_level=amp_opt_level)

        if self.distributed:
            model = torch.nn.parallel.DistributedDataParallel(
                model,
                device_ids=[self.local_rank],
                output_device=self.local_rank,
                # this should be removed if we update BatchNorm stats
                broadcast_buffers=False,
            )

        data_loader = make_data_loader(self.cfg,
                                       is_train=True,
                                       is_distributed=self.distributed,
                                       start_iter=arguments["iteration"],
                                       is_target_task=self.is_target_task)

        test_period = self.cfg.SOLVER.TEST_PERIOD
        if test_period > 0:
            data_loader_val = make_data_loader(
                self.cfg,
                is_train=False,
                is_distributed=self.distributed,
                is_target_task=self.is_target_task)
        else:
            data_loader_val = None

        checkpoint_period = self.cfg.SOLVER.CHECKPOINT_PERIOD
        do_train(self.cfg,
                 model,
                 data_loader,
                 data_loader_val,
                 checkpointer.optimizer,
                 checkpointer.scheduler,
                 checkpointer,
                 device,
                 checkpoint_period,
                 test_period,
                 arguments,
                 is_target_task=self.is_target_task)

        logger = logging.getLogger("maskrcnn_benchmark")
        logger.handlers = []