def after_train_epoch(self, runner):
        self.update_attr(runner.model)

        # save ema model
        if not self.every_n_epochs(runner, self.interval):
            return
        if not self.out_dir:
            self.out_dir = runner.work_dir

        meta = runner.meta

        if meta is None:
            meta = dict(epoch=runner.epoch + 1, iter=runner.iter)
        else:
            meta.update(epoch=runner.epoch + 1, iter=runner.iter)

        filename = 'epoch_ema_{}.pth'.format(runner.epoch + 1)
        filepath = osp.join(self.out_dir, filename)
        optimizer = runner.optimizer if self.save_optimizer else None
        save_checkpoint(self.ema, filepath, optimizer=optimizer, meta=meta)
        if self.create_symlink:
            mmcv.symlink(filename, osp.join(self.out_dir, 'latest_ema.pth'))

        # remove other checkpoints
        if self.max_keep_ckpts > 0:
            filename_tmpl = self.args.get('filename_tmpl', 'epoch_ema_{}.pth')
            current_epoch = runner.epoch + 1
            for epoch in range(current_epoch - self.max_keep_ckpts, 0, -1):
                ckpt_path = os.path.join(self.out_dir,
                                         filename_tmpl.format(epoch))
                if os.path.exists(ckpt_path):
                    os.remove(ckpt_path)
                else:
                    break
    def after_train_epoch(self, runner):
        """Update the parameters of the averaged model, save and evaluate the
        updated averaged model."""
        model = runner.model
        # update the parameters of the averaged model
        self.model.update_parameters(model)

        # save the swa model
        runner.logger.info(
            f'Saving swa model at swa-training {runner.epoch + 1} epoch')
        filename = 'swa_model_{}.pth'.format(runner.epoch + 1)
        filepath = osp.join(runner.work_dir, filename)
        optimizer = runner.optimizer
        self.meta['hook_msgs']['last_ckpt'] = filepath
        save_checkpoint(self.model.module,
                        filepath,
                        optimizer=optimizer,
                        meta=self.meta)

        # evaluate the swa model
        if self.swa_eval:
            self.work_dir = runner.work_dir
            self.rank = runner.rank
            self.epoch = runner.epoch
            self.logger = runner.logger
            self.meta['hook_msgs']['last_ckpt'] = filename
            self.eval_hook.after_train_epoch(self)
            for name, val in self.log_buffer.output.items():
                name = 'swa_' + name
                runner.log_buffer.output[name] = val
            runner.log_buffer.ready = True
            self.log_buffer.clear()
Example #3
0
    def save_checkpoint(self,
                        out_dir,
                        filename_tmpl='epoch_{}.pth',
                        save_optimizer=True,
                        meta=None):

        if meta is None:
            meta = dict(epoch=self.epoch + 1, iter=self.iter)
        else:
            meta.update(epoch=self.epoch + 1, iter=self.iter)

        filename = filename_tmpl.format(self.epoch + 1)
        filepath = osp.join(out_dir, filename)
        linkpath = osp.join(out_dir, 'latest.pth')
        optimizer = self.optimizer if save_optimizer else None
        save_checkpoint(self.model, filepath, optimizer=optimizer, meta=meta)
        # use relative symlink
        mmcv.symlink(filename, linkpath)

        filename_tmpl = 'adv_' + filename_tmpl
        filename = filename_tmpl.format(self.epoch + 1)
        filepath = osp.join(out_dir, filename)
        linkpath = osp.join(out_dir, 'adv_latest.pth')
        optimizer = self.adv_optimizer if save_optimizer else None
        save_checkpoint(self.adv_model,
                        filepath,
                        optimizer=optimizer,
                        meta=meta)
        # use relative symlink
        mmcv.symlink(filename, linkpath)
    def after_run(self, runner):
        runner.logger.info(f'Saving ema parameters')

        # swap the buffer parameters and swap them back
        self._swap_ema_parameters()
        filepath = os.path.join(runner.work_dir, 'ema.pth')
        save_checkpoint(runner.model, filepath, optimizer=False, meta=runner.meta)
        self._swap_ema_parameters()
Example #5
0
    def _save_ckpt(self, model, filepath, meta, runner):
        save_checkpoint(model, filepath, runner.optimizer, meta)

        for i in range(20):
            try:
                ckpt = torch.load(filepath, map_location='cpu')
                runner.logger.info(
                    f'Success Saving swa model at swa-training {runner.epoch + 1} epoch'
                )
                break
            except Exception as e:
                save_checkpoint(model, filepath, runner.optimizer, meta)
                continue
Example #6
0
def test_load_from_local():
    import os
    home_path = os.path.expanduser('~')
    checkpoint_path = os.path.join(
        home_path, 'dummy_checkpoint_used_to_test_load_from_local.pth')
    model = Model()
    save_checkpoint(model, checkpoint_path)
    checkpoint = load_from_local(
        '~/dummy_checkpoint_used_to_test_load_from_local.pth',
        map_location=None)
    assert_tensor_equal(checkpoint['state_dict']['block.conv.weight'],
                        model.block.conv.weight)
    os.remove(checkpoint_path)
Example #7
0
def test_load_classes_name():
    import os

    import tempfile

    from mmcv.runner import load_checkpoint, save_checkpoint
    checkpoint_path = os.path.join(tempfile.gettempdir(), 'checkpoint.pth')
    model = Model()
    save_checkpoint(model, checkpoint_path)
    checkpoint = load_checkpoint(model, checkpoint_path)
    assert 'meta' in checkpoint and 'CLASSES' not in checkpoint['meta']

    model.CLASSES = ('class1', 'class2')
    save_checkpoint(model, checkpoint_path)
    checkpoint = load_checkpoint(model, checkpoint_path)
    assert 'meta' in checkpoint and 'CLASSES' in checkpoint['meta']
    assert checkpoint['meta']['CLASSES'] == ('class1', 'class2')

    model = Model()
    wrapped_model = DDPWrapper(model)
    save_checkpoint(wrapped_model, checkpoint_path)
    checkpoint = load_checkpoint(wrapped_model, checkpoint_path)
    assert 'meta' in checkpoint and 'CLASSES' not in checkpoint['meta']

    wrapped_model.module.CLASSES = ('class1', 'class2')
    save_checkpoint(wrapped_model, checkpoint_path)
    checkpoint = load_checkpoint(wrapped_model, checkpoint_path)
    assert 'meta' in checkpoint and 'CLASSES' in checkpoint['meta']
    assert checkpoint['meta']['CLASSES'] == ('class1', 'class2')

    # remove the temp file
    os.remove(checkpoint_path)
    def save_checkpoint(self,
                        out_dir,
                        filename_tmpl='epoch_{}.pth',
                        save_optimizer=True,
                        meta=None):
        if meta is None:
            meta = dict(epoch=self.epoch + 1, iter=self.iter)
        else:
            meta.update(epoch=self.epoch + 1, iter=self.iter)

        filename = filename_tmpl.format(self.epoch + 1)
        filepath = osp.join(out_dir, filename)
        optimizer = self.optimizer if save_optimizer else None
        save_checkpoint(self.model, filepath, optimizer=optimizer, meta=meta)
Example #9
0
    def save_checkpoint(
        self,
        out_dir,
        filename_tmpl="epoch_{}.pth",
        save_optimizer=True,
        meta=None,
        create_symlink=True,
    ):
        """Save the checkpoint.

        Args:
            out_dir (str): The directory that checkpoints are saved.
            filename_tmpl (str, optional): The checkpoint filename template,
                which contains a placeholder for the epoch number.
                Defaults to 'epoch_{}.pth'.
            save_optimizer (bool, optional): Whether to save the optimizer to
                the checkpoint. Defaults to True.
            meta (dict, optional): The meta information to be saved in the
                checkpoint. Defaults to None.
            create_symlink (bool, optional): Whether to create a symlink
                "latest.pth" to point to the latest checkpoint.
                Defaults to True.
        """
        if meta is None:
            meta = dict(epoch=self.epoch + 1, iter=self.iter)
        elif isinstance(meta, dict):
            meta.update(epoch=self.epoch + 1, iter=self.iter)
        else:
            raise TypeError(
                f"meta should be a dict or None, but got {type(meta)}")
        if self.meta is not None:
            meta.update(self.meta)

        filename = filename_tmpl.format(self.epoch + 1)
        filepath = osp.join(out_dir, filename)
        optimizer = self.optimizer if save_optimizer else None
        save_checkpoint(self.model, filepath, optimizer=optimizer, meta=meta)
        self.model.save_checkpoint(out_dir, tag="ds", client_state=meta)
        # in some environments, `os.symlink` is not supported, you may need to
        # set `create_symlink` to False
        if create_symlink:
            dst_file = osp.join(out_dir, "latest.pth")
            if platform.system() != "Windows":
                mmcv.symlink(filename, dst_file)
            else:
                shutil.copy(filepath, dst_file)
Example #10
0
    def after_train_epoch(self, runner):
        """Called after each training epoch to evaluate the model."""
        if not self.every_n_epochs(runner, self.interval):
            return

        start_epoch = self.eval_kwargs.get('start_epoch', -1)
        if start_epoch != -1 and (runner.epoch + 1) < start_epoch:
            return

        current_ckpt_path = osp.join(runner.work_dir,
                                     f'epoch_{runner.epoch + 1}.pth')
        json_path = osp.join(runner.work_dir, 'best.json')

        if osp.exists(json_path) and len(self.best_json) == 0:
            self.best_json = mmcv.load(json_path)
            self.best_score = self.best_json['best_score']
            self.best_ckpt = self.best_json['best_ckpt']
            self.key_indicator = self.best_json['key_indicator']

        from mmpose.apis import multi_gpu_test
        results = multi_gpu_test(runner.model,
                                 self.dataloader,
                                 tmpdir=osp.join(runner.work_dir,
                                                 '.eval_hook'),
                                 gpu_collect=self.gpu_collect)
        if runner.rank == 0:
            print('\n')
            key_score = self.evaluate(runner, results)
            if (self.save_best
                    and self.compare_func(key_score, self.best_score)):
                self.best_score = key_score
                self.logger.info(
                    f'Now best checkpoint is epoch_{runner.epoch + 1}.pth')
                save_checkpoint(runner.model,
                                osp.join(runner.work_dir, 'model_best.pth'))
                self.logger.info(
                    f'Saved best model at epoch {runner.epoch + 1} !')
                self.best_json['best_score'] = self.best_score
                self.best_json['best_ckpt'] = current_ckpt_path
                self.best_json['key_indicator'] = self.key_indicator
                mmcv.dump(self.best_json, json_path)
Example #11
0
    def print_model(self, runner, print_flops_acts=True, print_channel=True):
        """Print the related information of the current model.

        Args:
            runner (Runner): Runner in mmcv
            print_flops_acts (bool): Print the remained percentage of
                flops and acts
            print_channel (bool): Print information about
                the number of reserved channels.
        """

        if print_flops_acts:
            flops, acts = self.compute_flops_acts()
            runner.logger.info('Flops: {:.2f}%, Acts: {:.2f}%'.format(
                flops * 100, acts * 100))
            if len(self.save_flops_thr):
                flops_thr = self.save_flops_thr[0]
                if flops < flops_thr:
                    self.save_flops_thr.pop(0)
                    path = osp.join(
                        runner.work_dir, 'flops_{:.0f}_acts_{:.0f}.pth'.format(
                            flops * 100, acts * 100))
                    save_checkpoint(runner.model, filename=path)
            if len(self.save_acts_thr):
                acts_thr = self.save_acts_thr[0]
                if acts < acts_thr:
                    self.save_acts_thr.pop(0)
                    path = osp.join(
                        runner.work_dir, 'acts_{:.0f}_flops_{:.0f}.pth'.format(
                            acts * 100, flops * 100))
                    save_checkpoint(runner.model, filename=path)
        if print_channel:
            for module, name in self.conv_names.items():
                chans_i = int(module.in_mask.sum().cpu().numpy())
                chans_o = int(module.out_mask.sum().cpu().numpy())
                runner.logger.info(
                    '{}: input_channels: {}/{}, out_channels: {}/{}'.format(
                        name, chans_i, module.in_channels, chans_o,
                        module.out_channels))
Example #12
0
    def save_checkpoint(self,
                        out_dir,
                        filename_tmpl='epoch_{}.pth',
                        save_optimizer=True,
                        meta=dict(),
                        offset=1,
                        iter_no_offset=False):
        global_step = get_global_step()
        # when we save checkpoint after the epoch finished, the iter has already plus 1.
        iter_offset = 0 if iter_no_offset else offset
        meta.update(epoch=self.epoch + offset,
                    iter=self.iter + iter_offset,
                    inner_iter=self.inner_iter + iter_offset,
                    global_step=global_step,
                    batchsize=self.batchsize,
                    initial_lr=self.initial_lr)

        filename = osp.join(out_dir, filename_tmpl.format(self.epoch + offset))
        local_filename = "./{}".format(
            filename_tmpl.format(self.epoch + offset))
        optimizer = self.optimizer if save_optimizer else None
        save_checkpoint(self.model, filename, optimizer=optimizer, meta=meta)
        mmcv.symlink(local_filename, osp.join(out_dir, 'latest.pth'))
Example #13
0
    def save_checkpoint(self,
                        out_dir,
                        filename_tmpl='epoch_{}.pth',
                        save_optimizer=True,
                        meta=None):
        """remove symlink to avoid error in windows file system"""
        if meta is None:
            meta = dict(epoch=self.epoch + 1, iter=self.iter)
        else:
            meta.update(epoch=self.epoch + 1, iter=self.iter)

        filename = filename_tmpl.format(self.epoch + 1)
        filepath = osp.join(out_dir, filename)
        linkpath = osp.join(out_dir, 'latest.pth')
        optimizer = self.optimizer if save_optimizer else None
        save_checkpoint(self.model, filepath, optimizer=optimizer, meta=meta)
        if self.ema_model is not None:
            save_checkpoint(self.ema_model, f'{filepath}-ema.pth')
        # use relative symlink
        try:
            mmcv.symlink(filename, linkpath)
        except:
            print('Failed to symlink from {} to {}.'.format(
                filename, linkpath))
Example #14
0
def test_checkpoint_loader():
    import os

    import tempfile

    from mmcv.runner import CheckpointLoader, _load_checkpoint, save_checkpoint
    checkpoint_path = os.path.join(tempfile.gettempdir(), 'checkpoint.pth')
    model = Model()
    save_checkpoint(model, checkpoint_path)
    checkpoint = _load_checkpoint(checkpoint_path)
    assert 'meta' in checkpoint and 'CLASSES' not in checkpoint['meta']
    # remove the temp file
    os.remove(checkpoint_path)

    filenames = [
        'http://xx.xx/xx.pth', 'https://xx.xx/xx.pth',
        'modelzoo://xx.xx/xx.pth', 'torchvision://xx.xx/xx.pth',
        'open-mmlab://xx.xx/xx.pth', 'openmmlab://xx.xx/xx.pth',
        'mmcls://xx.xx/xx.pth', 'pavi://xx.xx/xx.pth', 's3://xx.xx/xx.pth',
        'ss3://xx.xx/xx.pth', ' s3://xx.xx/xx.pth',
        'open-mmlab:s3://xx.xx/xx.pth', 'openmmlab:s3://xx.xx/xx.pth',
        'openmmlabs3://xx.xx/xx.pth', ':s3://xx.xx/xx.path'
    ]
    fn_names = [
        'load_from_http', 'load_from_http', 'load_from_torchvision',
        'load_from_torchvision', 'load_from_openmmlab', 'load_from_openmmlab',
        'load_from_mmcls', 'load_from_pavi', 'load_from_ceph',
        'load_from_local', 'load_from_local', 'load_from_ceph',
        'load_from_ceph', 'load_from_local', 'load_from_local'
    ]

    for filename, fn_name in zip(filenames, fn_names):
        loader = CheckpointLoader._get_checkpoint_loader(filename)
        assert loader.__name__ == fn_name

    @CheckpointLoader.register_scheme(prefixes='ftp://')
    def load_from_ftp(filename, map_location):
        return dict(filename=filename)

    # test register_loader
    filename = 'ftp://xx.xx/xx.pth'
    loader = CheckpointLoader._get_checkpoint_loader(filename)
    assert loader.__name__ == 'load_from_ftp'

    def load_from_ftp1(filename, map_location):
        return dict(filename=filename)

    # test duplicate registered error
    with pytest.raises(KeyError):
        CheckpointLoader.register_scheme('ftp://', load_from_ftp1)

    # test force param
    CheckpointLoader.register_scheme('ftp://', load_from_ftp1, force=True)
    checkpoint = CheckpointLoader.load_checkpoint(filename)
    assert checkpoint['filename'] == filename

    # test print function name
    loader = CheckpointLoader._get_checkpoint_loader(filename)
    assert loader.__name__ == 'load_from_ftp1'

    # test sort
    @CheckpointLoader.register_scheme(prefixes='a/b')
    def load_from_ab(filename, map_location):
        return dict(filename=filename)

    @CheckpointLoader.register_scheme(prefixes='a/b/c')
    def load_from_abc(filename, map_location):
        return dict(filename=filename)

    filename = 'a/b/c/d'
    loader = CheckpointLoader._get_checkpoint_loader(filename)
    assert loader.__name__ == 'load_from_abc'
Example #15
0
def test_load_checkpoint_metadata():
    import os

    import tempfile

    from mmcv.runner import load_checkpoint, save_checkpoint

    class ModelV1(nn.Module):
        def __init__(self):
            super().__init__()
            self.block = Block()
            self.conv1 = nn.Conv2d(3, 3, 1)
            self.conv2 = nn.Conv2d(3, 3, 1)
            nn.init.normal_(self.conv1.weight)
            nn.init.normal_(self.conv2.weight)

    class ModelV2(nn.Module):
        _version = 2

        def __init__(self):
            super().__init__()
            self.block = Block()
            self.conv0 = nn.Conv2d(3, 3, 1)
            self.conv1 = nn.Conv2d(3, 3, 1)
            nn.init.normal_(self.conv0.weight)
            nn.init.normal_(self.conv1.weight)

        def _load_from_state_dict(self, state_dict, prefix, local_metadata,
                                  *args, **kwargs):
            """load checkpoints."""

            # Names of some parameters in has been changed.
            version = local_metadata.get('version', None)
            if version is None or version < 2:
                state_dict_keys = list(state_dict.keys())
                convert_map = {'conv1': 'conv0', 'conv2': 'conv1'}
                for k in state_dict_keys:
                    for ori_str, new_str in convert_map.items():
                        if k.startswith(prefix + ori_str):
                            new_key = k.replace(ori_str, new_str)
                            state_dict[new_key] = state_dict[k]
                            del state_dict[k]

            super()._load_from_state_dict(state_dict, prefix, local_metadata,
                                          *args, **kwargs)

    model_v1 = ModelV1()
    model_v1_conv0_weight = model_v1.conv1.weight.detach()
    model_v1_conv1_weight = model_v1.conv2.weight.detach()
    model_v2 = ModelV2()
    model_v2_conv0_weight = model_v2.conv0.weight.detach()
    model_v2_conv1_weight = model_v2.conv1.weight.detach()
    ckpt_v1_path = os.path.join(tempfile.gettempdir(), 'checkpoint_v1.pth')
    ckpt_v2_path = os.path.join(tempfile.gettempdir(), 'checkpoint_v2.pth')

    # Save checkpoint
    save_checkpoint(model_v1, ckpt_v1_path)
    save_checkpoint(model_v2, ckpt_v2_path)

    # test load v1 model
    load_checkpoint(model_v2, ckpt_v1_path)
    assert torch.allclose(model_v2.conv0.weight, model_v1_conv0_weight)
    assert torch.allclose(model_v2.conv1.weight, model_v1_conv1_weight)

    # test load v2 model
    load_checkpoint(model_v2, ckpt_v2_path)
    assert torch.allclose(model_v2.conv0.weight, model_v2_conv0_weight)
    assert torch.allclose(model_v2.conv1.weight, model_v2_conv1_weight)
Example #16
0
                ['{:.4f}'.format(iou * 100) for iou in IU_array])
            logging.info(val_info)
            val_log.write(val_info)

            if mean_IU > max_mIU:
                max_mIU, max_mIU_array, max_mIU_step = mean_IU, IU_array, step
                if step >= 35000:
                    filename = 'CS_scenes_step-{:d}_maxIU-{:.4f}.pth'.format(
                        step, mean_IU)
                    filepath = os.path.join(args.snapshot_dir, filename)
                    torch.save(model.student.state_dict(), filepath)
                    filename = 'mmseg_step-{:d}_maxIU-{:.4f}.pth'.format(
                        step, mean_IU)
                    filepath = os.path.join(args.snapshot_dir, filename)
                    save_checkpoint(model.student,
                                    filepath,
                                    optimizer=None,
                                    meta=None)

            if step % 10000 == 0 or step in [100, 200, 300, 1000]:
                filename = 'CS_scenes_step-{:d}_mIU-{:.4f}.pth'.format(
                    step, mean_IU)
                filepath = os.path.join(args.snapshot_dir, filename)
                torch.save(model.student.state_dict(), filepath)
                filename = 'mmseg_step-{:d}_mIU-{:.4f}.pth'.format(
                    step, mean_IU)
                filepath = os.path.join(args.snapshot_dir, filename)
                save_checkpoint(model.student,
                                filepath,
                                optimizer=None,
                                meta=None)
Example #17
0
def main():
    args = parse_args()
    cfg = Config.fromfile(args.config)
    if args.cfg_options is not None:  #将option操作合入cfg中
        cfg.merge_from_dict(args.cfg_options)
    # import modules from string list.
    if cfg.get('custom_imports', None):
        from mmcv.utils import import_modules_from_strings
        import_modules_from_strings(**cfg['custom_imports'])
    # set cudnn_benchmark
    if cfg.get('cudnn_benchmark', False):
        torch.backends.cudnn.benchmark = True

    # work_dir is determined in this priority: CLI > segment in file > filename
    if args.work_dir is not None:
        # update configs according to CLI args if args.work_dir is not None
        cfg.work_dir = args.work_dir
    elif cfg.get('work_dir', None) is None:
        # use config filename as default work_dir if cfg.work_dir is None
        cfg.work_dir = osp.join('./work_dirs',
                                osp.splitext(osp.basename(args.config))[0])
    if args.resume_from is not None:
        cfg.resume_from = args.resume_from
    if args.gpu_ids is not None:
        cfg.gpu_ids = args.gpu_ids
    else:
        cfg.gpu_ids = range(1) if args.gpus is None else range(
            args.gpus)  #默认1个gpu
    # if cfg.pdb_debug:
    #     from tools.pdb_install_handle import install_pdb_handler
    #     install_pdb_handler()
    #     print('CTRL+'' breaks into pdb.')
    # init distributed env first, since logger depends on the dist info.
    if args.launcher == 'none':
        distributed = False
    else:
        distributed = True
        init_dist(args.launcher, **cfg.dist_params)

    # create work_dir
    mmcv.mkdir_or_exist(osp.abspath(cfg.work_dir))
    # dump config
    cfg.dump(osp.join(cfg.work_dir,
                      osp.basename(args.config)))  #整合格式,生成新的config文件
    # init the logger before other steps
    timestamp = time.strftime('%Y%m%d_%H%M%S', time.localtime())
    log_file = osp.join(cfg.work_dir, f'{timestamp}.log')
    logger = get_root_logger(log_file=log_file, log_level=cfg.log_level)

    # init the meta dict to record some important information such as
    # environment info and seed, which will be logged
    meta = dict()
    # log env info
    env_info_dict = collect_env()
    env_info = '\n'.join([(f'{k}: {v}') for k, v in env_info_dict.items()])
    dash_line = '-' * 60 + '\n'
    logger.info('Environment info:\n' + dash_line + env_info + '\n' +
                dash_line)
    meta['env_info'] = env_info
    meta['config'] = cfg.pretty_text
    # log some basic info
    logger.info(f'Distributed training: {distributed}')
    logger.info(f'Config:\n{cfg.pretty_text}')

    # set random seeds
    if args.seed is not None:
        logger.info(f'Set random seed to {args.seed}, '
                    f'deterministic: {args.deterministic}')
        set_random_seed(args.seed, deterministic=args.deterministic)
    cfg.seed = args.seed
    meta['seed'] = args.seed

    model = build_detector(cfg.model,
                           train_cfg=cfg.train_cfg,
                           test_cfg=cfg.test_cfg)
    # model = model.to(memory_format=torch.channels_last)

    datasets = [build_dataset(cfg.data.train)]
    if len(cfg.workflow) == 2:
        val_dataset = copy.deepcopy(cfg.data.val)
        val_dataset.pipeline = cfg.data.train.pipeline
        datasets.append(build_dataset(val_dataset))
    if cfg.checkpoint_config is not None:
        # save mmdet version, config file content and class names in
        # checkpoints as meta data
        cfg.checkpoint_config.meta = dict(mmdet_version=__version__ +
                                          get_git_hash()[:7],
                                          CLASSES=datasets[0].CLASSES)
    # add an attribute for visualization convenience
    model.CLASSES = datasets[0].CLASSES
    save_checkpoint(
        model=model,
        filename="./work_dirs/paa_r50fulconv_fpn2x_coco/paa_r50_lite.pth")
    train_detector(model,
                   datasets,
                   cfg,
                   distributed=distributed,
                   validate=(not args.no_validate),
                   timestamp=timestamp,
                   meta=meta)
Example #18
0
def test_save_checkpoint(tmp_path):
    model = Model()
    optimizer = optim.SGD(model.parameters(), lr=0.1, momentum=0.9)
    # meta is not a dict
    with pytest.raises(TypeError):
        save_checkpoint(model, '/path/of/your/filename', meta='invalid type')

    # 1. save to disk
    filename = str(tmp_path / 'checkpoint1.pth')
    save_checkpoint(model, filename)

    filename = str(tmp_path / 'checkpoint2.pth')
    save_checkpoint(model, filename, optimizer)

    filename = str(tmp_path / 'checkpoint3.pth')
    save_checkpoint(model, filename, meta={'test': 'test'})

    filename = str(tmp_path / 'checkpoint4.pth')
    save_checkpoint(model, filename, file_client_args={'backend': 'disk'})

    # 2. save to petrel oss
    with patch.object(PetrelBackend, 'put') as mock_method:
        filename = 's3://path/of/your/checkpoint1.pth'
        save_checkpoint(model, filename)
    mock_method.assert_called()

    with patch.object(PetrelBackend, 'put') as mock_method:
        filename = 's3://path//of/your/checkpoint2.pth'
        save_checkpoint(model,
                        filename,
                        file_client_args={'backend': 'petrel'})
    mock_method.assert_called()