def __init_monitor_mode(self): # TODO: Update with MisconfigurationException when auto mode is removed in v1.3 if self.mode not in self.mode_dict and self.mode != 'auto': if self.verbose > 0: rank_zero_warn( f'EarlyStopping mode={self.mode} is unknown, fallback to auto mode.', RuntimeWarning, ) self.mode = 'auto' if self.mode == 'auto': rank_zero_warn( "mode='auto' is deprecated in v1.1 and will be removed in v1.3." " Default value for mode with be 'min' in v1.3.", DeprecationWarning) if "acc" in self.monitor or self.monitor.startswith("fmeasure"): self.mode = 'max' else: self.mode = 'min' if self.verbose > 0: rank_zero_info( f'EarlyStopping mode set to {self.mode} for monitoring {self.monitor}.' )
def _save_top_k_checkpoint(self, trainer, monitor_candidates: Dict[str, Any]): if self.monitor is None or self.save_top_k == 0: return current = monitor_candidates.get(self.monitor) epoch = monitor_candidates.get("epoch") step = monitor_candidates.get("step") # when `val_loss` is being logged and no ModelCheckpoint is being provided # `val_loss` will be selected for monitor and need to be reduced to # prevent processes divergence # TODO: Move this logic to logger_connector. This also needs to be fixed for any # other monitor logged value which aren't produced from a Metric. if self.monitor == "val_loss": current = trainer.training_type_plugin.reduce(current, reduce_op="mean") if self.check_monitor_top_k(current): self._update_best_and_save(current, epoch, step, trainer, monitor_candidates) elif self.verbose: rank_zero_info( f"Epoch {epoch:d}, step {step:d}: {self.monitor} was not in top {self.save_top_k}" )
def on_load_checkpoint(self, checkpointed_state: Dict[str, Any]) -> None: """ Called before model state is restored. Explicitly handles old model states so we can resume training from D2Go checkpoints transparently. Args: checkpointed_state: The raw checkpoint state as returned by torch.load or equivalent. """ # If this is a non-Lightning checkpoint, we need to convert it. if not _is_lightning_checkpoint(checkpointed_state) and not _is_d2go_checkpoint( checkpointed_state ): raise ValueError( f"Invalid checkpoint state with keys: {checkpointed_state.keys()}" ) if not _is_lightning_checkpoint(checkpointed_state): _convert_to_lightning(checkpointed_state) if self.ema_state: if "model_ema" not in checkpointed_state: rank_zero_info( "EMA is enabled but EMA state is not found in given checkpoint" ) else: self.ema_state = EMAState() self.ema_state.load_state_dict(checkpointed_state["model_ema"]) if not self.ema_state.device: # EMA state device not given, move to module device self.ema_state.to(self.device)
def _initialize_model_specific_parameters(self): task_specific_params = self.model.config.task_specific_params if task_specific_params: pars = task_specific_params.get(self.task, {}) rank_zero_info(f"Overriding model paramameters for {self.task} as defined within the model:\n {pars}") self.model.config.update(pars)
def on_validation_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule): rank_zero_info("***** Validation results *****") metrics = trainer.callback_metrics # Log results for key in sorted(metrics): rank_zero_info("{} = {}\n".format(key, str(metrics[key])))
def on_validation_end(self, trainer: Trainer, pl_module: LightningModule): rank_zero_info("***** Validation results *****") metrics = trainer.callback_metrics # Log results for key in sorted(metrics): if key not in ["log", "progress_bar"]: rank_zero_info("{} = {}\n".format(key, str(metrics[key])))
def resume_start(self) -> None: """ Attempts to pre-load the checkpoint file to memory, with the source path determined in this priority: 1. from HPC weights if found 2. from `resume_from_checkpoint` file if provided 3. don't restore Raises: FileNotFoundError: If the path to the checkpoint file is provided but the file does not exist. """ self.resume_checkpoint_path = self.hpc_resume_path or self.resume_checkpoint_path checkpoint_path = self.resume_checkpoint_path if not checkpoint_path: return # clear cache before restore torch.cuda.empty_cache() # Try to read the checkpoint file at `checkpoint_path`. If not exist, do not restore checkpoint. fs = get_filesystem(checkpoint_path) if not fs.exists(checkpoint_path): raise FileNotFoundError(f"Checkpoint at {checkpoint_path} not found. Aborting training.") rank_zero_info(f"Restoring states from the checkpoint file at {checkpoint_path}") self._loaded_checkpoint = self.trainer.training_type_plugin.load_checkpoint_file(checkpoint_path)
def _register_function(self, fn: Callable, name: Optional[str] = None, override: bool = False, metadata: Optional[Dict[str, Any]] = None): if not isinstance(fn, FunctionType) and not isinstance(fn, partial): raise MisconfigurationException( f"You can only register a function, found: {fn}") name = name or fn.__name__ if self._verbose: rank_zero_info( f"Registering: {fn.__name__} function with name: {name} and metadata: {metadata}" ) item = {"fn": fn, "name": name, "metadata": metadata or {}} matching_index = self._find_matching_index(item) if override and matching_index is not None: self.functions[matching_index] = item else: if matching_index is not None: raise MisconfigurationException( f"Function with name: {name} and metadata: {metadata} is already present within {self}." " HINT: Use `override=True`.") self.functions.append(item)
def restore(self, checkpoint_path: str, on_gpu: bool) -> bool: """ Load model/training states from a 'PyTorch-Lightning checkpoint' file through file-read and state-restore. All restored states are listed in return value description of `dump_checkpoint`. """ # Try to read the checkpoint file at `checkpoint_path`. If not exist, do not restore checkpoint. fs = get_filesystem(checkpoint_path) if not fs.exists(checkpoint_path): rank_zero_warn("No checkpoint file exists at `resume_from_checkpoint`. Start from scratch") return False # read a checkpoint dictionary object from the 'PyTorch-Lightning checkpoint' file at `checkpoint_path` checkpoint = pl_load(checkpoint_path, map_location=lambda storage, loc: storage) # acquire the model model = self.trainer.get_model() # restore model and datamodule state self.restore_model_state(model, checkpoint) if on_gpu: model.cuda(self.trainer.root_gpu) # restore training state self.restore_training_state(checkpoint) rank_zero_info(f"Restored states from the checkpoint file at {checkpoint_path}") return True
def configure_optimizers(self) -> Dict: if self.instantiator is None: rank_zero_warn( "You haven't specified an optimizer or lr scheduler. " "Defaulting to AdamW with an lr of 1e-5 and linear warmup for 10% of steps. " "To change this, either use Hydra configs or override ``configure_optimizers`` in the Task." "For more information: <todo>") self._set_default_optimizer_scheduler() return super().configure_optimizers() self.optimizer = self.instantiator.optimizer(self.model, self.optimizer_cfg) # compute_warmup needs the datamodule to be available when `self.num_training_steps` # is called that is why this is done here and not in the __init__ self.scheduler_cfg.num_training_steps, self.scheduler_cfg.num_warmup_steps = self.compute_warmup( num_training_steps=self.scheduler_cfg.num_training_steps, num_warmup_steps=self.scheduler_cfg.num_warmup_steps, ) rank_zero_info( f"Inferring number of training steps, set to {self.scheduler_cfg.num_training_steps}" ) rank_zero_info( f"Inferring number of warmup steps from ratio, set to {self.scheduler_cfg.num_warmup_steps}" ) self.scheduler = self.instantiator.scheduler(self.scheduler_cfg, self.optimizer) return super().configure_optimizers()
def on_train_batch_start( self, trainer: Trainer, pl_module: LightningModule, batch: Any, batch_idx: int, dataloader_idx: int, ) -> None: """ Applies model transforms at as specified during training. """ apply_only_once = [] current_step = trainer.global_step for i, transform in enumerate(self.transforms): if (transform.step is not None and transform.step <= current_step) or ( transform.interval is not None and current_step % transform.interval == 0 ): self.prepared.apply(transform.fn) rank_zero_info( f"[QAT] {transform.message} at step={trainer.global_step}." ) if transform.step is not None and transform.step <= current_step: apply_only_once.append(i) if apply_only_once: self.transforms = [ transform for i, transform in enumerate(self.transforms) if i not in set(apply_only_once) ]
def __validate_init_configuration(self): if self.save_top_k is not None and self.save_top_k < -1: raise MisconfigurationException( f'Invalid value for save_top_k={self.save_top_k}. Must be None or >= -1' ) if self._every_n_train_steps < 0: raise MisconfigurationException( f'Invalid value for every_n_train_steps={self._every_n_train_steps}. Must be >= 0' ) if self._every_n_val_epochs < 0: raise MisconfigurationException( f'Invalid value for every_n_val_epochs={self._every_n_val_epochs}. Must be >= 0' ) if self._every_n_train_steps > 0 and self._every_n_val_epochs > 0: raise MisconfigurationException( f'Invalid values for every_n_train_steps={self._every_n_train_steps}' ' and every_n_val_epochs={self._every_n_val_epochs}.' ' Both cannot be enabled at the same time.') if self.monitor is None: # None: save last epoch, -1: save all epochs, 0: nothing is saved if self.save_top_k not in (None, -1, 0): raise MisconfigurationException( f'ModelCheckpoint(save_top_k={self.save_top_k}, monitor=None) is not a valid' ' configuration. No quantity for top_k to track.') if self.save_last: rank_zero_warn( 'ModelCheckpoint(save_last=True, save_top_k=None, monitor=None) is a redundant configuration.' ' You can save the last checkpoint with ModelCheckpoint(save_top_k=None, monitor=None).' ) if self.save_top_k == -1 and self.save_last: rank_zero_info( 'ModelCheckpoint(save_last=True, save_top_k=-1, monitor=None)' ' will duplicate the last checkpoint saved.')
def restore(self, checkpoint_path: str, on_gpu: bool) -> bool: """ Load model/training states from a 'PyTorch-Lightning checkpoint' file through file-read and state-restore. All restored states are listed in return value description of `dump_checkpoint`. """ # Try to read the checkpoint file at `checkpoint_path`. If not exist, do not restore checkpoint. fs = get_filesystem(checkpoint_path) if not fs.exists(checkpoint_path): raise FileNotFoundError( f"Checkpoint at {checkpoint_path} not found. Aborting training." ) checkpoint, load_optimizer_states = self.trainer.training_type_plugin.restore_model_state_from_ckpt_path( checkpoint_path, map_location=lambda storage, loc: storage) model = self.trainer.lightning_module if on_gpu: model.cuda(self.trainer.root_gpu) # restore training state self.restore_training_state(checkpoint, load_optimizer_states) rank_zero_info( f"Restored states from the checkpoint file at {checkpoint_path}") return True
def restore_weights(self, model: LightningModule) -> None: """ Attempt to restore a checkpoint (e.g. weights) in this priority: 1. from HPC weights 2. from `resume_from_checkpoint` file 3. don't restore """ # clear cache before restore if self.trainer.on_gpu: torch.cuda.empty_cache() # 1. Attempt to restore states from HPC checkpoint dir_path_hpc = str(self.trainer.weights_save_path) max_suffix = self.max_ckpt_in_folder(dir_path_hpc, "hpc_ckpt_") if max_suffix is not None: checkpoint_path = f'{dir_path_hpc}/hpc_ckpt_{max_suffix}.ckpt' self.hpc_load(checkpoint_path, self.trainer.on_gpu) rank_zero_info(f'restored hpc model from: {checkpoint_path}') # 2. Attempt to restore states from `resume_from_checkpoint` file elif self.trainer.resume_from_checkpoint is not None: self.restore(self.trainer.resume_from_checkpoint, on_gpu=self.trainer.on_gpu) # wait for all to catch up self.trainer.accelerator_backend.barrier('TrainerIOMixin.restore_weights') # clear cache after restore if self.trainer.on_gpu: torch.cuda.empty_cache()
def configure_slurm_ddp(self): # extract SLURM flag vars # whenever we have the correct number of tasks, we let slurm manage processes # otherwise we launch the required number of processes if self.use_ddp or self.use_ddp2: num_requested_gpus = self.num_gpus * self.num_nodes num_slurm_tasks = 0 try: num_slurm_tasks = int(os.environ["SLURM_NTASKS"]) self.is_slurm_managing_tasks = num_slurm_tasks == num_requested_gpus # enable slurm cpu if num_requested_gpus == 0: self.is_slurm_managing_tasks = num_slurm_tasks == self.num_processes # in interactive mode we don't manage tasks job_name = os.environ["SLURM_JOB_NAME"] if job_name == "bash": self.is_slurm_managing_tasks = False except Exception: # likely not on slurm, so set the slurm managed flag to false self.is_slurm_managing_tasks = False # used for tests only, set this flag to simulate slurm managing a task try: should_fake = int(os.environ["FAKE_SLURM_MANAGING_TASKS"]) if should_fake: self.is_slurm_managing_tasks = True except Exception: pass # notify user the that slurm is managing tasks if self.is_slurm_managing_tasks: rank_zero_info("Multi-processing is handled by Slurm.")
def on_epoch_end(self, trainer, pl_module): logs = trainer.callback_metrics self.epochs_since_last_check += 1 if self.epochs_since_last_check >= self.period: self.epochs_since_last_check = 0 current = logs.get(self.monitor) if current is None: warnings.warn( f"Can save best module state only with {self.monitor} available," " skipping.", RuntimeWarning, ) else: if isinstance(current, torch.Tensor): current = current.item() if self.check_monitor_top(current): self.best_module_state = deepcopy(pl_module.module.state_dict()) self.best_module_metric_val = current if self.verbose: rank_zero_info( f"\nEpoch {trainer.current_epoch:05d}: {self.monitor} reached." f" Module best state updated." )
def _check_and_init_precision(self) -> PrecisionPlugin: self._validate_precision_choice() if isinstance(self._precision_plugin_flag, PrecisionPlugin): return self._precision_plugin_flag if isinstance(self.accelerator, IPUAccelerator): return IPUPrecisionPlugin(self._precision_flag) # type: ignore if isinstance(self.accelerator, HPUAccelerator): return HPUPrecisionPlugin(self._precision_flag) # type: ignore if isinstance(self.accelerator, TPUAccelerator): if self._precision_flag == 32: return TPUPrecisionPlugin() elif self._precision_flag in (16, "bf16"): if self._precision_flag == 16: rank_zero_warn( "You passed `Trainer(accelerator='tpu', precision=16)` but AMP" " is not supported with TPUs. Using `precision='bf16'` instead." ) return TPUBf16PrecisionPlugin() if isinstance(self.strategy, DeepSpeedStrategy): return DeepSpeedPrecisionPlugin( self._precision_flag, self._amp_type_flag, self._amp_level_flag # type: ignore ) if self._precision_flag == 32: return PrecisionPlugin() if self._precision_flag == 64: return DoublePrecisionPlugin() if self._precision_flag == 16 and self._accelerator_flag == "cpu": rank_zero_warn( "You passed `Trainer(accelerator='cpu', precision=16)` but native AMP is not supported on CPU." " Using `precision='bf16'` instead.") self._precision_flag = "bf16" if self._precision_flag in (16, "bf16"): rank_zero_info( f"Using 16bit {self._amp_type_flag.value} Automatic Mixed Precision (AMP)" # type: ignore if self._precision_flag == 16 else "Using bfloat16 Automatic Mixed Precision (AMP)") if self._amp_type_flag == AMPType.NATIVE: device = "cpu" if self._accelerator_flag == "cpu" else "cuda" if isinstance(self.strategy, (DDPShardedStrategy, DDPSpawnShardedStrategy)): return ShardedNativeMixedPrecisionPlugin( self._precision_flag, device) if isinstance(self.strategy, DDPFullyShardedStrategy): return FullyShardedNativeMixedPrecisionPlugin( self._precision_flag, device) return NativeMixedPrecisionPlugin(self._precision_flag, device) if self._amp_type_flag == AMPType.APEX: self._amp_level_flag = self._amp_level_flag or "O2" return ApexMixedPrecisionPlugin(self._amp_level_flag) raise RuntimeError("No precision set")
def on_init_start(self, limit_train_batches, limit_val_batches, limit_test_batches, val_check_interval, overfit_batches, fast_dev_run): self.trainer.fast_dev_run = fast_dev_run if self.trainer.fast_dev_run: limit_train_batches = 1 limit_val_batches = 1 limit_test_batches = 1 self.trainer.num_sanity_val_steps = 0 self.trainer.max_epochs = 1 rank_zero_info( 'Running in fast_dev_run mode: will run a full train,' ' val and test loop using a single batch') self.trainer.limit_train_batches = _determine_batch_limits( limit_train_batches, 'limit_train_batches') self.trainer.limit_val_batches = _determine_batch_limits( limit_val_batches, 'limit_val_batches') self.trainer.limit_test_batches = _determine_batch_limits( limit_test_batches, 'limit_test_batches') self.trainer.val_check_interval = _determine_batch_limits( val_check_interval, 'val_check_interval') self.trainer.overfit_batches = _determine_batch_limits( overfit_batches, 'overfit_batches') self.determine_data_use_amount(self.trainer.overfit_batches)
def infer_test_tta(config: Config, test_source: DataSource, transforms, experiment_name): experiment_root = get_my_isic2020_experiments_root() / experiment_name experiment_index_path = experiment_root / "index.json" experiment = json.load(experiment_index_path.open("r")) test_loader = DataLoader( MelanomaDataset(test_source, train=False, transforms=transforms), batch_size=config.batch_size, num_workers=config.num_workers, ) sample_submission_path = get_my_isic2020_csv_root( ) / "sample_submission.csv" label = cast(pd.DataFrame, pd.read_csv( sample_submission_path))["image_name"].iloc[:len(test_source.df)] inferences = [] model_type = experiment["model_type"] for ckpt in experiment["checkpoints"]: fold_index = ckpt["fold_index"] n_fold = ckpt["n_fold"] ckpt_path = experiment_root / ckpt["file"] rank_zero_info( f"Infer test data using {str(ckpt_path)} - {fold_index} / {n_fold}" ) model = load_from_checkpoint(model_type, ckpt_path) inference = Classifier( model, tta_epochs=config.tta_epochs).predict(test_loader) inferences.append((f"fold-{fold_index}", inference.ravel())) result = pd.concat([label, pd.DataFrame(dict(inferences))], axis=1) result.to_csv(f"cv_test_{model_type}.csv", index=False)
def _attach_model_callbacks(self) -> None: """Attaches the callbacks defined in the model. If a callback returned by the model's configure_callback method has the same type as one or several callbacks already present in the trainer callbacks list, it will replace them. In addition, all :class:`~pytorch_lightning.callbacks.model_checkpoint.ModelCheckpoint` callbacks will be pushed to the end of the list, ensuring they run last. """ model_callbacks = self.trainer.call_hook("configure_callbacks") if not model_callbacks: return model_callback_types = {type(c) for c in model_callbacks} trainer_callback_types = {type(c) for c in self.trainer.callbacks} override_types = model_callback_types.intersection( trainer_callback_types) if override_types: rank_zero_info( "The following callbacks returned in `LightningModule.configure_callbacks` will override" " existing callbacks passed to Trainer:" f" {', '.join(sorted(t.__name__ for t in override_types))}") # remove all callbacks with a type that occurs in model callbacks all_callbacks = [ c for c in self.trainer.callbacks if type(c) not in override_types ] all_callbacks.extend(model_callbacks) all_callbacks = CallbackConnector._reorder_callbacks(all_callbacks) # TODO: connectors refactor: move callbacks list to connector and do not write Trainer state self.trainer.callbacks = all_callbacks
def resume_end(self) -> None: """Signal the connector that all states have resumed and memory for the checkpoint object can be released.""" assert self.trainer.state.fn is not None if self.resume_checkpoint_path: if self.trainer.state.fn == TrainerFn.FITTING: rank_zero_info( f"Restored all states from the checkpoint file at {self.resume_checkpoint_path}" ) elif self.trainer.state.fn in (TrainerFn.VALIDATING, TrainerFn.TESTING, TrainerFn.PREDICTING): rank_zero_info( f"Loaded model weights from checkpoint at {self.resume_checkpoint_path}" ) # TODO: remove resume_from_checkpoint_fit_path in v1.7 if (self.trainer.state.fn == TrainerFn.FITTING and self.resume_checkpoint_path == self.resume_from_checkpoint_fit_path): self.resume_from_checkpoint_fit_path = None self.resume_checkpoint_path = None self._loaded_checkpoint = {} # clear cache after restore torch.cuda.empty_cache() # wait for all to catch up self.trainer.strategy.barrier("CheckpointConnector.resume_end")
def __validate_init_configuration(self) -> None: if self.save_top_k < -1: raise MisconfigurationException( f"Invalid value for save_top_k={self.save_top_k}. Must be >= -1" ) if self._every_n_train_steps < 0: raise MisconfigurationException( f"Invalid value for every_n_train_steps={self._every_n_train_steps}. Must be >= 0" ) if self._every_n_epochs < 0: raise MisconfigurationException( f"Invalid value for every_n_epochs={self._every_n_epochs}. Must be >= 0" ) every_n_train_steps_triggered = self._every_n_train_steps >= 1 every_n_epochs_triggered = self._every_n_epochs >= 1 train_time_interval_triggered = self._train_time_interval is not None if every_n_train_steps_triggered + every_n_epochs_triggered + train_time_interval_triggered > 1: raise MisconfigurationException( f"Combination of parameters every_n_train_steps={self._every_n_train_steps}, " f"every_n_epochs={self._every_n_epochs} and train_time_interval={self._train_time_interval} " "should be mutually exclusive.") if self.monitor is None: # -1: save all epochs, 0: nothing is saved, 1: save last epoch if self.save_top_k not in (-1, 0, 1): raise MisconfigurationException( f"ModelCheckpoint(save_top_k={self.save_top_k}, monitor=None) is not a valid" " configuration. No quantity for top_k to track.") if self.save_top_k == -1 and self.save_last: rank_zero_info( "ModelCheckpoint(save_last=True, save_top_k=-1, monitor=None)" " will duplicate the last checkpoint saved.")
def on_train_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule"): if trainer.current_epoch == self.swa_start: # move average model to request device. self._average_model = self._average_model.to(self._device or pl_module.device) optimizer = trainer.optimizers[0] if self._swa_lrs is None: self._swa_lrs = [param_group["lr"] for param_group in optimizer.param_groups] if isinstance(self._swa_lrs, float): self._swa_lrs = [self._swa_lrs] * len(optimizer.param_groups) for lr, group in zip(self._swa_lrs, optimizer.param_groups): group["initial_lr"] = lr self._swa_scheduler = SWALR( optimizer, swa_lr=self._swa_lrs, anneal_epochs=self._annealing_epochs, anneal_strategy=self._annealing_strategy, last_epoch=trainer.max_epochs if self._annealing_strategy == "cos" else -1, ) default_scheduler_cfg = _get_default_scheduler_config() assert default_scheduler_cfg["interval"] == "epoch" and default_scheduler_cfg["frequency"] == 1 default_scheduler_cfg["scheduler"] = self._swa_scheduler if trainer.lr_schedulers: scheduler_cfg = trainer.lr_schedulers[0] if scheduler_cfg["interval"] != "epoch" or scheduler_cfg["frequency"] != 1: rank_zero_warn(f"SWA is currently only supported every epoch. Found {scheduler_cfg}") rank_zero_info( f"Swapping scheduler `{scheduler_cfg['scheduler'].__class__.__name__}`" f" for `{self._swa_scheduler.__class__.__name__}`" ) trainer.lr_schedulers[0] = default_scheduler_cfg else: trainer.lr_schedulers.append(default_scheduler_cfg) self.n_averaged = torch.tensor(0, dtype=torch.long, device=pl_module.device) if self.swa_start <= trainer.current_epoch <= self.swa_end: self.update_parameters(self._average_model, pl_module, self.n_averaged, self.avg_fn) # Note: No > here in case the callback is saved with the model and training continues if trainer.current_epoch == self.swa_end + 1: # Transfer weights from average model to pl_module self.transfer_weights(self._average_model, pl_module) # Reset BatchNorm for update self.reset_batch_norm_and_save_state(pl_module) # There is no need to perform either backward or optimizer.step as we are # performing only one pass over the train data-loader to compute activation statistics # Therefore, we will virtually increase `num_training_batches` by 1 and skip backward. trainer.num_training_batches += 1 trainer.fit_loop._skip_backward = True self._accumulate_grad_batches = trainer.accumulate_grad_batches trainer.accumulate_grad_batches = trainer.num_training_batches
def _configure_timer_callback(self, max_time: Optional[Union[str, timedelta, Dict[str, int]]] = None) -> None: if max_time is None: return if any(isinstance(cb, Timer) for cb in self.trainer.callbacks): rank_zero_info("Ignoring `Trainer(max_time=...)`, callbacks list already contains a Timer.") return timer = Timer(duration=max_time, interval="step") self.trainer.callbacks.append(timer)
def print_weight_summary(self): string = 'Summary Teacher:\n' for j in range(len(self.models)): sum = 0 for i in self.models[j].parameters(): sum += i[0].sum() string += f' Teacher Level {j}: WeightSum == {sum}\n' rank_zero_info(string)
def setup(self): rank_zero_info(f'training on {self.trainer.tpu_cores} TPU cores') if not XLA_AVAILABLE: raise MisconfigurationException('No TPU devices found.') # COLAB_GPU is an env var available by default in Colab environments. self.start_method = 'fork' if self.trainer.on_colab_kaggle else 'spawn'
def on_validation_end(self, trainer: pl.Trainer, pl_module): save_json(pl_module.metrics, pl_module.metrics_save_path) rank_zero_info("***** Validation results *****") metrics = trainer.callback_metrics # Log results for key in sorted(metrics): if key not in ["log", "progress_bar", "preds"]: rank_zero_info("{} = {}\n".format(key, str(metrics[key])))
def on_init_start( self, limit_train_batches, limit_val_batches, limit_test_batches, limit_predict_batches, val_check_interval, overfit_batches, fast_dev_run, ): if not isinstance(fast_dev_run, (bool, int)): raise MisconfigurationException( f'fast_dev_run={fast_dev_run} is not a valid configuration.' ' It should be either a bool or an int >= 0') if isinstance(fast_dev_run, int) and (fast_dev_run < 0): raise MisconfigurationException( f'fast_dev_run={fast_dev_run} is not a' ' valid configuration. It should be >= 0.') self.trainer.fast_dev_run = fast_dev_run fast_dev_run = int(fast_dev_run) # set fast_dev_run=True when it is 1, used while logging if fast_dev_run == 1: self.trainer.fast_dev_run = True if fast_dev_run: limit_train_batches = fast_dev_run limit_val_batches = fast_dev_run limit_test_batches = fast_dev_run limit_predict_batches = fast_dev_run self.trainer.fit_loop.max_steps = fast_dev_run self.trainer.num_sanity_val_steps = 0 self.trainer.fit_loop.max_epochs = 1 val_check_interval = 1.0 self.trainer.check_val_every_n_epoch = 1 self.trainer.logger = DummyLogger() rank_zero_info( 'Running in fast_dev_run mode: will run a full train,' f' val, test and prediction loop using {fast_dev_run} batch(es).' ) self.trainer.limit_train_batches = _determine_batch_limits( limit_train_batches, 'limit_train_batches') self.trainer.limit_val_batches = _determine_batch_limits( limit_val_batches, 'limit_val_batches') self.trainer.limit_test_batches = _determine_batch_limits( limit_test_batches, 'limit_test_batches') self.trainer.limit_predict_batches = _determine_batch_limits( limit_predict_batches, 'limit_predict_batches') self.trainer.val_check_interval = _determine_batch_limits( val_check_interval, 'val_check_interval') self.trainer.overfit_batches = _determine_batch_limits( overfit_batches, 'overfit_batches') self.determine_data_use_amount(self.trainer.overfit_batches)
def _save_top_k_checkpoints(self, trainer, pl_module, metrics): current = metrics.get(self.monitor) epoch = metrics.get("epoch") step = metrics.get("step") if self.check_monitor_top_k(current): self._update_best_and_save(current, epoch, step, trainer, pl_module, metrics) elif self.verbose: rank_zero_info(f"Epoch {epoch:d}, step {step:d}: {self.monitor} was not in top {self.save_top_k}")
def _load_backbone(self, repo_or_dir: str, model_name: str, model_path: Optional[str] = None): if model_path: return BarlowTwins.load_from_checkpoint(model_path).backbone rank_zero_info( "No model provided, loading torch hub pre-trained model") return torch.hub.load(repo_or_dir, model_name)