Exemple #1
0
    def load_annotations(self, ann_file):
        """Load annotation from COCO style annotation file.

        Args:
            ann_file (str): Path of annotation file.

        Returns:
            list[dict]: Annotation info from COCO api.
        """
        if self.ann_file_backend == 'disk':
            self.coco = COCO(ann_file)
        else:
            mmcv_version = digit_version(mmcv.__version__)
            if mmcv_version < digit_version('1.3.16'):
                raise Exception('Please update mmcv to 1.3.16 or higher '
                                'to enable "get_local_path" of "FileClient".')
            file_client = mmcv.FileClient(backend=self.ann_file_backend)
            with file_client.get_local_path(ann_file) as local_path:
                self.coco = COCO(local_path)
        self.cat_ids = self.coco.get_cat_ids(cat_names=self.CLASSES)
        self.cat2label = {cat_id: i for i, cat_id in enumerate(self.cat_ids)}
        self.img_ids = self.coco.get_img_ids()
        data_infos = []

        count = 0
        for i in self.img_ids:
            info = self.coco.load_imgs([i])[0]
            info['filename'] = info['file_name']
            data_infos.append(info)
            count = count + 1
            if count > self.select_first_k and self.select_first_k > 0:
                break
        return data_infos
Exemple #2
0
def train_detector(model,
                   dataset,
                   cfg,
                   distributed=False,
                   validate=False,
                   timestamp=None,
                   meta=None):
    logger = get_root_logger(cfg.log_level)

    # prepare data loaders
    dataset = dataset if isinstance(dataset, (list, tuple)) else [dataset]
    # step 1: give default values and override (if exist) from cfg.data
    loader_cfg = {
        **dict(seed=cfg.get('seed'),
               drop_last=False,
               dist=distributed,
               num_gpus=len(cfg.gpu_ids)),
        **({} if torch.__version__ != 'parrots' else dict(
               prefetch_num=2,
               pin_memory=False,
           )),
        **dict((k, cfg.data[k]) for k in [
                   'samples_per_gpu',
                   'workers_per_gpu',
                   'shuffle',
                   'seed',
                   'drop_last',
                   'prefetch_num',
                   'pin_memory',
                   'persistent_workers',
               ] if k in cfg.data)
    }

    # step 2: cfg.data.train_dataloader has highest priority
    train_loader_cfg = dict(loader_cfg, **cfg.data.get('train_dataloader', {}))

    data_loaders = [build_dataloader(ds, **train_loader_cfg) for ds in dataset]

    # put model on gpus
    if distributed:
        find_unused_parameters = cfg.get('find_unused_parameters', False)
        # Sets the `find_unused_parameters` parameter in
        # torch.nn.parallel.DistributedDataParallel
        model = MMDistributedDataParallel(
            model.cuda(),
            device_ids=[torch.cuda.current_device()],
            broadcast_buffers=False,
            find_unused_parameters=find_unused_parameters)
    else:
        if not torch.cuda.is_available():
            assert digit_version(mmcv.__version__) >= digit_version('1.4.4'), \
                'Please use MMCV >= 1.4.4 for CPU training!'
        model = MMDataParallel(model, device_ids=cfg.gpu_ids)

    # build runner
    optimizer = build_optimizer(model, cfg.optimizer)

    if 'runner' not in cfg:
        cfg.runner = {
            'type': 'EpochBasedRunner',
            'max_epochs': cfg.total_epochs
        }
        warnings.warn(
            'config is now expected to have a `runner` section, '
            'please set `runner` in your config.', UserWarning)
    else:
        if 'total_epochs' in cfg:
            assert cfg.total_epochs == cfg.runner.max_epochs

    runner = build_runner(cfg.runner,
                          default_args=dict(model=model,
                                            optimizer=optimizer,
                                            work_dir=cfg.work_dir,
                                            logger=logger,
                                            meta=meta))

    # an ugly workaround to make .log and .log.json filenames the same
    runner.timestamp = timestamp

    # fp16 setting
    fp16_cfg = cfg.get('fp16', None)
    if fp16_cfg is not None:
        optimizer_config = Fp16OptimizerHook(**cfg.optimizer_config,
                                             **fp16_cfg,
                                             distributed=distributed)
    elif distributed and 'type' not in cfg.optimizer_config:
        optimizer_config = OptimizerHook(**cfg.optimizer_config)
    else:
        optimizer_config = cfg.optimizer_config

    # register hooks
    runner.register_training_hooks(cfg.lr_config,
                                   optimizer_config,
                                   cfg.checkpoint_config,
                                   cfg.log_config,
                                   cfg.get('momentum_config', None),
                                   custom_hooks_config=cfg.get(
                                       'custom_hooks', None))
    if distributed:
        if isinstance(runner, EpochBasedRunner):
            runner.register_hook(DistSamplerSeedHook())

    # register eval hooks
    if validate:
        val_samples_per_gpu = (cfg.data.get('val_dataloader', {})).get(
            'samples_per_gpu', cfg.data.get('samples_per_gpu', 1))
        if val_samples_per_gpu > 1:
            # Support batch_size > 1 in test for text recognition
            # by disable MultiRotateAugOCR since it is useless for most case
            cfg = disable_text_recog_aug_test(cfg)
            cfg = replace_image_to_tensor(cfg)

        val_dataset = build_dataset(cfg.data.val, dict(test_mode=True))

        val_loader_cfg = {
            **loader_cfg,
            **dict(shuffle=False, drop_last=False),
            **cfg.data.get('val_dataloader', {}),
            **dict(samples_per_gpu=val_samples_per_gpu)
        }

        val_dataloader = build_dataloader(val_dataset, **val_loader_cfg)

        eval_cfg = cfg.get('evaluation', {})
        eval_cfg['by_epoch'] = cfg.runner['type'] != 'IterBasedRunner'
        eval_hook = DistEvalHook if distributed else EvalHook
        runner.register_hook(eval_hook(val_dataloader, **eval_cfg))

    if cfg.resume_from:
        runner.resume(cfg.resume_from)
    elif cfg.load_from:
        runner.load_checkpoint(cfg.load_from)
    runner.run(data_loaders, cfg.workflow)
def test_digit_version():
    assert digit_version('0.2.16') == (0, 2, 16, 0, 0, 0)
    assert digit_version('1.2.3') == (1, 2, 3, 0, 0, 0)
    assert digit_version('1.2.3rc0') == (1, 2, 3, 0, -1, 0)
    assert digit_version('1.2.3rc1') == (1, 2, 3, 0, -1, 1)
    assert digit_version('1.0rc0') == (1, 0, 0, 0, -1, 0)
    assert digit_version('1.0') == digit_version('1.0.0')
    assert digit_version('1.5.0+cuda90_cudnn7.6.3_lms') == digit_version('1.5')
    assert digit_version('1.0.0dev') < digit_version('1.0.0a')
    assert digit_version('1.0.0a') < digit_version('1.0.0a1')
    assert digit_version('1.0.0a') < digit_version('1.0.0b')
    assert digit_version('1.0.0b') < digit_version('1.0.0rc')
    assert digit_version('1.0.0rc1') < digit_version('1.0.0')
    assert digit_version('1.0.0') < digit_version('1.0.0post')
    assert digit_version('1.0.0post') < digit_version('1.0.0post1')
    assert digit_version('v1') == (1, 0, 0, 0, 0, 0)
    assert digit_version('v1.1.5') == (1, 1, 5, 0, 0, 0)