def prepare(self, pin_memory: bool = False): """ Prepares the task: - dataloaders - model - copy model to correct device - meters - loss - optimizer - LR schedulers - AMP state - resume from a checkpoint if available """ self.dataloaders = self.build_dataloaders(pin_memory=pin_memory) self.phases = self._build_phases() train_phases = [phase for phase in self.phases if phase["train"]] num_train_phases = len(train_phases) self.base_model = self._build_model() self._set_ddp_options() self.base_loss = self._build_loss() self.meters = self._build_meters() self.optimizer = self._build_optimizer() self.optimizer_schedulers = self._build_optimizer_schedulers() self.num_train_phases = num_train_phases self.base_loss = self.base_loss.to(self.device) if self.device.type == "cuda": self.base_model = copy_model_to_gpu(self.base_model) # initialize the pytorch optimizer now since the model has been moved to # the appropriate device. self.prepare_optimizer() # Enable mixed precision grad scalers if self.amp_type == AmpType.APEX: # Allow Apex Amp to perform casts as specified by the amp_args. # This updates the model and the PyTorch optimizer (which is wrapped # by the ClassyOptimizer in self.optimizer). # NOTE: this must happen before loading the checkpoint. See # https://nvidia.github.io/apex/amp.html#checkpointing for more details. self.base_model, self.optimizer.optimizer = apex.amp.initialize( self.base_model, self.optimizer.optimizer, **self.amp_args) # Restore an hypothetical checkpoint vissl_state_dict = None if self.checkpoint_path is not None: self.checkpoint = CheckpointLoader.load_and_broadcast_checkpoint( checkpoint_folder=self.checkpoint_folder, checkpoint_path=self.checkpoint_path, device=torch.device("cpu"), ) self.iteration = self.checkpoint["iteration"] self.local_iteration_num = self.checkpoint["iteration_num"] vissl_state_dict = self.checkpoint.get("classy_state_dict") if "loss" in self.checkpoint: self.base_loss.load_state_dict(self.checkpoint["loss"]) logging.info("======Loaded loss state from checkpoint======") return self._update_classy_state(vissl_state_dict)
def prepare(self, pin_memory: bool = False): """ Prepares the task: - dataloaders - model - copy model to correct device - meters - loss - optimizer - LR schedulers - AMP state - resume from a checkpoint if available """ self.phases = self._build_phases() self.num_phases = len(self.phases) self.base_model = self._build_model() self._set_ddp_options() self.meters = self._build_meters() self.optimizer = self._build_optimizer() self.optimizer_schedulers = self._build_optimizer_schedulers() if self.device.type == "cuda": self.base_model = copy_model_to_gpu(self.base_model) # initialize the pytorch optimizer now since the model has been moved to # the appropriate device. self.prepare_optimizer() # Enable mixed precision grad scalers if self.amp_type == AmpType.APEX: # Allow Apex Amp to perform casts as specified by the amp_args. # This updates the model and the PyTorch optimizer (which is wrapped # by the ClassyOptimizer in self.optimizer). # NOTE: this must happen before loading the checkpoint. See # https://nvidia.github.io/apex/amp.html#checkpointing for more details. self.base_model, self.optimizer.optimizer = apex.amp.initialize( self.base_model, self.optimizer.optimizer, **self.amp_args ) # Create EMA average of the model if hook is specified. ema_config = self.config["HOOKS"]["EMA_MODEL"] if ema_config["ENABLE_EMA_METERS"] or ema_config["SAVE_EMA_MODEL"]: self._create_ema_model() # Restore an hypothetical checkpoint vissl_state_dict = None if self.checkpoint_path is not None: self.checkpoint = CheckpointLoader.load_and_broadcast_checkpoint( checkpoint_folder=self.checkpoint_folder, checkpoint_path=self.checkpoint_path, device=torch.device("cpu"), ) if self.checkpoint is not None: self.iteration = self.checkpoint["iteration"] self.local_iteration_num = self.checkpoint["iteration_num"] vissl_state_dict = self.checkpoint.get("classy_state_dict") else: raise ValueError(f"Could not load checkpoint: {self.checkpoint_path}") current_train_phase_idx = ( vissl_state_dict["train_phase_idx"] + 1 if vissl_state_dict else 0 ) self.datasets, self.data_and_label_keys = self.build_datasets( current_train_phase_idx ) # set dataset state before building dataloader, in order to capture checkpoint info. if vissl_state_dict and "train" in self.datasets: self.datasets["train"].set_classy_state( vissl_state_dict.get("train_dataset_iterator") ) self.dataloaders = self.build_dataloaders( pin_memory=pin_memory, current_train_phase_idx=current_train_phase_idx ) # Build base loss, move to device, and load from checkpoint if applicable self.base_loss = self._build_loss() self.base_loss = self.base_loss.to(self.device) if self.checkpoint and "loss" in self.checkpoint: self.base_loss.load_state_dict(self.checkpoint["loss"]) logging.info("======Loaded loss state from checkpoint======") return self._update_classy_state(vissl_state_dict)