Ejemplo n.º 1
0
    def _check_head(self, runner, dataset):
        """Check whether the `num_classes` in head matches the length of
        `CLASSES` in `dataset`.

        Args:
            runner (obj:`EpochBasedRunner`, `IterBasedRunner`): runner object.
            dataset (obj: `BaseDataset`): the dataset to check.
        """
        model = runner.model
        if dataset.CLASSES is None:
            runner.logger.warning(
                f'Please set `CLASSES` '
                f'in the {dataset.__class__.__name__} and'
                f'check if it is consistent with the `num_classes` '
                f'of head')
        else:
            assert is_seq_of(dataset.CLASSES, str), \
                (f'`CLASSES` in {dataset.__class__.__name__}'
                 f'should be a tuple of str.')
            for name, module in model.named_modules():
                if hasattr(module, 'num_classes'):
                    assert module.num_classes == len(dataset.CLASSES), \
                        (f'The `num_classes` ({module.num_classes}) in '
                         f'{module.__class__.__name__} of '
                         f'{model.__class__.__name__} does not matches '
                         f'the length of `CLASSES` '
                         f'{len(dataset.CLASSES)}) in '
                         f'{dataset.__class__.__name__}')
Ejemplo n.º 2
0
def _concat_dataset(cfg, default_args=None):
    types = cfg['type']
    ann_files = cfg['ann_file']
    img_prefixes = cfg.get('img_prefix', None)
    dataset_infos = cfg.get('dataset_info', None)

    num_joints = cfg['data_cfg'].get('num_joints', None)
    dataset_channel = cfg['data_cfg'].get('dataset_channel', None)

    datasets = []
    num_dset = len(ann_files)
    for i in range(num_dset):
        cfg_copy = copy.deepcopy(cfg)
        cfg_copy['ann_file'] = ann_files[i]

        if isinstance(types, (list, tuple)):
            cfg_copy['type'] = types[i]
        if isinstance(img_prefixes, (list, tuple)):
            cfg_copy['img_prefix'] = img_prefixes[i]
        if isinstance(dataset_infos, (list, tuple)):
            cfg_copy['dataset_info'] = dataset_infos[i]

        if isinstance(num_joints, (list, tuple)):
            cfg_copy['data_cfg']['num_joints'] = num_joints[i]

        if is_seq_of(dataset_channel, list):
            cfg_copy['data_cfg']['dataset_channel'] = dataset_channel[i]

        datasets.append(build_dataset(cfg_copy, default_args))

    return ConcatDataset(datasets)
Ejemplo n.º 3
0
    def __init__(self,
                 dataloader,
                 start=None,
                 interval=1,
                 by_epoch=True,
                 save_best=None,
                 rule=None,
                 test_fn=None,
                 greater_keys=None,
                 less_keys=None,
                 **eval_kwargs):
        if not isinstance(dataloader, DataLoader):
            raise TypeError(f'dataloader must be a pytorch DataLoader, '
                            f'but got {type(dataloader)}')

        if interval <= 0:
            raise ValueError(f'interval must be a positive number, '
                             f'but got {interval}')

        assert isinstance(by_epoch, bool), '``by_epoch`` should be a boolean'

        if start is not None and start < 0:
            raise ValueError(f'The evaluation start epoch {start} is smaller '
                             f'than 0')

        self.dataloader = dataloader
        self.interval = interval
        self.start = start
        self.by_epoch = by_epoch

        assert isinstance(save_best, str) or save_best is None, \
            '""save_best"" should be a str or None ' \
            f'rather than {type(save_best)}'
        self.save_best = save_best
        self.eval_kwargs = eval_kwargs
        self.initial_flag = True

        if test_fn is None:
            from mmcv.engine import single_gpu_test
            self.test_fn = single_gpu_test
        else:
            self.test_fn = test_fn

        if greater_keys is None:
            self.greater_keys = self._default_greater_keys
        else:
            if not isinstance(greater_keys, (list, tuple)):
                greater_keys = (greater_keys, )
            assert is_seq_of(greater_keys, str)
            self.greater_keys = greater_keys

        if less_keys is None:
            self.less_keys = self._default_less_keys
        else:
            if not isinstance(less_keys, (list, tuple)):
                less_keys = (less_keys, )
            assert is_seq_of(less_keys, str)
            self.less_keys = less_keys

        if self.save_best is not None:
            self.best_ckpt_path = None
            self._init_rule(rule, self.save_best)