Beispiel #1
0
    def __init__(self,
                 dataset,
                 samples_per_gpu=1,
                 num_replicas=None,
                 rank=None):
        _rank, _num_replicas = get_dist_info()
        if num_replicas is None:
            num_replicas = _num_replicas
        if rank is None:
            rank = _rank
        self.dataset = dataset
        self.samples_per_gpu = samples_per_gpu
        self.num_replicas = num_replicas
        self.rank = rank
        self.epoch = 0

        assert hasattr(self.dataset, 'flag')
        self.flag = self.dataset.flag
        self.group_sizes = np.bincount(self.flag)

        self.num_samples = 0
        for i, j in enumerate(self.group_sizes):
            self.num_samples += int(
                math.ceil(self.group_sizes[i] * 1.0 / self.samples_per_gpu /
                          self.num_replicas)) * self.samples_per_gpu
        self.total_size = self.num_samples * self.num_replicas
Beispiel #2
0
def build_dataloader(dataset,
                     imgs_per_gpu,
                     workers_per_gpu,
                     num_gpus=1,
                     dist=True,
                     **kwargs):
    shuffle = kwargs.get('shuffle', True)
    if dist:
        rank, world_size = get_dist_info()
        if shuffle:
            sampler = DistributedGroupSampler(dataset, imgs_per_gpu,
                                              world_size, rank)
        else:
            sampler = DistributedSampler(dataset,
                                         world_size,
                                         rank,
                                         shuffle=False)
        batch_size = imgs_per_gpu
        num_workers = workers_per_gpu
    else:
        sampler = GroupSampler(dataset, imgs_per_gpu) if shuffle else None
        batch_size = num_gpus * imgs_per_gpu
        num_workers = num_gpus * workers_per_gpu

    data_loader = DataLoader(dataset,
                             batch_size=batch_size,
                             sampler=sampler,
                             num_workers=num_workers,
                             collate_fn=partial(collate,
                                                samples_per_gpu=imgs_per_gpu),
                             pin_memory=False,
                             **kwargs)

    return data_loader
Beispiel #3
0
    def __call__(self, results):
        rank, _ = get_dist_info()
        if isinstance(self.height, int):
            dst_height = self.height
            dst_min_width = self.min_width
            dst_max_width = self.max_width
        else:
            """Multi-scale resize used in distributed training.

            Choose one (height, width) pair for one rank id.
            """
            idx = rank % len(self.height)
            dst_height = self.height[idx]
            dst_min_width = self.min_width[idx]
            dst_max_width = self.max_width[idx]

        img_shape = results['img_shape']
        ori_height, ori_width = img_shape[:2]
        valid_ratio = 1.0
        resize_shape = list(img_shape)
        pad_shape = list(img_shape)

        if self.keep_aspect_ratio:
            new_width = math.ceil(float(dst_height) / ori_height * ori_width)
            width_divisor = int(1 / self.width_downsample_ratio)
            # make sure new_width is an integral multiple of width_divisor.
            if new_width % width_divisor != 0:
                new_width = round(new_width / width_divisor) * width_divisor
            if dst_min_width is not None:
                new_width = max(dst_min_width, new_width)
            if dst_max_width is not None:
                valid_ratio = min(1.0, 1.0 * new_width / dst_max_width)
                resize_width = min(dst_max_width, new_width)
                img_resize = cv2.resize(results['img'],
                                        (resize_width, dst_height))
                resize_shape = img_resize.shape
                pad_shape = img_resize.shape
                if new_width < dst_max_width:
                    img_resize = mmcv.impad(img_resize,
                                            shape=(dst_height, dst_max_width),
                                            pad_val=self.img_pad_value)
                    pad_shape = img_resize.shape
            else:
                img_resize = cv2.resize(results['img'],
                                        (new_width, dst_height))
                resize_shape = img_resize.shape
                pad_shape = img_resize.shape
        else:
            img_resize = cv2.resize(results['img'],
                                    (dst_max_width, dst_height))
            resize_shape = img_resize.shape
            pad_shape = img_resize.shape

        results['img'] = img_resize
        results['resize_shape'] = resize_shape
        results['pad_shape'] = pad_shape
        results['valid_ratio'] = valid_ratio

        return results
Beispiel #4
0
    def forward(self, x):
        """Forward function.

        Args:
            x (Tensor): Input feature map with shape of (N, C, H, W).

        Returns:
            Tensor: Output feature map with shape of (N, C+1, H, W).
        """

        if self.sync_std:
            # concatenate all features
            all_features = torch.cat(AllGatherLayer.apply(x), dim=0)
            # get the exact features we need in calculating std-dev
            rank, ws = get_dist_info()
            local_bs = all_features.shape[0] // ws
            start_idx = local_bs * rank
            # avoid the case where start idx near the tail of features
            if start_idx + self.sync_groups > all_features.shape[0]:
                start_idx = all_features.shape[0] - self.sync_groups
            end_idx = min(local_bs * rank + self.sync_groups,
                          all_features.shape[0])

            x = all_features[start_idx:end_idx]

        # batch size should be smaller than or equal to group size. Otherwise,
        # batch size should be divisible by the group size.
        assert x.shape[
            0] <= self.group_size or x.shape[0] % self.group_size == 0, (
                'Batch size be smaller than or equal '
                'to group size. Otherwise,'
                ' batch size should be divisible by the group size.'
                f'But got batch size {x.shape[0]},'
                f' group size {self.group_size}')
        assert x.shape[1] % self.channel_groups == 0, (
            '"channel_groups" must be divided by the feature channels. '
            f'channel_groups: {self.channel_groups}, '
            f'feature channels: {x.shape[1]}')

        n, c, h, w = x.shape
        group_size = min(n, self.group_size)
        # [G, M, Gc, C', H, W]
        y = torch.reshape(x, (group_size, -1, self.channel_groups,
                              c // self.channel_groups, h, w))
        y = torch.var(y, dim=0, unbiased=False)
        y = torch.sqrt(y + self.eps)
        # [M, 1, 1, 1]
        y = y.mean(dim=(2, 3, 4), keepdim=True).squeeze(2)
        y = y.repeat(group_size, 1, h, w)
        return torch.cat([x, y], dim=1)
Beispiel #5
0
def load_state_dict(module,
                    in_state,
                    class_maps=None,
                    strict=False,
                    logger=None,
                    force_matching=False,
                    show_converted=False,
                    ignore_prefixes=None,
                    ignore_suffixes=None):
    rank, _ = get_dist_info()

    unexpected_keys = []
    converted_pairs = []
    shape_mismatch_pairs = []
    shape_casted_pairs = []

    own_state = module.state_dict()
    for name, in_param in in_state.items():
        ignored_prefix = ignore_prefixes is not None and name.startswith(
            ignore_prefixes)
        ignored_suffix = ignore_suffixes is not None and name.endswith(
            ignore_suffixes)
        if ignored_prefix or ignored_suffix:
            continue

        if name not in own_state:
            unexpected_keys.append(name)
            continue

        out_param = own_state[name]
        if isinstance(out_param, torch.nn.Parameter):
            out_param = out_param.data
        if isinstance(in_param, torch.nn.Parameter):
            in_param = in_param.data

        src_shape = in_param.size()
        trg_shape = out_param.size()
        if src_shape != trg_shape:
            if np.prod(src_shape) == np.prod(trg_shape):
                out_param.copy_(in_param.view(trg_shape))
                shape_casted_pairs.append(
                    [name, list(out_param.size()),
                     list(in_param.size())])
                continue

            is_valid = False
            if force_matching:
                is_valid = len(src_shape) == len(trg_shape)
                for i in range(len(src_shape)):
                    is_valid &= src_shape[i] >= trg_shape[i]

            if is_valid:
                if not (name.endswith('.weight') or name.endswith('.bias')):
                    continue

                if class_maps is not None and _is_cls_layer(name):
                    dataset_id = 0
                    if len(class_maps) > 1:
                        dataset_id = _get_dataset_id(name)
                    class_map = class_maps[dataset_id]

                    if 'fc_angular' in name:
                        for src_id, trg_id in class_map.items():
                            out_param[:, src_id] = in_param[:, trg_id]
                    else:
                        for src_id, trg_id in class_map.items():
                            out_param[src_id] = in_param[trg_id]
                else:
                    ind = [slice(0, d) for d in list(trg_shape)]
                    out_param.copy_(in_param[ind])

                shape_casted_pairs.append(
                    [name, list(out_param.size()),
                     list(in_param.size())])
            else:
                shape_mismatch_pairs.append(
                    [name, list(out_param.size()),
                     list(in_param.size())])
        else:
            out_param.copy_(in_param)
            if show_converted:
                converted_pairs.append([name, list(out_param.size())])

    missing_keys = list(set(own_state.keys()) - set(in_state.keys()))

    err_msg = []
    if unexpected_keys:
        err_msg.append('unexpected key in source state_dict: {}\n'.format(
            ', '.join(unexpected_keys)))
    if missing_keys:
        err_msg.append('missing keys in source state_dict: {}\n'.format(
            ', '.join(missing_keys)))

    if shape_mismatch_pairs:
        casted_info = 'these keys have mismatched shape:\n'
        header = ['key', 'expected shape', 'loaded shape']
        table_data = [header] + shape_mismatch_pairs
        table = AsciiTable(table_data)
        err_msg.append(casted_info + table.table)

    if len(err_msg) > 0 and rank == 0:
        err_msg.insert(
            0, 'The model and loaded state dict do not match exactly\n')
        err_msg = '\n'.join(err_msg)
        if strict:
            raise RuntimeError(err_msg)
        elif logger is not None:
            logger.warning(err_msg)

    ok_message = []
    if converted_pairs:
        converted_info = 'These keys have been matched correctly:\n'
        header = ['key', 'shape']
        table_data = [header] + converted_pairs
        table = AsciiTable(table_data)
        ok_message.append(converted_info + table.table)

    if len(ok_message) > 0 and rank == 0:
        ok_message = '\n'.join(ok_message)
        if logger is not None:
            logger.info(ok_message)

    warning_msg = []
    if shape_casted_pairs:
        casted_info = 'these keys have been shape casted:\n'
        header = ['key', 'expected shape', 'loaded shape']
        table_data = [header] + shape_casted_pairs
        table = AsciiTable(table_data)
        warning_msg.append(casted_info + table.table)

    if len(warning_msg) > 0 and rank == 0:
        warning_msg.insert(
            0, 'The model and loaded state dict do not match exactly\n')
        warning_msg = '\n'.join(warning_msg)
        if logger is not None:
            logger.warning(warning_msg)
Beispiel #6
0
    def __init__(
        self,
        model,
        batch_processor=None,
        optimizer=None,
        work_dir=None,
        logger=None,
        meta=None,
        max_iters=None,
        max_epochs=None,
    ):
        if batch_processor is not None:
            if not callable(batch_processor):
                raise TypeError("batch_processor must be callable, "
                                f"but got {type(batch_processor)}")
            warnings.warn("batch_processor is deprecated, please implement "
                          "train_step() and val_step() in the model instead.")
            # raise an error is `batch_processor` is not None and
            # `model.train_step()` exists.
            if is_module_wrapper(model):
                _model = model.module
            else:
                _model = model
            if hasattr(_model, "train_step") or hasattr(_model, "val_step"):
                raise RuntimeError(
                    "batch_processor and model.train_step()/model.val_step() "
                    "cannot be both available.")
        else:
            assert hasattr(model, "train_step")

        # check the type of `optimizer`
        if isinstance(optimizer, dict):
            for name, optim in optimizer.items():
                if not isinstance(optim, Optimizer):
                    raise TypeError(
                        f"optimizer must be a dict of torch.optim.Optimizers, "
                        f'but optimizer["{name}"] is a {type(optim)}')
        elif not isinstance(optimizer, Optimizer) and optimizer is not None:
            pass
            # raise TypeError(
            #     f'optimizer must be a torch.optim.Optimizer object '
            #     f'or dict or None, but got {type(optimizer)}')

        # check the type of `logger`
        if not isinstance(logger, logging.Logger):
            raise TypeError(f"logger must be a logging.Logger object, "
                            f"but got {type(logger)}")

        # check the type of `meta`
        if meta is not None and not isinstance(meta, dict):
            raise TypeError(
                f"meta must be a dict or None, but got {type(meta)}")

        self.model = model
        self.batch_processor = batch_processor
        self.optimizer = optimizer
        self.logger = logger
        self.meta = meta

        # create work_dir
        if mmcv.is_str(work_dir):
            self.work_dir = osp.abspath(work_dir)
            mmcv.mkdir_or_exist(self.work_dir)
        elif work_dir is None:
            self.work_dir = None
        else:
            raise TypeError('"work_dir" must be a str or None')

        # get model name from the model class
        if hasattr(self.model, "module"):
            self._model_name = self.model.module.__class__.__name__
        else:
            self._model_name = self.model.__class__.__name__

        self._rank, self._world_size = get_dist_info()
        self.timestamp = get_time_str()
        self.mode = None
        self._hooks = []
        self._epoch = 0
        self._iter = 0
        self._inner_iter = 0

        if max_epochs is not None and max_iters is not None:
            raise ValueError(
                "Only one of `max_epochs` or `max_iters` can be set.")

        self._max_epochs = max_epochs
        self._max_iters = max_iters
        # TODO: Redesign LogBuffer, it is not flexible and elegant enough
        self.log_buffer = LogBuffer()