def test_update_classy_state(self): """ Tests that the update_classy_state successfully updates from a checkpoint """ config = get_fast_test_task_config() task = build_task(config) task_2 = build_task(config) task_2.prepare() trainer = LocalTrainer() trainer.train(task) update_classy_state(task_2, task.get_classy_state(deep_copy=True)) self._compare_states(task.get_classy_state(), task_2.get_classy_state())
def prepare(self): """Prepares task for training, populates all derived attributes """ pin_memory = self.use_gpu and torch.cuda.device_count() > 1 self.phases = self._build_phases() self.train = False if self.test_only else self.train self.dataloaders = self.build_dataloaders( pin_memory=pin_memory, multiprocessing_context=mp.get_context(self.dataloader_mp_context), ) if self.batch_norm_sync_mode == BatchNormSyncMode.PYTORCH: self.base_model = nn.SyncBatchNorm.convert_sync_batchnorm( self.base_model) elif self.batch_norm_sync_mode == BatchNormSyncMode.APEX: self.base_model = apex.parallel.convert_syncbn_model( self.base_model) # move the model and loss to the right device if self.use_gpu: self.base_model, self.loss = copy_model_to_gpu( self.base_model, self.loss) else: self.loss.cpu() self.base_model.cpu() if self.optimizer is not None: # initialize the pytorch optimizer now since the model has been moved to # the appropriate device self.optimizer.init_pytorch_optimizer(self.base_model, loss=self.loss) if self.amp_args is not None: # Initialize apex.amp. This updates the model and the PyTorch optimizer ( # if training, which is wrapped by the ClassyOptimizer in self.optimizer). # Please note this must happen before loading the checkpoint, cause # there's amp state to be restored. if self.optimizer is None: self.base_model = apex.amp.initialize(self.base_model, optimizers=None, **self.amp_args) else: self.base_model, self.optimizer.optimizer = apex.amp.initialize( self.base_model, self.optimizer.optimizer, **self.amp_args) if self.checkpoint_path: self.checkpoint_dict = load_and_broadcast_checkpoint( self.checkpoint_path) classy_state_dict = (None if self.checkpoint_dict is None else self.checkpoint_dict["classy_state_dict"]) if classy_state_dict is not None: state_load_success = update_classy_state(self, classy_state_dict) assert (state_load_success ), "Update classy state from checkpoint was unsuccessful." self.init_distributed_data_parallel_model()
def prepare(self, num_dataloader_workers=0, dataloader_mp_context=None): """Prepares task for training, populates all derived attributes Args: num_dataloader_workers: Number of dataloading processes. If 0, dataloading is done on main process dataloader_mp_context: Determines how processes are spawned. Value must be one of None, "spawn", "fork", "forkserver". If None, then context is inherited from parent process """ pin_memory = self.use_gpu and torch.cuda.device_count() > 1 self.phases = self._build_phases() self.dataloaders = self.build_dataloaders( num_workers=num_dataloader_workers, pin_memory=pin_memory, multiprocessing_context=dataloader_mp_context, ) if self.batch_norm_sync_mode == BatchNormSyncMode.PYTORCH: self.base_model = nn.SyncBatchNorm.convert_sync_batchnorm( self.base_model) elif self.batch_norm_sync_mode == BatchNormSyncMode.APEX: self.base_model = apex.parallel.convert_syncbn_model( self.base_model) # move the model and loss to the right device if self.use_gpu: self.base_model, self.loss = copy_model_to_gpu( self.base_model, self.loss) else: self.loss.cpu() self.base_model.cpu() # initialize the pytorch optimizer now since the model has been moved to # the appropriate device self.optimizer.init_pytorch_optimizer(self.base_model, loss=self.loss) if self.amp_args is not None: # Initialize apex.amp. This updates the model and the PyTorch optimizer ( # which is wrapped by the ClassyOptimizer in self.optimizer). # Please note this must happen before loading the checkpoint, cause # there's amp state to be restored. self.base_model, self.optimizer.optimizer = apex.amp.initialize( self.base_model, self.optimizer.optimizer, **self.amp_args) classy_state_dict = (None if self.checkpoint is None else self.checkpoint.get("classy_state_dict")) if classy_state_dict is not None: state_load_success = update_classy_state(self, classy_state_dict) assert (state_load_success ), "Update classy state from checkpoint was unsuccessful." self.init_distributed_data_parallel_model()
def prepare(self): """Prepares task for training, populates all derived attributes """ self.phases = self._build_phases() self.train = False if self.test_only else self.train if self.batch_norm_sync_mode == BatchNormSyncMode.PYTORCH: self.base_model = nn.SyncBatchNorm.convert_sync_batchnorm( self.base_model) elif self.batch_norm_sync_mode == BatchNormSyncMode.APEX: sync_bn_process_group = apex.parallel.create_syncbn_process_group( self.batch_norm_sync_group_size) self.base_model = apex.parallel.convert_syncbn_model( self.base_model, process_group=sync_bn_process_group) # move the model and loss to the right device if self.use_gpu: self.base_model, self.base_loss = copy_model_to_gpu( self.base_model, self.base_loss) else: self.base_loss.cpu() self.base_model.cpu() if self.optimizer is not None: self.prepare_optimizer(optimizer=self.optimizer, model=self.base_model, loss=self.base_loss) if self.amp_args is not None: # Initialize apex.amp. This updates the model and the PyTorch optimizer ( # if training, which is wrapped by the ClassyOptimizer in self.optimizer). # Please note this must happen before loading the checkpoint, cause # there's amp state to be restored. if self.optimizer is None: self.base_model = apex.amp.initialize(self.base_model, optimizers=None, **self.amp_args) else: self.base_model, self.optimizer.optimizer = apex.amp.initialize( self.base_model, self.optimizer.optimizer, **self.amp_args) if self.checkpoint_path: self.checkpoint_dict = load_and_broadcast_checkpoint( self.checkpoint_path) classy_state_dict = (None if self.checkpoint_dict is None else self.checkpoint_dict["classy_state_dict"]) if classy_state_dict is not None: state_load_success = update_classy_state(self, classy_state_dict) assert (state_load_success ), "Update classy state from checkpoint was unsuccessful." self.init_distributed_data_parallel_model()
def prepare( self, num_dataloader_workers=0, pin_memory=False, use_gpu=False, dataloader_mp_context=None, ): """Prepares task for training, populates all derived attributes Args: num_dataloader_workers: Number of dataloading processes. If 0, dataloading is done on main process pin_memory: if true pin memory on GPU use_gpu: if true, load model, optimizer, loss, etc on GPU dataloader_mp_context: Determines how processes are spawned. Value must be one of None, "spawn", "fork", "forkserver". If None, then context is inherited from parent process """ self.phases = self._build_phases() self.dataloaders = self.build_dataloaders( num_workers=num_dataloader_workers, pin_memory=pin_memory, multiprocessing_context=dataloader_mp_context, ) # move the model and loss to the right device if use_gpu: self.base_model, self.loss = copy_model_to_gpu(self.base_model, self.loss) else: self.loss.cpu() self.base_model.cpu() # initialize the pytorch optimizer now since the model has been moved to # the appropriate device self.optimizer.init_pytorch_optimizer(self.base_model, loss=self.loss) classy_state_dict = ( None if self.checkpoint is None else self.checkpoint.get("classy_state_dict") ) if classy_state_dict is not None: state_load_success = update_classy_state(self, classy_state_dict) assert ( state_load_success ), "Update classy state from checkpoint was unsuccessful." if self.amp_opt_level is not None: # Initialize apex.amp. This updates the model and the PyTorch optimizer ( # which is wrapped by the ClassyOptimizer in self.optimizer) self.base_model, self.optimizer.optimizer = apex.amp.initialize( self.base_model, self.optimizer.optimizer, opt_level=self.amp_opt_level ) self.init_distributed_data_parallel_model()