예제 #1
0
class DefaultTask(pl.LightningModule):
    def __init__(self, cfg: CfgNode):
        super().__init__()
        self.cfg = cfg
        self.model = build_model(cfg)
        self.storage = None
        # evaluators for validation datasets, split by model tag(default, ema),
        # in the order of DATASETS.TEST
        self.dataset_evaluators = {ModelTag.DEFAULT: []}
        self.save_hyperparameters()
        self.eval_res = None

        self.ema_state: Optional[EMAState] = None
        if cfg.MODEL_EMA.ENABLED:
            self.ema_state = EMAState(
                decay=cfg.MODEL_EMA.DECAY,
                device=cfg.MODEL_EMA.DEVICE or cfg.MODEL.DEVICE,
            )
            self.model_ema = deepcopy(self.model)
            self.dataset_evaluators[ModelTag.EMA] = []

    def setup(self, stage: str):
        setup_after_launch(self.cfg, self.cfg.OUTPUT_DIR, runner=None)

    @classmethod
    def get_default_cfg(cls):
        return Detectron2GoRunner.get_default_cfg()

    def training_step(self, batch, batch_idx):
        loss_dict = self.forward(batch)
        losses = sum(loss_dict.values())
        self.storage.step()

        self.log_dict(loss_dict, prog_bar=True)
        return losses

    def test_step(self, batch, batch_idx: int, dataloader_idx: int = 0) -> None:
        self._evaluation_step(batch, batch_idx, dataloader_idx)

    def validation_step(self, batch, batch_idx: int, dataloader_idx: int = 0) -> None:
        self._evaluation_step(batch, batch_idx, dataloader_idx)

    def _evaluation_step(self, batch, batch_idx: int, dataloader_idx: int) -> None:
        if not isinstance(batch, List):
            batch = [batch]
        outputs = self.forward(batch)
        self.dataset_evaluators[ModelTag.DEFAULT][dataloader_idx].process(
            batch, outputs
        )

        if self.ema_state:
            ema_outputs = self.model_ema(batch)
            self.dataset_evaluators[ModelTag.EMA][dataloader_idx].process(
                batch, ema_outputs
            )

    def _log_dataset_evaluation_results(self) -> None:
        nested_res = {}
        for tag, evaluators in self.dataset_evaluators.items():
            res = {}
            for idx, evaluator in enumerate(evaluators):
                dataset_name = self.cfg.DATASETS.TEST[idx]
                res[dataset_name] = evaluator.evaluate()
            nested_res[tag.value] = res

        self.eval_res = nested_res
        flattened = pl.loggers.LightningLoggerBase._flatten_dict(nested_res)
        self.log_dict(flattened)

    def test_epoch_end(self, _outputs) -> None:
        self._evaluation_epoch_end()

    def validation_epoch_end(self, _outputs) -> None:
        self._evaluation_epoch_end()

    def _evaluation_epoch_end(self) -> None:
        self._log_dataset_evaluation_results()
        self._reset_dataset_evaluators()

    def configure_optimizers(
        self,
    ) -> Tuple[List[torch.optim.Optimizer], List]:
        optim = d2_build_optimizer(self.cfg, self.model)
        lr_scheduler = d2_build_lr_scheduler(self.cfg, optim)

        return [optim], [{"scheduler": lr_scheduler, "interval": "step"}]

    def train_dataloader(self):
        return Detectron2GoRunner.build_detection_train_loader(self.cfg)

    def _reset_dataset_evaluators(self):
        """reset validation dataset evaluator to be run in EVAL_PERIOD steps"""
        assert (
            not self.trainer.distributed_backend
            or self.trainer.distributed_backend.lower()
            in [
                "ddp",
                "ddp_cpu",
            ]
        ), (
            "Only DDP and DDP_CPU distributed backend are supported"
        )

        def _get_inference_dir_name(
            base_dir, inference_type, dataset_name, model_tag: ModelTag
        ):
            next_eval_iter = self.trainer.global_step + self.cfg.TEST.EVAL_PERIOD
            if self.trainer.global_step == 0:
                next_eval_iter -= 1
            return os.path.join(
                base_dir,
                inference_type,
                model_tag,
                str(next_eval_iter),
                dataset_name,
            )

        for tag, dataset_evaluators in self.dataset_evaluators.items():
            dataset_evaluators.clear()
            assert self.cfg.OUTPUT_DIR, "Expect output_dir to be specified in config"
            for dataset_name in self.cfg.DATASETS.TEST:
                # setup evaluator for each dataset
                output_folder = _get_inference_dir_name(
                    self.cfg.OUTPUT_DIR, "inference", dataset_name, tag
                )
                evaluator = Detectron2GoRunner.get_evaluator(
                    self.cfg, dataset_name, output_folder=output_folder
                )
                evaluator.reset()
                dataset_evaluators.append(evaluator)
                # TODO: add visualization evaluator

    def _evaluation_dataloader(self):
        # TODO: Support subsample n images
        assert len(self.cfg.DATASETS.TEST)

        dataloaders = []
        for dataset_name in self.cfg.DATASETS.TEST:
            dataloaders.append(
                Detectron2GoRunner.build_detection_test_loader(self.cfg, dataset_name)
            )

        self._reset_dataset_evaluators()
        return dataloaders

    def test_dataloader(self):
        return self._evaluation_dataloader()

    def val_dataloader(self):
        return self._evaluation_dataloader()

    def forward(self, input):
        return self.model(input)

    def on_pretrain_routine_end(self) -> None:
        if self.cfg.MODEL_EMA.ENABLED:
            if self.ema_state and self.ema_state.has_inited():
                # ema_state could have been loaded from checkpoint
                return
            self.ema_state = EMAState.from_model(
                self.model,
                decay=self.cfg.MODEL_EMA.DECAY,
                device=self.cfg.MODEL_EMA.DEVICE or self.cfg.MODEL.DEVICE,
            )

    def on_train_batch_end(self, outputs, batch, batch_idx, dataloader_idx) -> None:
        if self.ema_state:
            self.ema_state.update(self.model)

    def on_test_epoch_start(self):
        self._on_evaluation_epoch_start()

    def on_validation_epoch_start(self):
        self._on_evaluation_epoch_start()

    def _on_evaluation_epoch_start(self):
        if self.ema_state:
            self.ema_state.apply_to(self.model_ema)

    def on_save_checkpoint(self, checkpoint: Dict[str, Any]) -> None:
        if self.ema_state:
            checkpoint["model_ema"] = self.ema_state.state_dict()

    def on_load_checkpoint(self, checkpointed_state: Dict[str, Any]) -> None:
        """
        Called before model state is restored. Explicitly handles old model
        states so we can resume training from D2Go checkpoints transparently.

        Args:
            checkpointed_state: The raw checkpoint state as returned by torch.load
                or equivalent.
        """
        # If this is a non-Lightning checkpoint, we need to convert it.
        if not _is_lightning_checkpoint(checkpointed_state) and not _is_d2go_checkpoint(
            checkpointed_state
        ):
            raise ValueError(
                f"Invalid checkpoint state with keys: {checkpointed_state.keys()}"
            )
        if not _is_lightning_checkpoint(checkpointed_state):
            _convert_to_lightning(checkpointed_state)

        if self.ema_state:
            if "model_ema" not in checkpointed_state:
                rank_zero_info(
                    "EMA is enabled but EMA state is not found in given checkpoint"
                )
            else:
                self.ema_state = EMAState()
                self.ema_state.load_state_dict(checkpointed_state["model_ema"])
                if not self.ema_state.device:
                    # EMA state device not given, move to module device
                    self.ema_state.to(self.device)
예제 #2
0
class DefaultTask(pl.LightningModule):
    def __init__(self, cfg: CfgNode):
        super().__init__()
        self.register(cfg)
        self.cfg = cfg
        self.model = self._build_model()
        self.storage = None
        # evaluators for validation datasets, split by model tag(default, ema),
        # in the order of DATASETS.TEST
        self.dataset_evaluators = {ModelTag.DEFAULT: []}
        self.save_hyperparameters()
        self.eval_res = None

        self.ema_state: Optional[EMAState] = None
        if cfg.MODEL_EMA.ENABLED:
            self.ema_state = EMAState(
                decay=cfg.MODEL_EMA.DECAY,
                device=cfg.MODEL_EMA.DEVICE or cfg.MODEL.DEVICE,
            )
            self.model_ema = deepcopy(self.model)
            self.dataset_evaluators[ModelTag.EMA] = []

    def _build_model(self):
        model = build_model(self.cfg)

        if self.cfg.MODEL.FROZEN_LAYER_REG_EXP:
            set_requires_grad(model,
                              self.cfg.MODEL.FROZEN_LAYER_REG_EXP,
                              value=False)

        return model

    @classmethod
    def from_config(cls, cfg: CfgNode, eval_only=False):
        """Builds Lightning module including model from config.
        To load weights from a pretrained checkpoint, please specify checkpoint
        path in `MODEL.WEIGHTS`.

        Args:
            cfg: D2go config node.
            eval_only: True if module should be in eval mode.
        """
        if eval_only and not cfg.MODEL.WEIGHTS:
            logger.warning("MODEL.WEIGHTS is missing for eval only mode.")

        if cfg.MODEL.WEIGHTS:
            # only load model weights from checkpoint
            logger.info(
                f"Load model weights from checkpoint: {cfg.MODEL.WEIGHTS}.")
            task = cls.load_from_checkpoint(cfg.MODEL.WEIGHTS, cfg=cfg)
        else:
            task = cls(cfg)

        if cfg.MODEL_EMA.ENABLED and cfg.MODEL_EMA.USE_EMA_WEIGHTS_FOR_EVAL_ONLY:
            assert task.ema_state, "EMA state is not loaded from checkpoint."
            task.ema_state.apply_to(task.model)

        if eval_only:
            task.eval()
        return task

    def training_step(self, batch, batch_idx):
        loss_dict = self.forward(batch)
        losses = sum(loss_dict.values())
        self.storage.step()

        self.log_dict(loss_dict, prog_bar=True)
        return losses

    def test_step(self,
                  batch,
                  batch_idx: int,
                  dataloader_idx: int = 0) -> None:
        self._evaluation_step(batch, batch_idx, dataloader_idx)

    def validation_step(self,
                        batch,
                        batch_idx: int,
                        dataloader_idx: int = 0) -> None:
        self._evaluation_step(batch, batch_idx, dataloader_idx)

    def _evaluation_step(self, batch, batch_idx: int,
                         dataloader_idx: int) -> None:
        if not isinstance(batch, List):
            batch = [batch]
        outputs = self.forward(batch)
        self.dataset_evaluators[ModelTag.DEFAULT][dataloader_idx].process(
            batch, outputs)

        if self.ema_state:
            ema_outputs = self.model_ema(batch)
            self.dataset_evaluators[ModelTag.EMA][dataloader_idx].process(
                batch, ema_outputs)

    def _log_dataset_evaluation_results(self) -> None:
        nested_res = {}
        for tag, evaluators in self.dataset_evaluators.items():
            res = {}
            for idx, evaluator in enumerate(evaluators):
                dataset_name = self.cfg.DATASETS.TEST[idx]
                res[dataset_name] = evaluator.evaluate()
            nested_res[tag.value] = res

        self.eval_res = nested_res
        flattened = pl.loggers.LightningLoggerBase._flatten_dict(nested_res)
        self.log_dict(flattened)

    def test_epoch_end(self, _outputs) -> None:
        self._evaluation_epoch_end()

    def validation_epoch_end(self, _outputs) -> None:
        self._evaluation_epoch_end()

    def _evaluation_epoch_end(self) -> None:
        self._log_dataset_evaluation_results()
        self._reset_dataset_evaluators()

    def configure_optimizers(
        self, ) -> Tuple[List[torch.optim.Optimizer], List]:
        optim = d2_build_optimizer(self.cfg, self.model)
        lr_scheduler = d2_build_lr_scheduler(self.cfg, optim)

        return [optim], [{"scheduler": lr_scheduler, "interval": "step"}]

    def train_dataloader(self):
        return self.build_detection_train_loader(self.cfg)

    def _reset_dataset_evaluators(self):
        """reset validation dataset evaluator to be run in EVAL_PERIOD steps"""
        assert (not self.trainer.distributed_backend
                or self.trainer.distributed_backend.lower() in [
                    "ddp",
                    "ddp_cpu",
                ]), ("Only DDP and DDP_CPU distributed backend are supported")

        def _get_inference_dir_name(base_dir, inference_type, dataset_name,
                                    model_tag: ModelTag):
            next_eval_iter = self.trainer.global_step + self.cfg.TEST.EVAL_PERIOD
            if self.trainer.global_step == 0:
                next_eval_iter -= 1
            return os.path.join(
                base_dir,
                inference_type,
                model_tag,
                str(next_eval_iter),
                dataset_name,
            )

        for tag, dataset_evaluators in self.dataset_evaluators.items():
            dataset_evaluators.clear()
            assert self.cfg.OUTPUT_DIR, "Expect output_dir to be specified in config"
            for dataset_name in self.cfg.DATASETS.TEST:
                # setup evaluator for each dataset
                output_folder = _get_inference_dir_name(
                    self.cfg.OUTPUT_DIR, "inference", dataset_name, tag)
                evaluator = self.get_evaluator(self.cfg,
                                               dataset_name,
                                               output_folder=output_folder)
                evaluator.reset()
                dataset_evaluators.append(evaluator)
                # TODO: add visualization evaluator

    def _evaluation_dataloader(self):
        # TODO: Support subsample n images
        assert len(self.cfg.DATASETS.TEST)

        dataloaders = []
        for dataset_name in self.cfg.DATASETS.TEST:
            dataloaders.append(
                self.build_detection_test_loader(self.cfg, dataset_name))

        self._reset_dataset_evaluators()
        return dataloaders

    def test_dataloader(self):
        return self._evaluation_dataloader()

    def val_dataloader(self):
        return self._evaluation_dataloader()

    def forward(self, input):
        return self.model(input)

    # ---------------------------------------------------------------------------
    # Runner methods
    # ---------------------------------------------------------------------------
    def setup(self, stage: str):
        setup_after_lightning_launch(self.cfg, self.cfg.OUTPUT_DIR)

    def register(self, cfg: CfgNode):
        inject_coco_datasets(cfg)
        register_dynamic_datasets(cfg)
        update_cfg_if_using_adhoc_dataset(cfg)
        patch_d2_meta_arch()

    @classmethod
    def build_model(cls, cfg: CfgNode, eval_only=False):
        """Builds D2go model instance from config.
        NOTE: For backward compatible with existing D2Go tools. Prefer
        `from_config` in other use cases.

        Args:
            cfg: D2go config node.
            eval_only: True if model should be in eval mode.
        """
        task = cls.from_config(cfg, eval_only)
        if hasattr(task, PREPARED):
            task = getattr(task, PREPARED)
        return task.model

    @classmethod
    def get_default_cfg(cls):
        return Detectron2GoRunner.get_default_cfg()

    @staticmethod
    def _initialize(cfg: CfgNode):
        pass

    @staticmethod
    def get_evaluator(cfg: CfgNode, dataset_name: str, output_folder: str):
        return Detectron2GoRunner.get_evaluator(cfg=cfg,
                                                dataset_name=dataset_name,
                                                output_folder=output_folder)

    @staticmethod
    def get_mapper(cfg, is_train):
        return Detectron2GoRunner.get_mapper(cfg, is_train)

    @staticmethod
    def build_detection_train_loader(cfg, *args, mapper=None, **kwargs):
        return Detectron2GoRunner.build_detection_train_loader(
            cfg, *args, **kwargs)

    @staticmethod
    def build_detection_test_loader(cfg, dataset_name, mapper=None):
        return Detectron2GoRunner.build_detection_test_loader(
            cfg, dataset_name, mapper)

    # ---------------------------------------------------------------------------
    # Hooks
    # ---------------------------------------------------------------------------
    def on_pretrain_routine_end(self) -> None:
        if self.cfg.MODEL_EMA.ENABLED:
            if self.ema_state and self.ema_state.has_inited():
                # ema_state could have been loaded from checkpoint
                return
            self.ema_state = EMAState.from_model(
                self.model,
                decay=self.cfg.MODEL_EMA.DECAY,
                device=self.cfg.MODEL_EMA.DEVICE or self.cfg.MODEL.DEVICE,
            )

    def on_train_batch_end(self, outputs, batch, batch_idx,
                           dataloader_idx) -> None:
        if self.ema_state:
            self.ema_state.update(self.model)

    def on_test_epoch_start(self):
        self._on_evaluation_epoch_start()

    def on_validation_epoch_start(self):
        self._on_evaluation_epoch_start()

    def _on_evaluation_epoch_start(self):
        if self.ema_state:
            self.ema_state.apply_to(self.model_ema)

    def on_save_checkpoint(self, checkpoint: Dict[str, Any]) -> None:
        if self.ema_state:
            checkpoint["model_ema"] = self.ema_state.state_dict()

    def on_load_checkpoint(self, checkpointed_state: Dict[str, Any]) -> None:
        """
        Called before model state is restored. Explicitly handles old model
        states so we can resume training from D2Go checkpoints transparently.

        Args:
            checkpointed_state: The raw checkpoint state as returned by torch.load
                or equivalent.
        """
        # If this is a non-Lightning checkpoint, we need to convert it.
        if not _is_lightning_checkpoint(
                checkpointed_state) and not _is_d2go_checkpoint(
                    checkpointed_state):
            raise ValueError(
                f"Invalid checkpoint state with keys: {checkpointed_state.keys()}"
            )
        if not _is_lightning_checkpoint(checkpointed_state):
            _convert_to_lightning(checkpointed_state)

        maybe_prepare_for_quantization(self, checkpointed_state)

        if self.ema_state:
            if "model_ema" not in checkpointed_state:
                rank_zero_info(
                    "EMA is enabled but EMA state is not found in given checkpoint"
                )
            else:
                self.ema_state = EMAState()
                self.ema_state.load_state_dict(checkpointed_state["model_ema"])
                if not self.ema_state.device:
                    # EMA state device not given, move to module device
                    self.ema_state.to(self.device)

    def prepare_for_quant(self) -> pl.LightningModule:
        if hasattr(self.model, "prepare_for_quant"):
            self.model = self.model.prepare_for_quant(self.cfg)
        else:
            self.model = default_prepare_for_quant(self.cfg, self.model)
        return self

    def prepare_for_quant_convert(self) -> pl.LightningModule:
        if hasattr(self.model, "prepare_for_quant_convert"):
            self.model = self.model.prepare_for_quant_convert(self.cfg)
        else:
            self.model = default_prepare_for_quant_convert(
                self.cfg, self.model)
        return self