Ejemplo n.º 1
0
    def _get_sampler(train_set, test_set, val_set, train_sampler, test_sampler,
                     val_sampler, start_epoch):
        if train_sampler is None:
            if is_distributed():
                train_sampler = DistributedSampler(
                    train_set,
                    num_replicas=get_world_size(),
                    rank=get_global_rank())
                train_sampler.set_epoch(start_epoch)
            else:
                train_sampler = RandomSampler(train_set, True)
        else:
            train_sampler = train_sampler(train_set)

        if test_sampler is None:
            if is_distributed():
                test_sampler = DistributedSampler(
                    test_set,
                    num_replicas=get_world_size(),
                    rank=get_global_rank())
        else:
            test_sampler = test_sampler(test_set)

        if val_set is not None:
            if val_sampler is None and is_distributed():
                val_sampler = DistributedSampler(val_set,
                                                 num_replicas=get_world_size(),
                                                 rank=get_global_rank())
                val_sampler.set_epoch(start_epoch)
            elif val_sampler is not None:
                val_sampler = val_sampler(val_set)

        return train_sampler, test_sampler, val_sampler
Ejemplo n.º 2
0
def vision_loaders(name: str,
                   batch_size: int,
                   train_da: Optional[List] = None,
                   test_da: Optional[List] = None,
                   norm: Optional[List] = None,
                   val_size: int = 0,
                   download: bool = False,
                   num_workers: int = -1,
                   non_training_bs_factor=2,
                   distributed: bool = False,
                   drop_last: bool = False,
                   pin_memory: bool = True,
                   return_num_classes: bool = False,
                   test_batch_size: Optional[int] = None
                   ) -> Tuple:
    """ Get data loaders for registered vision datasets. homura expects datasets are in `~/.torch/data/DATASET_NAME`.
    Link path if necessary, e.g. `ln -s /original/path $HOME/.torch`. Datasets can be registered
    using `homura.vision.register_dataset`

    :param name: name of dataset.
    :param batch_size:
    :param train_da: custom train-time data augmentation
    :param test_da: custom test-time data augmentation
    :param norm: custom normalization after train_da/test_da
    :param val_size: If `val_size>0`, split train set
    :param download:
    :param num_workers:
    :param non_training_bs_factor:
    :param distributed:
    :param return_num_classes:
    :return: (train_set, test_set, [val_set], [num_classes])
    """

    if name not in _DATASETS.keys():
        raise RuntimeError(f'Unknown dataset name {name}.')
    dataset = _DATASETS[name]
    train_set, test_set = dataset.instantiate(train_da, test_da, norm, download)
    if test_batch_size is None:
        test_batch_size = non_training_bs_factor * batch_size
    if val_size > 0:
        train_set, val_set = _split_dataset(train_set, val_size)
        val_set.transform = test_set.transform

    samplers = [None, None, None]
    if distributed:
        import homura

        kwargs = dict(num_replicas=homura.get_world_size(), rank=homura.get_global_rank())
        samplers[0] = DistributedSampler(train_set, **kwargs)
        samplers[2] = DistributedSampler(test_set, **kwargs)
    else:
        samplers[0] = RandomSampler(train_set, True)

    shared_kwargs = dict(drop_last=drop_last, num_workers=num_workers, pin_memory=pin_memory,
                         collate_fn=dataset.collate_fn)
    train_loader = DataLoader(train_set, batch_size, sampler=samplers[0], **shared_kwargs)
    test_loader = DataLoader(test_set, test_batch_size, sampler=samplers[2], **shared_kwargs)
    ret = [train_loader, test_loader]
    if val_size > 0:
        if distributed:
            samplers[1] = DistributedSampler(test_set, **kwargs)
        val_loader = DataLoader(val_set, test_batch_size, sampler=samplers[1], **shared_kwargs)
        ret.append(val_loader)

    if return_num_classes:
        ret.append(dataset.num_classes)

    return tuple(ret)
Ejemplo n.º 3
0
    def __init__(self,
                 model: nn.Module or Dict[str, nn.Module],
                 optimizer: Optional[Partial or Optimizer
                                     or Dict[str, Optimizer]],
                 loss_f: Optional[Callable or Dict[str, Callable]] = None,
                 *,
                 reporters: Optional[_ReporterBase
                                     or List[_ReporterBase]] = None,
                 scheduler: Optional[Partial or Scheduler
                                     or Dict[str, Scheduler]] = None,
                 device: Optional[torch.device or str] = None,
                 quiet: bool = True,
                 disable_cudnn_benchmark: bool = False,
                 disable_cuda_nonblocking: bool = False,
                 logger=None,
                 use_sync_bn: bool = False,
                 tqdm_ncols: int = 120,
                 debug: bool = False,
                 **kwargs):

        if kwargs.get("update_scheduler_by_epoch"):
            raise DeprecationWarning(
                "update_scheduler_by_epoch is deprecated, users need to step")

        if kwargs.get("callbacks"):
            raise DeprecationWarning(
                "callback is deprecated, if you need, use homura before v2020.8"
            )

        self.logger = logger or get_logger(__name__)
        self.device = device or (torch.device(
            GPU) if torch.cuda.is_available() else torch.device(CPU))
        self._is_debug = debug

        if self._is_debug:
            self.logger.warning(
                "Trainer is set to be debug mode, which may affect the performance"
            )
            set_verb_level("debug")

        # setup for distributed
        self._use_sync_bn = use_sync_bn
        if is_distributed():
            if self._use_sync_bn:
                model = nn.SyncBatchNorm.convert_sync_batchnorm(model)
                self.logger.info(
                    "BNs of model are converted to nn.SyncBatchNorm")

            rank = get_local_rank()
            torch.cuda.set_device(rank)
            if get_global_rank() > 0:
                # to avoid overwriting
                quiet = True

        self.loss_f = loss_f
        self._verbose = not quiet

        # setup model
        if isinstance(model, nn.Module):
            self.model = model
        elif isinstance(model, dict):
            self.model = nn.ModuleDict(model)
            self.logger.debug(f"model is nn.ModuleDict of {self.model.keys()}")
        else:
            raise TypeError(
                f"Unknown type for `model`. Expected nn.Module or Dict[str, Module], but got {type(model)}"
            )

        if GPU in str(self.device):
            self.model.to(self.device)
            torch.backends.cudnn.benchmark = not disable_cudnn_benchmark
            self._cuda_nonblocking = not disable_cuda_nonblocking
            self.logger.debug(
                f"cuda: True, cudnn.benchmark: {not disable_cudnn_benchmark}, "
                f"cuda.nonblocking: {not disable_cuda_nonblocking}")
        else:
            self._cuda_nonblocking = False
            # usually, this is not expected
            self.logger.info(
                f"cuda: False (torch.cuda.is_available()={torch.cuda.is_available()})"
            )

        if is_distributed():
            self.model = nn.parallel.DistributedDataParallel(self.model,
                                                             device_ids=[rank])
            self.logger.debug(
                f"model converted to DistributedDataParallel at rank={rank}")

        # self.accessible_model is useful for e.g., checkpointing
        if isinstance(self.model,
                      nn.parallel.DistributedDataParallel) or isinstance(
                          self.model, nn.DataParallel):
            self.accessible_model = self.model.module
        else:
            self.accessible_model = self.model

        # setup optimizer and scheduler
        self.optimizer = optimizer
        self.scheduler = scheduler
        self.set_optimizer()
        self.set_scheduler()

        if reporters is not None and not isinstance(reporters, Iterable):
            reporters = [reporters]
        reporters = reporters or []

        if not any([isinstance(rep, TQDMReporter) for rep in reporters]):
            # if reporters not contain TQDMReporter
            reporters.append(TQDMReporter(ncols=tqdm_ncols))
        self.logger.debug(f"reporter is ready: {reporters}")
        self.reporter = ReporterList(reporters)

        # called via property
        # _step and _epoch are set to -1 because they are incremented before each iteration and epoch
        self._step = -1
        self._epoch = -1
        self._is_train = True

        # to nest, leave=False (https://github.com/tqdm/tqdm/blob/master/examples/simple_examples.py#L19)
        self._tqdm = lambda x: x
        if self._verbose:
            self._tqdm = Partial(tqdm, ncols=tqdm_ncols, leave=False)
            set_tqdm_stdout_stderr()
            self.logger.debug("verbose: setup tqdm")
        else:
            self.logger.debug("quiet: no tqdm")

        for k, v in kwargs.items():
            if hasattr(self, k):
                raise AttributeError(f"{self} already has {k}")
            if isinstance(v, torch.Tensor):
                v = v.to(self.device)
            if isinstance(v, nn.Module):
                v.to(self.device)
            setattr(self, k, v)
            self.logger.debug(f"trainer sets {k} as a new attribute")
Ejemplo n.º 4
0
    def get_dataloader(
        self,
        batch_size: int,
        train_da: Optional[List] = None,
        test_da: Optional[List] = None,
        norm: Optional[List] = None,
        train_size: Optional[int] = None,
        test_size: Optional[int] = None,
        val_size: Optional[int] = None,
        download: bool = False,
        num_workers: int = 1,
        non_training_bs_factor=2,
        drop_last: bool = False,
        pin_memory: bool = True,
        return_num_classes: bool = False,
        test_batch_size: Optional[int] = None,
        pre_default_train_da: Optional[List] = None,
        post_default_train_da: Optional[List] = None,
        post_norm_train_da: Optional[List] = None,
        use_prefetcher: bool = False,
        start_epoch: bool = 0
    ) -> (Tuple[DataLoader, DataLoader] or Tuple[DataLoader, DataLoader,
                                                 DataLoader] or
          Tuple[DataLoader, DataLoader, int] or Tuple[DataLoader, DataLoader,
                                                      DataLoader, int]):
        """ Get Dataloader. This will automatically handle distributed setting

        :param batch_size: Batch size
        :param train_da: Data Augmentation for training
        :param test_da: Data Augmentation for testing and validation
        :param norm: Normalization after train_da and test_da
        :param train_size: Size of training dataset. If None, full dataset will be available.
        :param test_size: Size of test dataset. If None, full dataset will be available.
        :param val_size: Size of validation dataset, randomly split from the training dataset.
        If None, None will be returned.
        :param download: If dataset needs downloading
        :param num_workers: Number of workers in data loaders
        :param non_training_bs_factor: Batch size scale factor during non training. For example,
        testing time requires no backward cache, so basically batch size can be doubled.
        :param drop_last: If drop last batch or not
        :param pin_memory: If pin memory or not
        :param return_num_classes: If return number of classes as the last return value
        :param test_batch_size: Test time batch size. If None, non_training_bs_factor * batch_size is used.
        :param pre_default_train_da: Data Augmentation before the default data augmentation
        :param post_default_train_da: Data Augmentation after the default data augmentation
        :param post_norm_train_da: Data Augmentation after normalization (i.e., norm)
        :param use_prefetcher: Use prefetcher or Not
        :param start_epoch: Epoch at start time
        :return: train_loader, test_loader, [val_loader], [num_classes]
        """

        train_set, test_set, val_set = self.get_dataset(
            train_size,
            test_size,
            val_size,
            train_da,
            test_da,
            norm,
            download,
            pre_default_train_da=pre_default_train_da,
            post_default_train_da=post_default_train_da,
            post_norm_train_da=post_norm_train_da)
        if test_batch_size is None:
            test_batch_size = non_training_bs_factor * batch_size

        samplers = [None, None, None]
        if is_distributed():
            import homura

            dist_sampler_kwargs = dict(num_replicas=homura.get_world_size(),
                                       rank=homura.get_global_rank())
            samplers[0] = DistributedSampler(train_set, **dist_sampler_kwargs)
            samplers[2] = DistributedSampler(test_set, **dist_sampler_kwargs)
            samplers[0].set_epoch(start_epoch)
            samplers[2].set_epoch(start_epoch)
        else:
            samplers[0] = RandomSampler(train_set, True)

        shared_kwargs = dict(drop_last=drop_last,
                             num_workers=num_workers,
                             pin_memory=pin_memory,
                             collate_fn=self.collate_fn)
        train_loader = DataLoader(train_set,
                                  batch_size,
                                  sampler=samplers[0],
                                  **shared_kwargs)
        test_loader = DataLoader(test_set,
                                 test_batch_size,
                                 sampler=samplers[2],
                                 **shared_kwargs)
        if use_prefetcher:
            train_loader = DataPrefetchWrapper(train_loader, start_epoch)
            test_loader = DataPrefetchWrapper(test_loader, start_epoch)

        ret = [train_loader, test_loader]

        if val_set is not None:
            if is_distributed():
                samplers[1] = DistributedSampler(val_set,
                                                 **dist_sampler_kwargs)
                samplers[1].set_epoch(start_epoch)
            val_loader = DataLoader(val_set,
                                    test_batch_size,
                                    sampler=samplers[1],
                                    **shared_kwargs)
            if use_prefetcher:
                val_loader = DataPrefetchWrapper(test_loader)
            ret.append(val_loader)

        if return_num_classes:
            ret.append(self.num_classes)

        return tuple(ret)
Ejemplo n.º 5
0
 def __new__(cls, *args, **kwargs):
     if homura.get_global_rank() > 0:
         return _NoOpWrapper()
     else:
         return object.__new__(cls)