def __init__(self, dataset, num_replicas=None, rank=None, limit_number_of_volumes=None): if num_replicas is None: num_replicas = communication.get_world_size() if rank is None: rank = communication.get_rank() self.dataset = dataset self.num_replicas = num_replicas self.rank = rank filenames = list( self.dataset.volume_indices.keys()) # This is an OrderedDict if limit_number_of_volumes: filenames = filenames[:limit_number_of_volumes] chunked_filenames = list(chunks(filenames, self.num_replicas)) filenames = chunked_filenames[self.rank] # Collect the indices belonging to these filenames. self.indices = [] if self.rank < len( chunked_filenames): # Otherwise there is nothing to fill. for filename in filenames: self.indices.extend(list( self.dataset.volume_indices[filename]))
def __init__( self, size: int, shuffle: bool = True, seed: Optional[int] = None, ): """ Parameters ---------- size : int Size of underlying dataset. shuffle : bool If true, the indices will be shuffled. seed : int Initial seed of the shuffle, must be the same across all workers! """ self._size = size if self._size <= 0: raise AssertionError self._shuffle = shuffle if seed is None: seed = communication.shared_random_seed() self._seed = int(seed) self._rank = communication.get_rank() self._world_size = communication.get_world_size()
def train( self, optimizer: torch.optim.Optimizer, lr_scheduler: torch.optim.lr_scheduler._LRScheduler, # noqa training_data: Dataset, experiment_directory: pathlib.Path, validation_data: Dataset = None, resume: bool = False, num_workers: int = 0) -> None: # TODO: Does not need to be member of self. self.__optimizer = optimizer # TODO: Optimizer and LR scheduler need to be resumed too. self.__lr_scheduler = lr_scheduler training_sampler = self.build_sampler(training_data, 'random') # TODO: Configurable training_loader = self.build_loader( training_data, sampler=training_sampler, batch_size=self.cfg.training.batch_size, num_workers=num_workers, drop_last=True) if validation_data: validation_sampler = self.build_sampler( validation_data, 'sequential', limit_number_of_volumes=None) batch_sampler = BatchSampler(validation_sampler, batch_size=8 * self.cfg.training.batch_size, drop_last=False) # TODO: Batch size can be much larger, perhaps have a different batch size during evaluation. validation_loader = self.build_loader( validation_data, batch_sampler=batch_sampler, num_workers=num_workers, ) self.model = self.model.to(self.device) # Optimizer self.__optimizer.zero_grad() # type: ignore # Mixed precision setup. This requires the model to be on the gpu. extra_checkpointing = {} if self.mixed_precision > 0: opt_level = f'O{self.mixed_precision}' self.logger.info(f'Using apex level {opt_level}.') self.model, self.__optimizer = amp.initialize(self.model, self.__optimizer, opt_level=opt_level) extra_checkpointing['amp'] = amp extra_checkpointing['opt_level'] = opt_level git_hash = direct.utils.git_hash() extra_checkpointing['__author__'] = git_hash if git_hash else 'N/A' self.checkpointer = Checkpointer( self.model, experiment_directory, save_to_disk=communication.is_main_process(), optimizer=optimizer, lr_scheduler=lr_scheduler, **extra_checkpointing) # Load checkpoint start_iter = 0 if resume: self.logger.info('Attempting to resume...') # This changes the model inplace checkpoint = self.checkpointer.load( iteration='latest', checkpointable_objects=['amp'] if self.mixed_precision > 0 else []) if not checkpoint: self.logger.info('No checkpoint found. Starting from scratch.') else: start_iter = checkpoint['iteration'] + 1 self.logger.info(f'Starting from iteration: {start_iter}.') if '__author__' in checkpoint: self.logger.info( f"Git hash of checkpoint: {checkpoint['__author__']}") if checkpoint['__author__'] != direct.utils.git_hash(): self.logger.warning( f"Current git hash {direct.utils.git_hash()} is different from the one " f"this checkpoint is saved with ({checkpoint['__author__']}. This can be fine, " f"but beware that this can be a source of confusion.") if '__datetime__' in checkpoint: self.logger.info( f"Checkpoint created at: {checkpoint['__datetime__']}") if 'opt_level' in checkpoint: if checkpoint['opt_level'] != opt_level: self.logger.warning( f"Mixed precision opt-levels do not match. " f"Requested {opt_level} got {checkpoint['opt_level']} from checkpoint. " f"This will almost surely lead to performance degradation." ) self.logger.info(f'World size: {communication.get_world_size()}.') if communication.get_world_size() > 1: self.model = DistributedDataParallel( self.model, device_ids=[communication.get_rank()], broadcast_buffers=False) # World size > 1 if distributed mode, else allow a DataParallel fallback, can be convenient for debugging. elif torch.cuda.device_count() > 1 and communication.get_world_size( ) == 1: self.model = DataParallel(self.model) self.__writers = ([ JSONWriter(experiment_directory / 'metrics.json'), CommonMetricPrinter(self.cfg.training.num_iterations), TensorboardWriter(experiment_directory / 'tensorboard') ] if communication.is_main_process() else []) with EventStorage(start_iter): self.training_loop(training_loader, start_iter, validation_loader) self.logger.info('Training completed.')
def train( self, optimizer: torch.optim.Optimizer, lr_scheduler: torch.optim.lr_scheduler._LRScheduler, # noqa training_datasets: List[Dataset], experiment_directory: pathlib.Path, validation_data: Optional[Dataset] = None, resume: bool = False, initialization: Optional[PathOrString] = None, num_workers: int = 6, ) -> None: self.logger.info("Starting training.") # TODO: Does not need to be member of self. self.__optimizer = optimizer # TODO: Optimizer and LR scheduler need to be resumed too. self.__lr_scheduler = lr_scheduler training_data = ConcatDataset(training_datasets) self.logger.info(f"Concatenated dataset length: {len(training_data)}.") self.logger.info( f"Building batch sampler for training set with batch size {self.cfg.training.batch_size}." ) training_sampler = self.build_batch_sampler( training_datasets, self.cfg.training.batch_size, "random") training_loader = self.build_loader( training_data, batch_sampler=training_sampler, num_workers=num_workers, ) if validation_data: validation_loaders = [] for idx, curr_validation_data in enumerate(validation_data): text_dataset_description = curr_validation_data.text_description self.logger.info( f"Building dataloader for dataset: {text_dataset_description}." ) curr_batch_sampler = self.build_batch_sampler( curr_validation_data, batch_size=self.cfg.validation.batch_size, sampler_type="sequential", limit_number_of_volumes=None, ) validation_loaders.append(( text_dataset_description, self.build_loader( curr_validation_data, batch_sampler=curr_batch_sampler, num_workers= 0, # num_workers, # TODO(jt): This seems to choke the validation. ), )) else: validation_loaders = None self.models_to_device() # Optimizer self.__optimizer.zero_grad() # type: ignore # Mixed precision setup. This requires the model to be on the gpu. git_hash = direct.utils.git_hash() extra_checkpointing = { "__author__": git_hash if git_hash else "N/A", "__version__": direct.__version__, "__mixed_precision__": self.mixed_precision, } if self.mixed_precision: # TODO(jt): Check if on GPU self.logger.info(f"Using mixed precision training.") self.checkpointer = Checkpointer( self.model, experiment_directory, save_to_disk=communication.is_main_process(), optimizer=optimizer, lr_scheduler=lr_scheduler, scaler=self._scaler, **self.models, **extra_checkpointing, ) # Load checkpoint start_iter = 0 checkpoint = {} if resume: self.logger.info("Attempting to resume...") # This changes the model inplace checkpoint = self.checkpointer.load(iteration="latest") if not checkpoint: self.logger.info("No checkpoint found. Starting from scratch.") else: start_iter = checkpoint["iteration"] + 1 self.logger.info(f"Starting from iteration: {start_iter}.") if start_iter > 0 and initialization: self.logger.warning( f"Initialization checkpoint set to {initialization}," f" but model will resume training from previous checkpoint. Initialization ignored." ) elif initialization: self.logger.info(f"Initializing from {initialization}...") self.checkpointer.load_from_file(initialization) if "__version__" in checkpoint: self.logger.info( f"DIRECT version of checkpoint: {checkpoint['__version__']}.") if checkpoint["__version__"] != direct.__version__: self.logger.warning( f"Current DIRECT version {direct.__version__} is different from the one " f"this checkpoint is saved with ({checkpoint['__version__']}. This can be fine, " f"but beware that this can be a source of confusion.") if "__author__" in checkpoint: self.logger.info( f"Git hash of checkpoint: {checkpoint['__author__']}.") if checkpoint["__author__"] != direct.utils.git_hash(): self.logger.warning( f"Current git hash {direct.utils.git_hash()} is different from the one " f"this checkpoint is saved with ({checkpoint['__author__']}. This can be fine, " f"but beware that this can be a source of confusion.") if "__datetime__" in checkpoint: self.logger.info( f"Checkpoint created at: {checkpoint['__datetime__']}.") if "__mixed_precision__" in checkpoint: if (not self.mixed_precision ) and checkpoint["__mixed_precision__"]: self.logger.warning( f"Mixed precision training is not enabled, yet saved checkpoint requests this" f"Will now enable mixed precision.") self.mixed_precision = True elif not checkpoint["__mixed_precision__"] and self.mixed_precision: self.logger.warning( f"Mixed precision levels of training and loading checkpoint do not match. " f"Requested mixed precision but checkpoint is saved without. " f"This will almost surely lead to performance degradation." ) self.logger.info(f"World size: {communication.get_world_size()}.") self.logger.info(f"Device count: {torch.cuda.device_count()}.") if communication.get_world_size() > 1: self.model = DistributedDataParallel( self.model, device_ids=[communication.get_local_rank()], broadcast_buffers=False, ) # World size > 1 if distributed mode, else allow a DataParallel fallback, can be convenient for debugging. elif torch.cuda.device_count() > 1 and communication.get_world_size( ) == 1: self.model = DataParallel(self.model) self.__writers = ([ JSONWriter(experiment_directory / "metrics.json"), CommonMetricPrinter(self.cfg.training.num_iterations), TensorboardWriter(experiment_directory / "tensorboard"), ] if communication.is_main_process() else []) with EventStorage(start_iter): self.training_loop( training_loader, start_iter, validation_loaders, experiment_directory=experiment_directory, ) self.logger.info("Training completed.")
def train( self, optimizer: torch.optim.Optimizer, lr_scheduler: torch.optim.lr_scheduler._LRScheduler, # noqa training_datasets: List[Dataset], experiment_directory: pathlib.Path, validation_datasets: Optional[Dataset] = None, resume: bool = False, start_with_validation: bool = False, initialization: Optional[PathOrString] = None, num_workers: int = 6, ) -> None: self.logger.info("Starting training.") # Can consider not to make this a member of self, but that requires that optimizer is passed to # training_loop() self.__optimizer = optimizer self.__lr_scheduler = lr_scheduler self.models_to_device() # Optimizer self.__optimizer.zero_grad() # type: ignore # Mixed precision setup. This requires the model to be on the gpu. git_hash = direct.utils.git_hash() checkpointing_metadata = { "__author__": git_hash if git_hash else "N/A", "__version__": direct.__version__, "__mixed_precision__": self.mixed_precision, } if self.mixed_precision: # TODO(jt): Check if on GPU self.logger.info("Using mixed precision training.") self.checkpointer = Checkpointer( save_directory=experiment_directory, save_to_disk=False if not communication.is_main_process() else True, model=self.model, optimizer=optimizer, lr_scheduler=lr_scheduler, scaler=self._scaler, **checkpointing_metadata, # type: ignore **self.models, # type: ignore ) # Load checkpoint start_iter = 0 checkpoint = {} if resume: self.logger.info("Attempting to resume...") # This changes the model inplace checkpoint = self.checkpointer.load(iteration="latest") if not checkpoint: self.logger.info("No checkpoint found. Starting from scratch.") else: start_iter = checkpoint["iteration"] + 1 self.logger.info(f"Starting from iteration: {start_iter}.") if start_iter > 0 and initialization: self.logger.warning( f"Initialization checkpoint set to {initialization}," f" but model will resume training from previous checkpoint. Initialization ignored." ) elif initialization: self.logger.info(f"Initializing from {initialization}...") self.checkpointer.load_models_from_file(initialization) start_with_validation = True self.logger.info("Setting start_with_validation to True.") if "__version__" in checkpoint: self.logger.info( f"DIRECT version of checkpoint: {checkpoint['__version__']}.") if checkpoint["__version__"] != direct.__version__: self.logger.warning( f"Current DIRECT version {direct.__version__} is different from the one " f"this checkpoint is saved with: {checkpoint['__version__']}. This can be fine, " f"but beware that this can be a source of confusion.") if "__author__" in checkpoint: self.logger.info( f"Git hash of checkpoint: {checkpoint['__author__']}.") if checkpoint["__author__"] != direct.utils.git_hash(): self.logger.warning( f"Current git hash {direct.utils.git_hash()} is different from the one " f"this checkpoint is saved with: {checkpoint['__author__']}. This can be fine, " f"but beware that this can be a source of confusion.") if "__datetime__" in checkpoint: self.logger.info( f"Checkpoint created at: {checkpoint['__datetime__']}.") if "__mixed_precision__" in checkpoint: if (not self.mixed_precision ) and checkpoint["__mixed_precision__"]: self.logger.warning( "Mixed precision training is not enabled, yet saved checkpoint requests this" f"Will now enable mixed precision.") self.mixed_precision = True elif not checkpoint["__mixed_precision__"] and self.mixed_precision: self.logger.warning( "Mixed precision levels of training and loading checkpoint do not match. " f"Requested mixed precision but checkpoint is saved without. " f"This will almost surely lead to performance degradation." ) if start_with_validation: self.logger.info("Requested to start with validation.") self.logger.info(f"World size: {communication.get_world_size()}.") self.logger.info(f"Device count: {torch.cuda.device_count()}.") if communication.get_world_size() > 1: self.model = DistributedDataParallel( self.model, device_ids=[communication.get_local_rank()], broadcast_buffers=False, ) # World size > 1 if distributed mode, else allow a DataParallel fallback, can be convenient for debugging. elif torch.cuda.device_count() > 1 and communication.get_world_size( ) == 1: self.model = DataParallel(self.model) self.__writers = ([ JSONWriter(experiment_directory / "metrics.json"), CommonMetricPrinter( self.cfg.training.num_iterations), # type: ignore TensorboardWriter(experiment_directory / "tensorboard"), ] if communication.is_main_process() else []) with EventStorage(start_iter): self.training_loop( training_datasets, start_iter, validation_datasets, experiment_directory=experiment_directory, num_workers=num_workers, start_with_validation=start_with_validation, ) self.logger.info("Training completed.")