def before_run(self, runner):
        """Construct the averaged model which will keep track of the running
        averages of the parameters of the model."""
        model = runner.model
        self.model = AveragedModel(model)

        self.log_buffer = LogBuffer()
        self.meta = runner.meta
        if self.meta is None:
            self.meta = dict()
            self.meta.setdefault('hook_msgs', dict())
示例#2
0
def train_model(model,
                optimizer,
                train_loader,
                lr_scheduler,
                optim_cfg,
                start_epoch,
                total_epochs,
                start_iter,
                rank,
                logger,
                ckpt_save_dir,
                lr_warmup_scheduler=None,
                ckpt_save_interval=1,
                max_ckpt_save_num=50,
                log_interval=20):
    accumulated_iter = start_iter

    log_buffer = LogBuffer()

    for cur_epoch in range(start_epoch, total_epochs):

        trained_epoch = cur_epoch + 1
        accumulated_iter = train_one_epoch(
            model,
            optimizer,
            train_loader,
            lr_scheduler=lr_scheduler,
            lr_warmup_scheduler=lr_warmup_scheduler,
            accumulated_iter=accumulated_iter,
            train_epoch=trained_epoch,
            optim_cfg=optim_cfg,
            rank=rank,
            logger=logger,
            log_buffer=log_buffer,
            log_interval=log_interval)

        # save trained model
        if trained_epoch % ckpt_save_interval == 0 and rank == 0:

            ckpt_list = glob.glob(
                os.path.join(ckpt_save_dir, 'checkpoint_epoch_*.pth'))
            ckpt_list.sort(key=os.path.getmtime)

            if ckpt_list.__len__() >= max_ckpt_save_num:
                for cur_file_idx in range(
                        0,
                        len(ckpt_list) - max_ckpt_save_num + 1):
                    os.remove(ckpt_list[cur_file_idx])

            ckpt_name = os.path.join(ckpt_save_dir,
                                     ('checkpoint_epoch_%d' % trained_epoch))
            save_checkpoint(
                checkpoint_state(model, optimizer, trained_epoch,
                                 accumulated_iter),
                filename=ckpt_name,
            )
示例#3
0
    def __init__(self, model,
                 batch_processor,
                 optimizer=None,
                 work_dir=None,
                 log_level=logging.INFO,
                 logger=None):
        self.optimizer = []
        for opt in optimizer:
            self.optimizer.append(self.init_optimizer(opt))
        self.model = model
        self.batch_processor = batch_processor

        # create work_dir
        if mmcv.is_str(work_dir):
            self.work_dir = osp.abspath(work_dir)
            mmcv.mkdir_or_exist(self.work_dir)
        elif work_dir is None:
            self.work_dir = None
        else:
            raise TypeError('"work_dir" must be a str or None')

        # get model name from the model class
        if hasattr(self.model, 'module'):
            self._model_name = self.model.module.__class__.__name__
        else:
            self._model_name = self.model.__class__.__name__

        self._rank, self._world_size = get_dist_info()
        self.timestamp = get_time_str()
        if logger is None:
            self.logger = self.init_logger(work_dir, log_level)
        else:
            self.logger = logger
        self.log_buffer = LogBuffer()

        self.mode = None
        self._hooks = []
        self._epoch = 0
        self._iter = 0
        self._inner_iter = 0
        self._max_epochs = 0
        self._max_iters = 0
示例#4
0
class MyRunner(runner.Runner):

    def __init__(self, model,
                 batch_processor,
                 optimizer=None,
                 work_dir=None,
                 log_level=logging.INFO,
                 logger=None):
        self.optimizer = []
        for opt in optimizer:
            self.optimizer.append(self.init_optimizer(opt))
        self.model = model
        self.batch_processor = batch_processor

        # create work_dir
        if mmcv.is_str(work_dir):
            self.work_dir = osp.abspath(work_dir)
            mmcv.mkdir_or_exist(self.work_dir)
        elif work_dir is None:
            self.work_dir = None
        else:
            raise TypeError('"work_dir" must be a str or None')

        # get model name from the model class
        if hasattr(self.model, 'module'):
            self._model_name = self.model.module.__class__.__name__
        else:
            self._model_name = self.model.__class__.__name__

        self._rank, self._world_size = get_dist_info()
        self.timestamp = get_time_str()
        if logger is None:
            self.logger = self.init_logger(work_dir, log_level)
        else:
            self.logger = logger
        self.log_buffer = LogBuffer()

        self.mode = None
        self._hooks = []
        self._epoch = 0
        self._iter = 0
        self._inner_iter = 0
        self._max_epochs = 0
        self._max_iters = 0

    def current_lr(self):
        """Get current learning rates.

        Returns:
            list: Current learning rate of all param groups.
        """
        if self.optimizer is None:
            raise RuntimeError(
                'lr is not applicable because optimizer does not exist.')
        lr_list = []
        for opt in self.optimizer:
            lr_list.append([group['lr'] for group in opt.param_groups])
        return lr_list

    def register_logger_hooks(self, log_config):
        log_interval = log_config['interval']
        for info in log_config['hooks']:
            logger_hook = obj_from_dict(
                info, hooks, default_args=dict(interval=log_interval))
            self.register_hook(logger_hook, priority='VERY_LOW')


    def register_lr_hooks(self, lr_config):
        if isinstance(lr_config, MyLrUpdaterHook):
            self.register_hook(lr_config)
        elif isinstance(lr_config, dict):
            assert 'policy' in lr_config
            # from .hooks import lr_updater
            hook_name = lr_config['policy'].title() + 'LrUpdaterHook'
            if not hasattr(lr_updater, hook_name):
                raise ValueError('"{}" does not exist'.format(hook_name))
            hook_cls = getattr(lr_updater, hook_name)
            self.register_hook(hook_cls(**lr_config))
        else:
            raise TypeError('"lr_config" must be either a LrUpdaterHook object'
                            ' or dict, not {}'.format(type(lr_config)))

    def register_training_hooks(self,
                                lr_config,
                                optimizer_config=None,
                                checkpoint_config=None,
                                log_config=None):
        """Register default hooks for training.

        Default hooks include:

        - LrUpdaterHook
        - OptimizerStepperHook
        - CheckpointSaverHook
        - IterTimerHook
        - LoggerHook(s)
        """
        if optimizer_config is None:
            optimizer_config = {}
        if checkpoint_config is None:
            checkpoint_config = {}
        self.register_lr_hooks(lr_config)
        self.register_hook(self.build_hook(optimizer_config, MyOptimizerHook))
        self.register_hook(self.build_hook(checkpoint_config, CheckpointHook))
        self.register_hook(IterTimerHook())
        if log_config is not None:
            self.register_logger_hooks(log_config)


    def save_checkpoint(self,
                        out_dir,
                        filename_tmpl='epoch_{}.pth',
                        save_optimizer=True,
                        meta=None):
        if meta is None:
            meta = dict(epoch=self.epoch + 1, iter=self.iter)
        else:
            meta.update(epoch=self.epoch + 1, iter=self.iter)

        filename = filename_tmpl.format(self.epoch + 1)
        filepath = osp.join(out_dir, filename)
        linkpath = osp.join(out_dir, 'latest.pth')
        optimizer = None
        save_checkpoint(self.model, filepath, optimizer=optimizer, meta=meta)
        # use relative symlink
        mmcv.symlink(filename, linkpath)

    def train(self, data_loader, **kwargs):
        self.model.train()
        self.mode = 'train'
        self.data_loader = data_loader
        self._max_iters = self._max_epochs * len(data_loader)
        self.call_hook('before_train_epoch')
        for i, data_batch in enumerate(data_loader):
            self._inner_iter = i
            self.call_hook('before_train_iter')
            outputs = self.batch_processor(
                self.model, data_batch, epoch=self.epoch, train_mode=True, **kwargs)
            if not isinstance(outputs, dict):
                raise TypeError('batch_processor() must return a dict')
            if 'log_vars' in outputs:
                self.log_buffer.update(outputs['log_vars'],
                                       outputs['num_samples'])
            self.outputs = outputs
            self.call_hook('after_train_iter')
            self._iter += 1

        self.call_hook('after_train_epoch')
        self._epoch += 1

    def val(self, data_loader, **kwargs):
        self.model.eval()
        self.mode = 'val'
        self.data_loader = data_loader
        self.call_hook('before_val_epoch')

        for i, data_batch in enumerate(data_loader):
            self._inner_iter = i
            self.call_hook('before_val_iter')
            with torch.no_grad():
                outputs = self.batch_processor(
                    self.model, data_batch, train_mode=False, **kwargs)
            if not isinstance(outputs, dict):
                raise TypeError('batch_processor() must return a dict')
            if 'log_vars' in outputs:
                self.log_buffer.update(outputs['log_vars'],
                                       outputs['num_samples'])
            self.outputs = outputs
            self.call_hook('after_val_iter')

        self.call_hook('after_val_epoch')




    def run(self, data_loaders, workflow, max_epochs, **kwargs):
        """Start running.

        Args:
            data_loaders (list[:obj:`DataLoader`]): Dataloaders for training
                and validation.
            workflow (list[tuple]): A list of (phase, epochs) to specify the
                running order and epochs. E.g, [('train', 2), ('val', 1)] means
                running 2 epochs for training and 1 epoch for validation,
                iteratively.
            max_epochs (int): Total training epochs.
        """
        assert isinstance(data_loaders, list)
        assert mmcv.is_list_of(workflow, tuple)
        assert len(data_loaders) == len(workflow)

        self._max_epochs = max_epochs
        work_dir = self.work_dir if self.work_dir is not None else 'NONE'
        self.logger.info('Start running, host: %s, work_dir: %s',
                         get_host_info(), work_dir)
        self.logger.info('workflow: %s, max: %d epochs', workflow, max_epochs)
        self.call_hook('before_run')

        while self.epoch < max_epochs:
            for i, flow in enumerate(workflow):
                mode, epochs = flow
                if isinstance(mode, str):  # self.train()
                    if not hasattr(self, mode):
                        raise ValueError(
                            'runner has no method named "{}" to run an epoch'.
                            format(mode))
                    epoch_runner = getattr(self, mode)
                elif callable(mode):  # custom train()
                    epoch_runner = mode
                else:
                    raise TypeError('mode in workflow must be a str or '
                                    'callable function, not {}'.format(
                                        type(mode)))
                for _ in range(epochs):
                    if mode == 'train' and self.epoch >= max_epochs:
                        return
                    epoch_runner(data_loaders[i], **kwargs)

        time.sleep(1)  # wait for some hooks like loggers to finish
        self.call_hook('after_run')
示例#5
0
class SWAHook(Hook):
    r"""SWA Object Detection Hook.

        This hook works together with SWA training config files to train
        SWA object detectors <https://arxiv.org/abs/2012.12645>.

        Args:
            swa_eval (bool): Whether to evaluate the swa model.
                Defaults to True.
            eval_hook (Hook): Hook class that contains evaluation functions.
                Defaults to None.
            swa_interval (int): The epoch interval to perform swa
    """
    def __init__(self, swa_eval=True, eval_hook=None, swa_interval=1):
        if not isinstance(swa_eval, bool):
            raise TypeError('swa_eval must be a bool, but got'
                            f'{type(swa_eval)}')
        if swa_eval:
            if not isinstance(eval_hook, EvalHook) and \
               not isinstance(eval_hook, DistEvalHook):
                raise TypeError('eval_hook must be either a EvalHook or a '
                                'DistEvalHook when swa_eval = True, but got'
                                f'{type(eval_hook)}')
        self.swa_eval = swa_eval
        self.eval_hook = eval_hook
        self.swa_interval = swa_interval

    def before_run(self, runner):
        """Construct the averaged model which will keep track of the running
        averages of the parameters of the model."""
        model = runner.model
        self.model = AveragedModel(model)

        self.meta = runner.meta
        if self.meta is None:
            self.meta = dict()
            self.meta.setdefault('hook_msgs', dict())
        if isinstance(self.meta, dict) and 'hook_msgs' not in self.meta:
            self.meta.setdefault('hook_msgs', dict())
        self.log_buffer = LogBuffer()

    def after_train_epoch(self, runner):
        """Update the parameters of the averaged model, save and evaluate the
        updated averaged model."""
        model = runner.model
        # Whether to perform swa
        if (runner.epoch + 1) % self.swa_interval == 0:
            swa_flag = True
        else:
            swa_flag = False
        # update the parameters of the averaged model
        if swa_flag:
            self.model.update_parameters(model)

            # save the swa model
            runner.logger.info(
                f'Saving swa model at swa-training {runner.epoch + 1} epoch')
            filename = 'swa_model_{}.pth'.format(runner.epoch + 1)
            filepath = osp.join(runner.work_dir, filename)
            optimizer = runner.optimizer
            self.meta['hook_msgs']['last_ckpt'] = filepath
            save_checkpoint(self.model.module,
                            filepath,
                            optimizer=optimizer,
                            meta=self.meta)

        # evaluate the swa model
        if self.swa_eval and swa_flag:
            self.work_dir = runner.work_dir
            self.rank = runner.rank
            self.epoch = runner.epoch
            self.logger = runner.logger
            self.meta['hook_msgs']['last_ckpt'] = filename
            self.eval_hook.after_train_epoch(self)
            for name, val in self.log_buffer.output.items():
                name = 'swa_' + name
                runner.log_buffer.output[name] = val
            runner.log_buffer.ready = True
            self.log_buffer.clear()

    def after_run(self, runner):
        # since BN layers in the backbone are frozen,
        # we do not need to update the BN for the swa model
        pass

    def before_epoch(self, runner):
        pass
示例#6
0
    def __init__(
        self,
        model,
        batch_processor=None,
        optimizer=None,
        work_dir=None,
        logger=None,
        meta=None,
        max_iters=None,
        max_epochs=None,
    ):
        if batch_processor is not None:
            if not callable(batch_processor):
                raise TypeError("batch_processor must be callable, "
                                f"but got {type(batch_processor)}")
            warnings.warn("batch_processor is deprecated, please implement "
                          "train_step() and val_step() in the model instead.")
            # raise an error is `batch_processor` is not None and
            # `model.train_step()` exists.
            if is_module_wrapper(model):
                _model = model.module
            else:
                _model = model
            if hasattr(_model, "train_step") or hasattr(_model, "val_step"):
                raise RuntimeError(
                    "batch_processor and model.train_step()/model.val_step() "
                    "cannot be both available.")
        else:
            assert hasattr(model, "train_step")

        # check the type of `optimizer`
        if isinstance(optimizer, dict):
            for name, optim in optimizer.items():
                if not isinstance(optim, Optimizer):
                    raise TypeError(
                        f"optimizer must be a dict of torch.optim.Optimizers, "
                        f'but optimizer["{name}"] is a {type(optim)}')
        elif not isinstance(optimizer, Optimizer) and optimizer is not None:
            pass
            # raise TypeError(
            #     f'optimizer must be a torch.optim.Optimizer object '
            #     f'or dict or None, but got {type(optimizer)}')

        # check the type of `logger`
        if not isinstance(logger, logging.Logger):
            raise TypeError(f"logger must be a logging.Logger object, "
                            f"but got {type(logger)}")

        # check the type of `meta`
        if meta is not None and not isinstance(meta, dict):
            raise TypeError(
                f"meta must be a dict or None, but got {type(meta)}")

        self.model = model
        self.batch_processor = batch_processor
        self.optimizer = optimizer
        self.logger = logger
        self.meta = meta

        # create work_dir
        if mmcv.is_str(work_dir):
            self.work_dir = osp.abspath(work_dir)
            mmcv.mkdir_or_exist(self.work_dir)
        elif work_dir is None:
            self.work_dir = None
        else:
            raise TypeError('"work_dir" must be a str or None')

        # get model name from the model class
        if hasattr(self.model, "module"):
            self._model_name = self.model.module.__class__.__name__
        else:
            self._model_name = self.model.__class__.__name__

        self._rank, self._world_size = get_dist_info()
        self.timestamp = get_time_str()
        self.mode = None
        self._hooks = []
        self._epoch = 0
        self._iter = 0
        self._inner_iter = 0

        if max_epochs is not None and max_iters is not None:
            raise ValueError(
                "Only one of `max_epochs` or `max_iters` can be set.")

        self._max_epochs = max_epochs
        self._max_iters = max_iters
        # TODO: Redesign LogBuffer, it is not flexible and elegant enough
        self.log_buffer = LogBuffer()
示例#7
0
def main():
    args = parse_args()
    cfg = Config.fromfile(args.cfg)
    work_dir = cfg.work_dir
    os.environ["CUDA_VISIBLE_DEVICES"] = ",".join(
        str(device_id) for device_id in cfg.device_ids)
    log_dir = os.path.join(work_dir, 'logs')
    if not os.path.exists(log_dir):
        os.makedirs(log_dir)
    logger = init_logger(log_dir)
    seed = cfg.seed
    logger.info('Set random seed to {}'.format(seed))
    set_random_seed(seed)

    train_dataset = get_dataset(cfg.data.train)
    train_data_loader = build_dataloader(
        train_dataset,
        cfg.data.imgs_per_gpu,
        cfg.data.workers_per_gpu,
        len(cfg.device_ids),
        dist=False,
    )
    val_dataset = get_dataset(cfg.data.val)
    val_data_loader = build_dataloader(val_dataset,
                                       1,
                                       cfg.data.workers_per_gpu,
                                       1,
                                       dist=False,
                                       shuffle=False)

    model = build_detector(cfg.model,
                           train_cfg=cfg.train_cfg,
                           test_cfg=cfg.test_cfg)
    model = MMDataParallel(model).cuda()
    optimizer = obj_from_dict(cfg.optimizer, torch.optim,
                              dict(params=model.parameters()))
    lr_scheduler = obj_from_dict(cfg.lr_scedule, LRschedule,
                                 dict(optimizer=optimizer))

    checkpoint_dir = os.path.join(cfg.work_dir, 'checkpoint_dir')
    if not os.path.exists(checkpoint_dir):
        os.makedirs(checkpoint_dir)

    start_epoch = cfg.start_epoch
    if cfg.resume_from:
        checkpoint = load_checkpoint(model, cfg.resume_from)
        start_epoch = 0
        logger.info('resumed epoch {}, from {}'.format(start_epoch,
                                                       cfg.resume_from))

    log_buffer = LogBuffer()
    for epoch in range(start_epoch, cfg.end_epoch):
        train(train_data_loader, model, optimizer, epoch, lr_scheduler,
              log_buffer, cfg, logger)
        tmp_checkpoint_file = os.path.join(checkpoint_dir, 'tmp_val.pth')
        meta_dict = cfg._cfg_dict
        logger.info('save tmp checkpoint to {}'.format(tmp_checkpoint_file))
        save_checkpoint(model, tmp_checkpoint_file, optimizer, meta=meta_dict)
        if len(cfg.device_ids) == 1:
            sensitivity = val(val_data_loader, model, cfg, logger, epoch)
        else:
            model_args = cfg.model.copy()
            model_args.update(train_cfg=cfg.train_cfg, test_cfg=cfg.test_cfg)
            model_type = getattr(detectors, model_args.pop('type'))
            results = parallel_test(
                cfg,
                model_type,
                model_args,
                tmp_checkpoint_file,
                val_dataset,
                np.arange(len(cfg.device_ids)).tolist(),
                workers_per_gpu=1,
            )

            sensitivity = evaluate_deep_lesion(results, val_dataset,
                                               cfg.cfg_3dce, logger)
        save_file = os.path.join(
            checkpoint_dir, 'epoch_{}_sens@4FP_{:.5f}_{}.pth'.format(
                epoch + 1, sensitivity,
                time.strftime('%m-%d-%H-%M', time.localtime(time.time()))))
        os.rename(tmp_checkpoint_file, save_file)
        logger.info('save checkpoint to {}'.format(save_file))
        if epoch > cfg.lr_scedule.T_max:
            os.remove(save_file)