def _get_sampler(train_set, test_set, val_set, train_sampler, test_sampler, val_sampler, start_epoch): if train_sampler is None: if is_distributed(): train_sampler = DistributedSampler( train_set, num_replicas=get_world_size(), rank=get_global_rank()) train_sampler.set_epoch(start_epoch) else: train_sampler = RandomSampler(train_set, True) else: train_sampler = train_sampler(train_set) if test_sampler is None: if is_distributed(): test_sampler = DistributedSampler( test_set, num_replicas=get_world_size(), rank=get_global_rank()) else: test_sampler = test_sampler(test_set) if val_set is not None: if val_sampler is None and is_distributed(): val_sampler = DistributedSampler(val_set, num_replicas=get_world_size(), rank=get_global_rank()) val_sampler.set_epoch(start_epoch) elif val_sampler is not None: val_sampler = val_sampler(val_set) return train_sampler, test_sampler, val_sampler
def vision_loaders(name: str, batch_size: int, train_da: Optional[List] = None, test_da: Optional[List] = None, norm: Optional[List] = None, val_size: int = 0, download: bool = False, num_workers: int = -1, non_training_bs_factor=2, distributed: bool = False, drop_last: bool = False, pin_memory: bool = True, return_num_classes: bool = False, test_batch_size: Optional[int] = None ) -> Tuple: """ Get data loaders for registered vision datasets. homura expects datasets are in `~/.torch/data/DATASET_NAME`. Link path if necessary, e.g. `ln -s /original/path $HOME/.torch`. Datasets can be registered using `homura.vision.register_dataset` :param name: name of dataset. :param batch_size: :param train_da: custom train-time data augmentation :param test_da: custom test-time data augmentation :param norm: custom normalization after train_da/test_da :param val_size: If `val_size>0`, split train set :param download: :param num_workers: :param non_training_bs_factor: :param distributed: :param return_num_classes: :return: (train_set, test_set, [val_set], [num_classes]) """ if name not in _DATASETS.keys(): raise RuntimeError(f'Unknown dataset name {name}.') dataset = _DATASETS[name] train_set, test_set = dataset.instantiate(train_da, test_da, norm, download) if test_batch_size is None: test_batch_size = non_training_bs_factor * batch_size if val_size > 0: train_set, val_set = _split_dataset(train_set, val_size) val_set.transform = test_set.transform samplers = [None, None, None] if distributed: import homura kwargs = dict(num_replicas=homura.get_world_size(), rank=homura.get_global_rank()) samplers[0] = DistributedSampler(train_set, **kwargs) samplers[2] = DistributedSampler(test_set, **kwargs) else: samplers[0] = RandomSampler(train_set, True) shared_kwargs = dict(drop_last=drop_last, num_workers=num_workers, pin_memory=pin_memory, collate_fn=dataset.collate_fn) train_loader = DataLoader(train_set, batch_size, sampler=samplers[0], **shared_kwargs) test_loader = DataLoader(test_set, test_batch_size, sampler=samplers[2], **shared_kwargs) ret = [train_loader, test_loader] if val_size > 0: if distributed: samplers[1] = DistributedSampler(test_set, **kwargs) val_loader = DataLoader(val_set, test_batch_size, sampler=samplers[1], **shared_kwargs) ret.append(val_loader) if return_num_classes: ret.append(dataset.num_classes) return tuple(ret)
def __init__(self, model: nn.Module or Dict[str, nn.Module], optimizer: Optional[Partial or Optimizer or Dict[str, Optimizer]], loss_f: Optional[Callable or Dict[str, Callable]] = None, *, reporters: Optional[_ReporterBase or List[_ReporterBase]] = None, scheduler: Optional[Partial or Scheduler or Dict[str, Scheduler]] = None, device: Optional[torch.device or str] = None, quiet: bool = True, disable_cudnn_benchmark: bool = False, disable_cuda_nonblocking: bool = False, logger=None, use_sync_bn: bool = False, tqdm_ncols: int = 120, debug: bool = False, **kwargs): if kwargs.get("update_scheduler_by_epoch"): raise DeprecationWarning( "update_scheduler_by_epoch is deprecated, users need to step") if kwargs.get("callbacks"): raise DeprecationWarning( "callback is deprecated, if you need, use homura before v2020.8" ) self.logger = logger or get_logger(__name__) self.device = device or (torch.device( GPU) if torch.cuda.is_available() else torch.device(CPU)) self._is_debug = debug if self._is_debug: self.logger.warning( "Trainer is set to be debug mode, which may affect the performance" ) set_verb_level("debug") # setup for distributed self._use_sync_bn = use_sync_bn if is_distributed(): if self._use_sync_bn: model = nn.SyncBatchNorm.convert_sync_batchnorm(model) self.logger.info( "BNs of model are converted to nn.SyncBatchNorm") rank = get_local_rank() torch.cuda.set_device(rank) if get_global_rank() > 0: # to avoid overwriting quiet = True self.loss_f = loss_f self._verbose = not quiet # setup model if isinstance(model, nn.Module): self.model = model elif isinstance(model, dict): self.model = nn.ModuleDict(model) self.logger.debug(f"model is nn.ModuleDict of {self.model.keys()}") else: raise TypeError( f"Unknown type for `model`. Expected nn.Module or Dict[str, Module], but got {type(model)}" ) if GPU in str(self.device): self.model.to(self.device) torch.backends.cudnn.benchmark = not disable_cudnn_benchmark self._cuda_nonblocking = not disable_cuda_nonblocking self.logger.debug( f"cuda: True, cudnn.benchmark: {not disable_cudnn_benchmark}, " f"cuda.nonblocking: {not disable_cuda_nonblocking}") else: self._cuda_nonblocking = False # usually, this is not expected self.logger.info( f"cuda: False (torch.cuda.is_available()={torch.cuda.is_available()})" ) if is_distributed(): self.model = nn.parallel.DistributedDataParallel(self.model, device_ids=[rank]) self.logger.debug( f"model converted to DistributedDataParallel at rank={rank}") # self.accessible_model is useful for e.g., checkpointing if isinstance(self.model, nn.parallel.DistributedDataParallel) or isinstance( self.model, nn.DataParallel): self.accessible_model = self.model.module else: self.accessible_model = self.model # setup optimizer and scheduler self.optimizer = optimizer self.scheduler = scheduler self.set_optimizer() self.set_scheduler() if reporters is not None and not isinstance(reporters, Iterable): reporters = [reporters] reporters = reporters or [] if not any([isinstance(rep, TQDMReporter) for rep in reporters]): # if reporters not contain TQDMReporter reporters.append(TQDMReporter(ncols=tqdm_ncols)) self.logger.debug(f"reporter is ready: {reporters}") self.reporter = ReporterList(reporters) # called via property # _step and _epoch are set to -1 because they are incremented before each iteration and epoch self._step = -1 self._epoch = -1 self._is_train = True # to nest, leave=False (https://github.com/tqdm/tqdm/blob/master/examples/simple_examples.py#L19) self._tqdm = lambda x: x if self._verbose: self._tqdm = Partial(tqdm, ncols=tqdm_ncols, leave=False) set_tqdm_stdout_stderr() self.logger.debug("verbose: setup tqdm") else: self.logger.debug("quiet: no tqdm") for k, v in kwargs.items(): if hasattr(self, k): raise AttributeError(f"{self} already has {k}") if isinstance(v, torch.Tensor): v = v.to(self.device) if isinstance(v, nn.Module): v.to(self.device) setattr(self, k, v) self.logger.debug(f"trainer sets {k} as a new attribute")
def get_dataloader( self, batch_size: int, train_da: Optional[List] = None, test_da: Optional[List] = None, norm: Optional[List] = None, train_size: Optional[int] = None, test_size: Optional[int] = None, val_size: Optional[int] = None, download: bool = False, num_workers: int = 1, non_training_bs_factor=2, drop_last: bool = False, pin_memory: bool = True, return_num_classes: bool = False, test_batch_size: Optional[int] = None, pre_default_train_da: Optional[List] = None, post_default_train_da: Optional[List] = None, post_norm_train_da: Optional[List] = None, use_prefetcher: bool = False, start_epoch: bool = 0 ) -> (Tuple[DataLoader, DataLoader] or Tuple[DataLoader, DataLoader, DataLoader] or Tuple[DataLoader, DataLoader, int] or Tuple[DataLoader, DataLoader, DataLoader, int]): """ Get Dataloader. This will automatically handle distributed setting :param batch_size: Batch size :param train_da: Data Augmentation for training :param test_da: Data Augmentation for testing and validation :param norm: Normalization after train_da and test_da :param train_size: Size of training dataset. If None, full dataset will be available. :param test_size: Size of test dataset. If None, full dataset will be available. :param val_size: Size of validation dataset, randomly split from the training dataset. If None, None will be returned. :param download: If dataset needs downloading :param num_workers: Number of workers in data loaders :param non_training_bs_factor: Batch size scale factor during non training. For example, testing time requires no backward cache, so basically batch size can be doubled. :param drop_last: If drop last batch or not :param pin_memory: If pin memory or not :param return_num_classes: If return number of classes as the last return value :param test_batch_size: Test time batch size. If None, non_training_bs_factor * batch_size is used. :param pre_default_train_da: Data Augmentation before the default data augmentation :param post_default_train_da: Data Augmentation after the default data augmentation :param post_norm_train_da: Data Augmentation after normalization (i.e., norm) :param use_prefetcher: Use prefetcher or Not :param start_epoch: Epoch at start time :return: train_loader, test_loader, [val_loader], [num_classes] """ train_set, test_set, val_set = self.get_dataset( train_size, test_size, val_size, train_da, test_da, norm, download, pre_default_train_da=pre_default_train_da, post_default_train_da=post_default_train_da, post_norm_train_da=post_norm_train_da) if test_batch_size is None: test_batch_size = non_training_bs_factor * batch_size samplers = [None, None, None] if is_distributed(): import homura dist_sampler_kwargs = dict(num_replicas=homura.get_world_size(), rank=homura.get_global_rank()) samplers[0] = DistributedSampler(train_set, **dist_sampler_kwargs) samplers[2] = DistributedSampler(test_set, **dist_sampler_kwargs) samplers[0].set_epoch(start_epoch) samplers[2].set_epoch(start_epoch) else: samplers[0] = RandomSampler(train_set, True) shared_kwargs = dict(drop_last=drop_last, num_workers=num_workers, pin_memory=pin_memory, collate_fn=self.collate_fn) train_loader = DataLoader(train_set, batch_size, sampler=samplers[0], **shared_kwargs) test_loader = DataLoader(test_set, test_batch_size, sampler=samplers[2], **shared_kwargs) if use_prefetcher: train_loader = DataPrefetchWrapper(train_loader, start_epoch) test_loader = DataPrefetchWrapper(test_loader, start_epoch) ret = [train_loader, test_loader] if val_set is not None: if is_distributed(): samplers[1] = DistributedSampler(val_set, **dist_sampler_kwargs) samplers[1].set_epoch(start_epoch) val_loader = DataLoader(val_set, test_batch_size, sampler=samplers[1], **shared_kwargs) if use_prefetcher: val_loader = DataPrefetchWrapper(test_loader) ret.append(val_loader) if return_num_classes: ret.append(self.num_classes) return tuple(ret)
def __new__(cls, *args, **kwargs): if homura.get_global_rank() > 0: return _NoOpWrapper() else: return object.__new__(cls)