def test_configure_optimizers(self, model: pl.LightningModule): r"""Tests that ``model.configure_optimizers()`` runs and returns the required outputs. """ model = expose_distributed_model(model) optim = model.configure_optimizers() is_optimizer = isinstance(optim, torch.optim.Optimizer) is_optim_schedule_tuple = ( isinstance(optim, tuple) and len(optim) == 2 and isinstance(optim[0], list) and all([isinstance(x, torch.optim.Optimizer) for x in optim[0]]) and isinstance(optim[1], list) and all([isinstance(x, torch.optim.lr_scheduler._LRScheduler) for x in optim[0]]) ) assert is_optimizer or is_optim_schedule_tuple return optim
def get_optimizers_only(module: pl.LightningModule) -> IterOptimizer: optimizers = module.configure_optimizers() if isinstance(optimizers, Optimizer): return [optimizers] if len(optimizers) > 1: # either list of optimizers or list of optimizers + schedulers if isinstance(optimizers[1], Optimizer): return optimizers else: if isinstance(optimizers[0], Optimizer): return [optimizers[0]] else: return optimizers[0] else: # list with single optimizer assert isinstance(optimizers[0], Optimizer) return optimizers
def init_wandb_logger(project_config: dict, run_config: dict, lit_model: pl.LightningModule, datamodule: pl.LightningDataModule, log_path: str = "logs/") -> pl.loggers.WandbLogger: """Initialize Weights&Biases logger.""" # with this line wandb will throw an error if the run to be resumed does not exist yet # instead of auto-creating a new run os.environ["WANDB_RESUME"] = "must" resume_from_checkpoint = run_config.get("resume_training", {}).get("resume_from_checkpoint", None) wandb_run_id = run_config.get("resume_training", {}).get("wandb_run_id", None) wandb_logger = WandbLogger( project=project_config["loggers"]["wandb"]["project"], entity=project_config["loggers"]["wandb"]["entity"], log_model=project_config["loggers"]["wandb"]["log_model"], offline=project_config["loggers"]["wandb"]["offline"], group=run_config.get("wandb", {}).get("group", None), job_type=run_config.get("wandb", {}).get("job_type", "train"), tags=run_config.get("wandb", {}).get("tags", []), notes=run_config.get("wandb", {}).get("notes", ""), # resume run only if ckpt was set in the run config id=wandb_run_id if resume_from_checkpoint != "None" and wandb_run_id != "None" and resume_from_checkpoint is not None and resume_from_checkpoint is not False and wandb_run_id is not False else None, save_dir=log_path, save_code=False) if not os.path.exists(log_path): os.makedirs(log_path) if hasattr(lit_model, 'model'): wandb_logger.watch(lit_model.model, log=None) else: wandb_logger.watch(lit_model, log=None) wandb_logger.log_hyperparams({ "model": lit_model.model.__class__.__name__, "optimizer": lit_model.configure_optimizers().__class__.__name__, "train_size": len(datamodule.data_train) if hasattr(datamodule, 'data_train') and datamodule.data_train is not None else 0, "val_size": len(datamodule.data_val) if hasattr(datamodule, 'data_val') and datamodule.data_val is not None else 0, "test_size": len(datamodule.data_test) if hasattr(datamodule, 'data_test') and datamodule.data_test is not None else 0, }) wandb_logger.log_hyperparams(run_config["trainer"]) wandb_logger.log_hyperparams(run_config["model"]) wandb_logger.log_hyperparams(run_config["dataset"]) return wandb_logger