Beispiel #1
0
    def __init__(self, cfg, model_build_func):
        """
        Args:
            cfg (BaseConfig):
        """
        logger = logging.getLogger("cvpods")
        if not logger.isEnabledFor(
                logging.INFO):  # setup_logger is not called for d2
            setup_logger()

        self.start_iter = 0

        data_loader = self.build_train_loader(cfg)
        epoch_iters = adjust_epoch_and_iter(cfg, data_loader)
        self.max_iter = cfg.SOLVER.LR_SCHEDULER.MAX_ITER
        self.max_epoch = cfg.SOLVER.LR_SCHEDULER.MAX_EPOCH

        model = model_build_func(cfg)
        model = maybe_convert_module(model)
        logger.info(f"Model structure: {model}")

        # Assume these objects must be constructed in this order.
        optimizer = self.build_optimizer(cfg, model)

        # For training, wrap with DDP. But don't need this for inference.
        if comm.get_world_size() > 1:
            model = DistributedDataParallel(model,
                                            device_ids=[comm.get_local_rank()],
                                            broadcast_buffers=False,
                                            find_unused_parameters=True)
        # TODO: @wangfeng02, `batch_subdivisions`
        super().__init__(model, data_loader, optimizer,
                         cfg.SOLVER.BATCH_SUBDIVISIONS)

        if not cfg.SOLVER.LR_SCHEDULER.get("EPOCH_WISE", False):
            epoch_iters = -1

        self.scheduler = self.build_lr_scheduler(cfg,
                                                 optimizer,
                                                 epoch_iters=epoch_iters)
        # Assume no other objects need to be checkpointed.
        # We can later make it checkpoint the stateful hooks
        self.checkpointer = DetectionCheckpointer(
            # Assume you want to save checkpoints together with logs/statistics
            model,
            cfg.OUTPUT_DIR,
            optimizer=optimizer,
            scheduler=self.scheduler,
        )

        self.cfg = cfg
        self.register_hooks(self.build_hooks())
Beispiel #2
0
    def __init__(self, cfg, build_model):
        """
        Args:
            cfg (config dict):
        """

        self.data_loader = self.build_train_loader(cfg)
        # Assume these objects must be constructed in this order.
        model = build_model(cfg)
        self.model = maybe_convert_module(model)
        logger.info(f"Model: \n{self.model}")

        # Assume these objects must be constructed in this order.
        self.optimizer = self.build_optimizer(cfg, self.model)

        if cfg.TRAINER.FP16.ENABLED:
            self.mixed_precision = True
            if cfg.TRAINER.FP16.TYPE == "APEX":
                from apex import amp
                self.model, self.optimizer = amp.initialize(
                    self.model,
                    self.optimizer,
                    opt_level=cfg.TRAINER.FP16.OPTS.OPT_LEVEL)
        else:
            self.mixed_precision = False

        # For training, wrap with DDP. But don't need this for inference.
        if comm.get_world_size() > 1:
            torch.cuda.set_device(comm.get_local_rank())
            if cfg.MODEL.DDP_BACKEND == "torch":
                self.model = DistributedDataParallel(
                    self.model,
                    device_ids=[comm.get_local_rank()],
                    broadcast_buffers=False,
                    find_unused_parameters=True)
            elif cfg.MODEL.DDP_BACKEND == "apex":
                from apex.parallel import DistributedDataParallel as ApexDistributedDataParallel
                self.model = ApexDistributedDataParallel(self.model)
            else:
                raise ValueError("non-supported DDP backend: {}".format(
                    cfg.MODEL.DDP_BACKEND))

        super().__init__(
            self.model,
            self.data_loader,
            self.optimizer,
        )

        if not cfg.SOLVER.LR_SCHEDULER.get("EPOCH_WISE", False):
            epoch_iters = -1
        else:
            epoch_iters = cfg.SOLVER.LR_SCHEDULER.get("EPOCH_ITERS")
            logger.warning(f"Setup LR Scheduler in EPOCH mode: {epoch_iters}")

        auto_scale_config(cfg, self.data_loader)
        self.scheduler = self.build_lr_scheduler(cfg,
                                                 self.optimizer,
                                                 epoch_iters=epoch_iters)
        # Assume no other objects need to be checkpointed.
        # We can later make it checkpoint the stateful hooks
        self.checkpointer = DefaultCheckpointer(
            # Assume you want to save checkpoints together with logs/statistics
            self.model,
            cfg.OUTPUT_DIR,
            optimizer=self.optimizer,
            scheduler=self.scheduler,
        )

        self.start_iter = 0
        self.start_epoch = 0
        self.max_iter = cfg.SOLVER.LR_SCHEDULER.MAX_ITER
        self.max_epoch = cfg.SOLVER.LR_SCHEDULER.MAX_EPOCH
        self.window_size = cfg.TRAINER.WINDOW_SIZE

        self.cfg = cfg

        self.register_hooks(self.build_hooks())