def _check_head(self, runner, dataset): """Check whether the `num_classes` in head matches the length of `CLASSES` in `dataset`. Args: runner (obj:`EpochBasedRunner`, `IterBasedRunner`): runner object. dataset (obj: `BaseDataset`): the dataset to check. """ model = runner.model if dataset.CLASSES is None: runner.logger.warning( f'Please set `CLASSES` ' f'in the {dataset.__class__.__name__} and' f'check if it is consistent with the `num_classes` ' f'of head') else: assert is_seq_of(dataset.CLASSES, str), \ (f'`CLASSES` in {dataset.__class__.__name__}' f'should be a tuple of str.') for name, module in model.named_modules(): if hasattr(module, 'num_classes'): assert module.num_classes == len(dataset.CLASSES), \ (f'The `num_classes` ({module.num_classes}) in ' f'{module.__class__.__name__} of ' f'{model.__class__.__name__} does not matches ' f'the length of `CLASSES` ' f'{len(dataset.CLASSES)}) in ' f'{dataset.__class__.__name__}')
def _concat_dataset(cfg, default_args=None): types = cfg['type'] ann_files = cfg['ann_file'] img_prefixes = cfg.get('img_prefix', None) dataset_infos = cfg.get('dataset_info', None) num_joints = cfg['data_cfg'].get('num_joints', None) dataset_channel = cfg['data_cfg'].get('dataset_channel', None) datasets = [] num_dset = len(ann_files) for i in range(num_dset): cfg_copy = copy.deepcopy(cfg) cfg_copy['ann_file'] = ann_files[i] if isinstance(types, (list, tuple)): cfg_copy['type'] = types[i] if isinstance(img_prefixes, (list, tuple)): cfg_copy['img_prefix'] = img_prefixes[i] if isinstance(dataset_infos, (list, tuple)): cfg_copy['dataset_info'] = dataset_infos[i] if isinstance(num_joints, (list, tuple)): cfg_copy['data_cfg']['num_joints'] = num_joints[i] if is_seq_of(dataset_channel, list): cfg_copy['data_cfg']['dataset_channel'] = dataset_channel[i] datasets.append(build_dataset(cfg_copy, default_args)) return ConcatDataset(datasets)
def __init__(self, dataloader, start=None, interval=1, by_epoch=True, save_best=None, rule=None, test_fn=None, greater_keys=None, less_keys=None, **eval_kwargs): if not isinstance(dataloader, DataLoader): raise TypeError(f'dataloader must be a pytorch DataLoader, ' f'but got {type(dataloader)}') if interval <= 0: raise ValueError(f'interval must be a positive number, ' f'but got {interval}') assert isinstance(by_epoch, bool), '``by_epoch`` should be a boolean' if start is not None and start < 0: raise ValueError(f'The evaluation start epoch {start} is smaller ' f'than 0') self.dataloader = dataloader self.interval = interval self.start = start self.by_epoch = by_epoch assert isinstance(save_best, str) or save_best is None, \ '""save_best"" should be a str or None ' \ f'rather than {type(save_best)}' self.save_best = save_best self.eval_kwargs = eval_kwargs self.initial_flag = True if test_fn is None: from mmcv.engine import single_gpu_test self.test_fn = single_gpu_test else: self.test_fn = test_fn if greater_keys is None: self.greater_keys = self._default_greater_keys else: if not isinstance(greater_keys, (list, tuple)): greater_keys = (greater_keys, ) assert is_seq_of(greater_keys, str) self.greater_keys = greater_keys if less_keys is None: self.less_keys = self._default_less_keys else: if not isinstance(less_keys, (list, tuple)): less_keys = (less_keys, ) assert is_seq_of(less_keys, str) self.less_keys = less_keys if self.save_best is not None: self.best_ckpt_path = None self._init_rule(rule, self.save_best)