def evaluate(self, runner, results):
        self.eval_conf(runner, results, bins_number=100)

        error_log_buffer = LogBuffer()
        for result in results:
            error_log_buffer.update(result['Error'])
        error_log_buffer.average()

        # import to tensor-board
        for key in error_log_buffer.output.keys():
            runner.log_buffer.output[key] = error_log_buffer.output[key]

        # for better visualization, format into pandas
        format_output_dict = disp_output_evaluation_in_pandas(
            error_log_buffer.output)

        runner.logger.info(
            "Epoch [{}] Evaluation Result: \t".format(runner.epoch + 1))

        log_items = []
        for key, val in format_output_dict.items():
            if isinstance(val, pd.DataFrame):
                log_items.append("\n{}:\n{} \n".format(key, val))
            elif isinstance(val, float):
                val = "{:.4f}".format(val)
                log_items.append("{}: {}".format(key, val))
            else:
                log_items.append("{}: {}".format(key, val))

        log_str = ", ".join(log_items)
        runner.logger.info(log_str)
        runner.log_buffer.ready = True
        error_log_buffer.clear()
Exemple #2
0
    def __init__(self,
                 model,
                 optimizer=None,
                 work_dir=None,
                 logger=None,
                 meta=None,
                 max_iters=None,
                 max_epochs=None):

        self.model = model
        self.optimizer = optimizer
        self.logger = logger
        self.meta = meta

        if work_dir:
            self.work_dir = osp.abspath(work_dir)
            mmcv.mkdir_or_exist(self.work_dir)

        elif work_dir is None:
            self.work_dir = None

        else:
            raise TypeError('"work_dir" must be a str or None')

        # get model name from the model class
        if hasattr(self.model, 'module'):
            self.model_name = self.model.module.__class__.__name__
        else:
            self.model_name = self.model.__class__.__name__

        self.rank, self.world_size = get_dist_info()
        self.timestamp = get_time_str()
        self.mode = None
        self.hooks = []
        self.epoch = 0
        self.iter = 0
        self.inner_iter = 0

        self.max_epochs = max_epochs
        self.max_iters = max_iters

        # TODO: Redesign LogBuffer, it is not flexible and elegant enough
        self.log_buffer = LogBuffer()
Exemple #3
0
    def evaluate(self, runner, results):
        self.eval_conf(runner, results, bins_number=100)

        error_log_buffer = LogBuffer()
        for result in results:
            error_log_buffer.update(result['Error'])
        error_log_buffer.average()
        log_items = []
        for key in error_log_buffer.output.keys():
            runner.log_buffer.output[key] = error_log_buffer.output[key]

            val = error_log_buffer.output[key]
            if isinstance(val, float):
                val = "{:.4f}".format(val)
            log_items.append("{}: {}".format(key, val))

        # runner.epoch start at 0
        log_str = "Epoch [{}] Evaluation Result: \t".format(runner.epoch + 1)
        log_str += ", ".join(log_items)
        runner.logger.info(log_str)
        runner.log_buffer.ready = True
        error_log_buffer.clear()
Exemple #4
0
    def __init__(self,
                 model,
                 batch_processor=None,
                 optimizer=None,
                 work_dir=None,
                 logger=None,
                 meta=None,
                 max_iters=None,
                 num_stages=4,
                 max_epochs=120):
        # ---- BaseRunner init ----
        if batch_processor is not None:
            if not callable(batch_processor):
                raise TypeError('batch_processor must be callable, '
                                f'but got {type(batch_processor)}')
            # warnings.warn('batch_processor is deprecated, please implement '
            #               'train_step() and val_step() in the model instead.')  # suppressed
            # raise an error is `batch_processor` is not None and
            # `model.train_step()` exists.
            if is_module_wrapper(model):
                _model = model.module
            else:
                _model = model
            if hasattr(_model, 'train_step') or hasattr(_model, 'val_step'):
                raise RuntimeError(
                    'batch_processor and model.train_step()/model.val_step() '
                    'cannot be both available.')
        else:
            assert hasattr(model, 'train_step')

        # check the type of `optimizer`
        if isinstance(optimizer, dict):
            for name, optim in optimizer.items():
                if not isinstance(optim, Optimizer):
                    raise TypeError(
                        f'optimizer must be a dict of torch.optim.Optimizers, '
                        f'but optimizer["{name}"] is a {type(optim)}')
        elif not isinstance(optimizer, Optimizer) and optimizer is not None:
            raise TypeError(
                f'optimizer must be a torch.optim.Optimizer object '
                f'or dict or None, but got {type(optimizer)}')

        # check the type of `logger`
        if not isinstance(logger, logging.Logger):
            raise TypeError(f'logger must be a logging.Logger object, '
                            f'but got {type(logger)}')

        # check the type of `meta`
        if meta is not None and not isinstance(meta, dict):
            raise TypeError(
                f'meta must be a dict or None, but got {type(meta)}')

        self.model = model
        self.batch_processor = batch_processor
        self.optimizer = optimizer
        self.logger = logger
        self.meta = meta

        # create work_dir
        if mmcv.is_str(work_dir):
            self.work_dir = osp.abspath(work_dir)
            mmcv.mkdir_or_exist(self.work_dir)
        elif work_dir is None:
            self.work_dir = None
        else:
            raise TypeError('"work_dir" must be a str or None')

        # get model name from the model class
        if hasattr(self.model, 'module'):
            self._model_name = self.model.module.__class__.__name__
        else:
            self._model_name = self.model.__class__.__name__

        self._rank, self._world_size = get_dist_info()
        self.timestamp = get_time_str()
        self.mode = None
        self._hooks = []
        self._epoch = 0
        self._iter = 0
        self._inner_iter = 0

        if max_epochs is not None and max_iters is not None:
            raise ValueError(
                'Only one of `max_epochs` or `max_iters` can be set.')

        self._max_epochs = max_epochs
        self._max_iters = max_iters
        # TODO: Redesign LogBuffer, it is not flexible and elegant enough
        self.log_buffer = LogBuffer()
        # ---- end BaseRunner init ----

        self.num_stages = num_stages
        self._max_epochs = max_epochs
        assert self._max_epochs % num_stages == 0
        self.epochs_per_stage = self._max_epochs // num_stages
        self.just_resumed = False
        if hasattr(self.model, 'module'):
            self._epoch = self.model.module.start_block * self.epochs_per_stage
        else:
            self._epoch = self.model.start_block * self.epochs_per_stage
Exemple #5
0
class MultiStageRunner(EpochBasedRunner):
    def __init__(self,
                 model,
                 batch_processor=None,
                 optimizer=None,
                 work_dir=None,
                 logger=None,
                 meta=None,
                 max_iters=None,
                 num_stages=4,
                 max_epochs=120):
        # ---- BaseRunner init ----
        if batch_processor is not None:
            if not callable(batch_processor):
                raise TypeError('batch_processor must be callable, '
                                f'but got {type(batch_processor)}')
            # warnings.warn('batch_processor is deprecated, please implement '
            #               'train_step() and val_step() in the model instead.')  # suppressed
            # raise an error is `batch_processor` is not None and
            # `model.train_step()` exists.
            if is_module_wrapper(model):
                _model = model.module
            else:
                _model = model
            if hasattr(_model, 'train_step') or hasattr(_model, 'val_step'):
                raise RuntimeError(
                    'batch_processor and model.train_step()/model.val_step() '
                    'cannot be both available.')
        else:
            assert hasattr(model, 'train_step')

        # check the type of `optimizer`
        if isinstance(optimizer, dict):
            for name, optim in optimizer.items():
                if not isinstance(optim, Optimizer):
                    raise TypeError(
                        f'optimizer must be a dict of torch.optim.Optimizers, '
                        f'but optimizer["{name}"] is a {type(optim)}')
        elif not isinstance(optimizer, Optimizer) and optimizer is not None:
            raise TypeError(
                f'optimizer must be a torch.optim.Optimizer object '
                f'or dict or None, but got {type(optimizer)}')

        # check the type of `logger`
        if not isinstance(logger, logging.Logger):
            raise TypeError(f'logger must be a logging.Logger object, '
                            f'but got {type(logger)}')

        # check the type of `meta`
        if meta is not None and not isinstance(meta, dict):
            raise TypeError(
                f'meta must be a dict or None, but got {type(meta)}')

        self.model = model
        self.batch_processor = batch_processor
        self.optimizer = optimizer
        self.logger = logger
        self.meta = meta

        # create work_dir
        if mmcv.is_str(work_dir):
            self.work_dir = osp.abspath(work_dir)
            mmcv.mkdir_or_exist(self.work_dir)
        elif work_dir is None:
            self.work_dir = None
        else:
            raise TypeError('"work_dir" must be a str or None')

        # get model name from the model class
        if hasattr(self.model, 'module'):
            self._model_name = self.model.module.__class__.__name__
        else:
            self._model_name = self.model.__class__.__name__

        self._rank, self._world_size = get_dist_info()
        self.timestamp = get_time_str()
        self.mode = None
        self._hooks = []
        self._epoch = 0
        self._iter = 0
        self._inner_iter = 0

        if max_epochs is not None and max_iters is not None:
            raise ValueError(
                'Only one of `max_epochs` or `max_iters` can be set.')

        self._max_epochs = max_epochs
        self._max_iters = max_iters
        # TODO: Redesign LogBuffer, it is not flexible and elegant enough
        self.log_buffer = LogBuffer()
        # ---- end BaseRunner init ----

        self.num_stages = num_stages
        self._max_epochs = max_epochs
        assert self._max_epochs % num_stages == 0
        self.epochs_per_stage = self._max_epochs // num_stages
        self.just_resumed = False
        if hasattr(self.model, 'module'):
            self._epoch = self.model.module.start_block * self.epochs_per_stage
        else:
            self._epoch = self.model.start_block * self.epochs_per_stage

    @property
    def epoch(self):
        return self._epoch % self.epochs_per_stage

    @property
    def iter(self):
        return self._iter % self.iters_per_stage

    @property
    def max_epochs(self):
        return self.epochs_per_stage

    @property
    def max_iters(self):
        return self.iters_per_stage

    @property
    def stage(self):
        return int(self._epoch // self.epochs_per_stage + 1)

    def run_iter(self, data_batch, train_mode, **kwargs):
        if self.batch_processor is not None:
            outputs = self.batch_processor(self.model,
                                           data_batch,
                                           train_mode=train_mode,
                                           **kwargs)
        elif train_mode:
            outputs = self.model.train_step(data_batch, self.optimizer,
                                            **kwargs)
        else:
            outputs = self.model.val_step(data_batch, self.optimizer, **kwargs)
        if not isinstance(outputs, dict):
            raise TypeError('"batch_processor()" or "model.train_step()"'
                            'and "model.val_step()" must return a dict')
        if 'log_vars' in outputs:
            self.log_buffer.update(outputs['log_vars'], outputs['num_samples'])
        self.log_buffer.update({'Stage': self.stage})
        self.outputs = outputs

    def train(self, data_loader, **kwargs):
        if self.just_resumed:
            self.call_hook('after_train_epoch')
            self._epoch += 1
            self.just_resumed = False
        self.model.train()
        self.mode = 'train'
        self.data_loader = data_loader
        self._max_iters = self._max_epochs * len(self.data_loader)
        self.iters_per_stage = self._max_iters // self.num_stages
        self.call_hook('before_train_epoch')
        time.sleep(2)  # Prevent possible deadlock during epoch transition
        for i, data_batch in enumerate(self.data_loader):
            self._inner_iter = i
            self.call_hook('before_train_iter')
            self.run_iter(data_batch, train_mode=True)
            self.call_hook('after_train_iter')
            self._iter += 1

        self.call_hook('after_train_epoch')
        self._epoch += 1

    def val(self, data_loader, **kwargs):
        self.model.eval()
        self.mode = 'val'
        self.data_loader = data_loader
        self.call_hook('before_val_epoch')
        time.sleep(2)  # Prevent possible deadlock during epoch transition
        for i, data_batch in enumerate(self.data_loader):
            self._inner_iter = i
            self.call_hook('before_val_iter')
            with torch.no_grad():
                self.run_iter(data_batch, train_mode=False)
            self.call_hook('after_val_iter')

        self.call_hook('after_val_epoch')

    def run(self, data_loaders, workflow, max_epochs=None, **kwargs):
        """Start running.

        Args:
            data_loaders (list[:obj:`DataLoader`]): Dataloaders for training
                and validation.
            workflow (list[tuple]): A list of (phase, epochs) to specify the
                running order and epochs. E.g, [('train', 2), ('val', 1)] means
                running 2 epochs for training and 1 epoch for validation,
                iteratively.
        """
        assert isinstance(data_loaders, list)
        assert mmcv.is_list_of(workflow, tuple)
        assert len(data_loaders) == len(workflow)
        if max_epochs is not None:
            warnings.warn(
                'setting max_epochs in run is deprecated, '
                'please set max_epochs in runner_config', DeprecationWarning)
            self._max_epochs = max_epochs
            assert self._max_epochs % self.num_stages == 0
            self.epochs_per_stage = self._max_epochs // self.num_stages

        assert self._max_epochs is not None, (
            'max_epochs must be specified during instantiation')

        for i, flow in enumerate(workflow):
            mode, epochs = flow
            if mode == 'train':
                self._max_iters = self._max_epochs * len(data_loaders[i])
                self.iters_per_stage = self._max_iters // self.num_stages
                break

        work_dir = self.work_dir if self.work_dir is not None else 'NONE'
        self.logger.info('Start running, host: %s, work_dir: %s',
                         get_host_info(), work_dir)
        self.logger.info('workflow: %s, max: %d epochs', workflow,
                         self._max_epochs)
        self.call_hook('before_run')

        while self._epoch < self._max_epochs:
            for i, flow in enumerate(workflow):
                mode, epochs = flow
                if isinstance(mode, str):  # self.train()
                    if not hasattr(self, mode):
                        raise ValueError(
                            f'runner has no method named "{mode}" to run an '
                            'epoch')
                    epoch_runner = getattr(self, mode)
                else:
                    raise TypeError(
                        'mode in workflow must be a str, but got {}'.format(
                            type(mode)))

                for _ in range(epochs):
                    if mode == 'train' and self._epoch >= self._max_epochs:
                        break
                    epoch_runner(data_loaders[i], **kwargs)

        time.sleep(1)  # wait for some hooks like loggers to finish
        self.call_hook('after_run')

    def save_checkpoint(self,
                        out_dir,
                        filename_tmpl='stage{}_epoch{}.pth',
                        save_optimizer=True,
                        meta=None,
                        create_symlink=True):
        """Save the checkpoint.

        Args:
            out_dir (str): The directory that checkpoints are saved.
            filename_tmpl (str, optional): The checkpoint filename template,
                which contains a placeholder for the epoch number.
                Defaults to 'epoch_{}.pth'.
            save_optimizer (bool, optional): Whether to save the optimizer to
                the checkpoint. Defaults to True.
            meta (dict, optional): The meta information to be saved in the
                checkpoint. Defaults to None.
            create_symlink (bool, optional): Whether to create a symlink
                "latest.pth" to point to the latest checkpoint.
                Defaults to True.
        """
        if meta is None:
            meta = dict(epoch=self._epoch + 1, iter=self._iter)
        elif isinstance(meta, dict):
            meta.update(epoch=self._epoch + 1, iter=self._iter)
        else:
            raise TypeError(
                f'meta should be a dict or None, but got {type(meta)}')
        if self.meta is not None:
            meta.update(self.meta)

        filename = filename_tmpl.format(self.stage, self.epoch + 1)
        filepath = osp.join(out_dir, filename)
        optimizer = self.optimizer if save_optimizer else None
        save_checkpoint(self.model, filepath, optimizer=optimizer, meta=meta)
        # in some environments, `os.symlink` is not supported, you may need to
        # set `create_symlink` to False
        if create_symlink:
            dst_file = osp.join(out_dir, 'latest.pth')
            if platform.system() != 'Windows':
                mmcv.symlink(filename, dst_file)
            else:
                shutil.copy(filepath, dst_file)

    def resume(self, checkpoint, resume_optimizer=True, map_location='cpu'):
        if map_location == 'default':
            if torch.cuda.is_available():
                device_id = torch.cuda.current_device()
                checkpoint = self.load_checkpoint(
                    checkpoint,
                    map_location=lambda storage, loc: storage.cuda(device_id))
            else:
                checkpoint = self.load_checkpoint(checkpoint)
        else:
            checkpoint = self.load_checkpoint(checkpoint,
                                              map_location=map_location)

        self._epoch = checkpoint['meta']['epoch'] - 1
        self.epochs_per_stage = self._max_epochs // self.num_stages
        self._iter = checkpoint['meta']['iter']
        if 'optimizer' in checkpoint and resume_optimizer:
            if isinstance(self.optimizer, Optimizer):
                self.optimizer.load_state_dict(checkpoint['optimizer'])
            elif isinstance(self.optimizer, dict):
                for k in self.optimizer.keys():
                    self.optimizer[k].load_state_dict(
                        checkpoint['optimizer'][k])
            else:
                raise TypeError(
                    'Optimizer should be dict or torch.optim.Optimizer '
                    f'but got {type(self.optimizer)}')
        if hasattr(self.model, 'module'):
            print('start_block', self.model.module.start_block, 'stage',
                  self.stage)
            self.model.module.start_block = self.stage - 1
        else:
            print('start_block', self.model.start_block, 'stage', self.stage)
            self.model.start_block = self.stage - 1
        self.logger.info('resumed stage %d, epoch %d', self.stage,
                         self.epoch + 1)
        self.just_resumed = True
def main():
    args = parse_args()

    cfg = mmcv.Config.fromfile(args.config)
    # set cudnn_benchmark
    if cfg.get('cudnn_benchmark', False):
        torch.backends.cudnn.benchmark = True

    # init distributed env first, since logger depends on the dist info.
    if args.launcher == 'none':
        distributed = False
    else:
        distributed = True
        init_dist(args.launcher, **cfg.dist_params)

    if args.checkpoint is not None:
        cfg.checkpoint = args.checkpoint
    if args.out_dir is not None:
        cfg.out_dir = args.out_dir
    if args.gpus is not None:
        cfg.gpus = args.gpus
    cfg.show = args.show

    mkdir_or_exist(cfg.out_dir)

    # init logger before other step and setup training logger
    logger = get_root_logger(cfg.out_dir,
                             cfg.log_level,
                             filename="test_log.txt")
    logger.info("Using {} GPUs".format(cfg.gpus))
    logger.info('Distributed training: {}'.format(distributed))

    # log environment info
    logger.info("Collecting env info (might take some time)")
    logger.info("\n" + collect_env_info())

    logger.info(args)

    logger.info("Running with config:\n{}".format(cfg.text))

    # build the dataset
    test_dataset = build_dataset(cfg, 'test')

    # build the model and load checkpoint
    model = build_model(cfg)
    checkpoint = load_checkpoint(model, cfg.checkpoint, map_location='cpu')

    if not distributed:
        model = MMDataParallel(model, device_ids=[0])
        outputs = single_gpu_test(model, test_dataset, cfg, args.show)
    else:
        model = MMDistributedDataParallel(model.cuda())
        outputs = multi_gpu_test(model,
                                 test_dataset,
                                 cfg,
                                 args.show,
                                 tmpdir=osp.join(cfg.out_dir, 'temp'))

    rank, _ = get_dist_info()
    if cfg.out_dir is not None and rank == 0:
        result_path = osp.join(cfg.out_dir, 'result.pkl')
        logger.info('\nwriting results to {}'.format(result_path))
        mmcv.dump(outputs, result_path)

        if args.evaluate:
            error_log_buffer = LogBuffer()
            for result in outputs:
                error_log_buffer.update(result['Error'])
            error_log_buffer.average()
            log_items = []
            for key in error_log_buffer.output.keys():

                val = error_log_buffer.output[key]
                if isinstance(val, float):
                    val = '{:.4f}'.format(val)
                log_items.append('{}: {}'.format(key, val))

            if len(error_log_buffer.output) == 0:
                log_items.append('nothing to evaluate!')

            log_str = 'Evaluation Result: \t'
            log_str += ', '.join(log_items)
            logger.info(log_str)
            error_log_buffer.clear()
Exemple #7
0
class Trainer:
    def __init__(self,
                 model,
                 optimizer=None,
                 work_dir=None,
                 logger=None,
                 meta=None,
                 max_iters=None,
                 max_epochs=None):

        self.model = model
        self.optimizer = optimizer
        self.logger = logger
        self.meta = meta

        if work_dir:
            self.work_dir = osp.abspath(work_dir)
            mmcv.mkdir_or_exist(self.work_dir)

        elif work_dir is None:
            self.work_dir = None

        else:
            raise TypeError('"work_dir" must be a str or None')

        # get model name from the model class
        if hasattr(self.model, 'module'):
            self.model_name = self.model.module.__class__.__name__
        else:
            self.model_name = self.model.__class__.__name__

        self.rank, self.world_size = get_dist_info()
        self.timestamp = get_time_str()
        self.mode = None
        self.hooks = []
        self.epoch = 0
        self.iter = 0
        self.inner_iter = 0

        self.max_epochs = max_epochs
        self.max_iters = max_iters

        # TODO: Redesign LogBuffer, it is not flexible and elegant enough
        self.log_buffer = LogBuffer()

    def run_iter(self, data_batch, train_mode, **kwargs):
        if train_mode:
            outputs = self.model.train_step(data_batch, self.optimizer,
                                            **kwargs)
        else:
            outputs = self.model.val_step(data_batch, self.optimizer, **kwargs)

        if not isinstance(outputs, dict):
            raise TypeError('"batch_processor()" or "model.train_step()"'
                            'and "model.val_step()" must return a dict')
        if 'log_vars' in outputs:
            self.log_buffer.update(outputs['log_vars'], outputs['num_samples'])
        self.outputs = outputs

    def train(self, data_loader, **kws):
        self.model.train()
        self.mode = 'train'
        self.data_loader = data_loader
        self.max_iters = self.max_epochs * len(self.data_loader)
        self.call_hook('before_train_epoch')

        time.sleep(0.5)  # Prevent possible deadlock during epoch transition
        for i, data_batch in enumerate(self.data_loader):
            self.inner_iter = i
            self.call_hook('before_train_iter')
            self.run_iter(data_batch, train_mode=True, **kws)
            self.call_hook('after_train_iter')
            self.iter += 1

        self.call_hook('after_train_epoch')
        self.epoch += 1

    def val(self, data_loader, **kws):
        self.model.eval()
        self.mode = 'val'
        self.data_loader = data_loader
        self.call_hook('before_val_epoch')
        time.sleep(0.5)  # Prevent possible deadlock during epoch transition
        for i, data_batch in enumerate(self.data_loader):
            self.inner_iter = i
            self.call_hook('before_val_iter')
            with torch.no_grad():
                self.run_iter(data_batch, train_mode=False, **kws)
            self.call_hook('after_val_iter')

        self.call_hook('after_val_epoch')

    def run(self, data_loaders, workflow, **kws):
        assert isinstance(data_loaders, list)
        assert mmcv.is_list_of(workflow, tuple)
        assert len(data_loaders) == len(workflow)

        assert self.max_epochs is not None, (
            'max_epochs must be specified during instantiation')

        for i, flow in enumerate(workflow):
            mode, epochs = flow
            if mode == 'train':
                self.max_iters = self.max_epochs * len(data_loaders[i])
                break

        work_dir = self.work_dir if self.work_dir is not None else 'NONE'
        self.logger.info('Start running, host: %s, work_dir: %s',
                         get_host_info(), work_dir)
        self.logger.info('workflow: %s, max: %d epochs', workflow,
                         self.max_epochs)
        self.call_hook('before_run')

        while self.epoch < self.max_epochs:
            for i, flow in enumerate(workflow):
                mode, epochs = flow
                if isinstance(mode, str):  # self.train()
                    if not hasattr(self, mode):
                        raise ValueError(
                            f'runner has no method named "{mode}" to run an '
                            'epoch')
                    epoch_runner = getattr(self, mode)
                else:
                    raise TypeError(
                        'mode in workflow must be a str, but got {}'.format(
                            type(mode)))

                for _ in range(epochs):
                    if mode == 'train' and self.epoch >= self.max_epochs:
                        break
                    epoch_runner(data_loaders[i], **kws)

        time.sleep(1)  # wait for some hooks like loggers to finish
        self.call_hook('after_run')

    def save_checkpoint(self,
                        out_dir,
                        filename_tmpl='epoch_{}.pth',
                        save_optimizer=True,
                        meta=None,
                        create_symlink=True):

        if meta is None:
            meta = dict(epoch=self.epoch + 1, iter=self.iter)
        elif isinstance(meta, dict):
            meta.update(epoch=self.epoch + 1, iter=self.iter)
        else:
            raise TypeError(
                f'meta should be a dict or None, but got {type(meta)}')
        if self.meta is not None:
            meta.update(self.meta)

        filename = filename_tmpl.format(self.epoch + 1)
        filepath = osp.join(out_dir, filename)
        optimizer = self.optimizer if save_optimizer else None
        save_checkpoint(self.model, filepath, optimizer=optimizer, meta=meta)

        if create_symlink:
            dst_file = osp.join(out_dir, 'latest.pth')
            mmcv.symlink(filename, dst_file)

    def current_lr(self):
        """Get current learning rates.
        Returns:
            list[float] | dict[str, list[float]]: Current learning rates of all
                param groups. If the runner has a dict of optimizers, this
                method will return a dict.
        """
        if isinstance(self.optimizer, torch.optim.Optimizer):
            lr = [group['lr'] for group in self.optimizer.param_groups]
        elif isinstance(self.optimizer, dict):
            lr = dict()
            for name, optim in self.optimizer.items():
                lr[name] = [group['lr'] for group in optim.param_groups]
        else:
            raise RuntimeError(
                'lr is not applicable because optimizer does not exist.')
        return lr

    def current_momentum(self):
        """Get current momentums.
        Returns:
            list[float] | dict[str, list[float]]: Current momentums of all
                param groups. If the runner has a dict of optimizers, this
                method will return a dict.
        """
        def _get_momentum(optimizer):
            momentums = []
            for group in optimizer.param_groups:
                if 'momentum' in group.keys():
                    momentums.append(group['momentum'])
                elif 'betas' in group.keys():
                    momentums.append(group['betas'][0])
                else:
                    momentums.append(0)
            return momentums

        if self.optimizer is None:
            raise RuntimeError(
                'momentum is not applicable because optimizer does not exist.')
        elif isinstance(self.optimizer, torch.optim.Optimizer):
            momentums = _get_momentum(self.optimizer)
        elif isinstance(self.optimizer, dict):
            momentums = dict()
            for name, optim in self.optimizer.items():
                momentums[name] = _get_momentum(optim)
        return momentums

    def register_hook(self, hook, priority='NORMAL'):
        """Register a hook into the hook list.
        The hook will be inserted into a priority queue, with the specified
        priority (See :class:`Priority` for details of priorities).
        For hooks with the same priority, they will be triggered in the same
        order as they are registered.
        Args:
            hook (:obj:`Hook`): The hook to be registered.
            priority (int or str or :obj:`Priority`): Hook priority.
                Lower value means higher priority.
        """
        if not hook:
            return

        assert isinstance(hook, Hook)
        if hasattr(hook, 'priority'):
            raise ValueError('"priority" is a reserved attribute for hooks')
        priority = get_priority(priority)
        hook.priority = priority
        # insert the hook to a sorted list
        inserted = False
        for i in range(len(self.hooks) - 1, -1, -1):
            if priority >= self.hooks[i].priority:
                self.hooks.insert(i + 1, hook)
                inserted = True
                break
        if not inserted:
            self.hooks.insert(0, hook)

    def call_hook(self, fn_name):
        """Call all hooks.
        Args:
            fn_name (str): The function name in each hook to be called, such as
                "before_train_epoch".
        """
        for hook in self.hooks:
            getattr(hook, fn_name)(self)

    def load_checkpoint(self, filename, map_location='cpu', strict=False):
        self.logger.info('load checkpoint from %s', filename)
        return load_checkpoint(self.model, filename, map_location, strict,
                               self.logger)

    def resume(self,
               checkpoint,
               resume_optimizer=True,
               map_location='default'):
        if map_location == 'default':
            if torch.cuda.is_available():
                device_id = torch.cuda.current_device()
                checkpoint = self.load_checkpoint(
                    checkpoint,
                    map_location=lambda storage, loc: storage.cuda(device_id))
            else:
                checkpoint = self.load_checkpoint(checkpoint)
        else:
            checkpoint = self.load_checkpoint(checkpoint,
                                              map_location=map_location)

        self.epoch = checkpoint['meta']['epoch']
        self.iter = checkpoint['meta']['iter']

        if 'optimizer' in checkpoint and resume_optimizer:
            self.optimizer.load_state_dict(checkpoint['optimizer'])

        self.logger.info('resumed epoch %d, iter %d', self.epoch, self.iter)
def main():
    args = parse_args()

    cfg = mmcv.Config.fromfile(args.config)
    # set cudnn_benchmark
    if cfg.get('cudnn_benchmark', False):
        torch.backends.cudnn.benchmark = True

    # init distributed env first, since logger depends on the dist info.
    if args.launcher == 'none':
        distributed = False
    else:
        distributed = True
        init_dist(args.launcher, **cfg.dist_params)

    if args.checkpoint is not None:
        cfg.checkpoint = args.checkpoint
    if args.out_dir is not None:
        cfg.out_dir = args.out_dir
    if args.gpus is not None:
        cfg.gpus = args.gpus
    cfg.show = True if args.show == 'True' else False

    mkdir_or_exist(cfg.out_dir)

    # init logger before other step and setup training logger
    # init the logger before other steps
    timestamp = time.strftime('%Y%m%d_%H%M%S', time.localtime())
    log_file = osp.join(cfg.out_dir, '{}_test_log.txt'.format(timestamp))
    logger = get_root_logger(cfg.out_dir, cfg.log_level, filename=log_file)
    logger.info("Using {} GPUs".format(cfg.gpus))
    logger.info('Distributed training: {}'.format(distributed))
    logger.info("Whether the result will be saved to disk in image: {}".format(
        args.show))

    # log environment info
    logger.info("Collecting env info (might take some time)")
    dash_line = '-' * 60 + '\n'
    logger.info('Environment info:\n' + dash_line)
    logger.info("\n" + collect_env_info())
    logger.info('\n' + dash_line)

    logger.info(args)

    logger.info("Running with config:\n{}".format(cfg.text))

    # build the dataset
    test_dataset = build_dataset(cfg, 'test')

    # build the model and load checkpoint
    model = build_model(cfg)
    checkpoint = load_checkpoint(model, cfg.checkpoint, map_location='cpu')

    if not distributed:
        model = MMDataParallel(model, device_ids=[0])
        outputs = single_gpu_test(model, test_dataset, cfg, cfg.show)
    else:
        model = MMDistributedDataParallel(
            model.cuda(),
            device_ids=[torch.cuda.current_device()],
            broadcast_buffers=False)
        outputs = multi_gpu_test(model,
                                 test_dataset,
                                 cfg,
                                 cfg.show,
                                 tmpdir=osp.join(cfg.out_dir, 'temp'))

    rank, _ = get_dist_info()
    if cfg.out_dir is not None and rank == 0:
        result_path = osp.join(cfg.out_dir, 'result.pkl')
        logger.info('\nwriting results to {}'.format(result_path))
        mmcv.dump(outputs, result_path)

        if args.validate:
            error_log_buffer = LogBuffer()
            for result in outputs:
                error_log_buffer.update(result['Error'])
            error_log_buffer.average()

            task = cfg.get('task', 'stereo')
            # for better visualization, format into pandas
            format_output_dict = output_evaluation_in_pandas(
                error_log_buffer.output, task)

            log_items = []
            for key, val in format_output_dict.items():
                if isinstance(val, pd.DataFrame):
                    log_items.append("\n{}:\n{} \n".format(key, val))
                elif isinstance(val, float):
                    val = "{:.4f}".format(val)
                    log_items.append("{}: {}".format(key, val))
                else:
                    log_items.append("{}: {}".format(key, val))

            if len(error_log_buffer.output) == 0:
                log_items.append('nothing to evaluate!')

            log_str = 'Evaluation Result: \t'
            log_str += ", ".join(log_items)
            logger.info(log_str)
            error_log_buffer.clear()