Ejemplo n.º 1
0
 def test_configure_optimizers(self, model: pl.LightningModule):
     r"""Tests that ``model.configure_optimizers()`` runs and returns the required
     outputs.
     """
     model = expose_distributed_model(model)
     optim = model.configure_optimizers()
     is_optimizer = isinstance(optim, torch.optim.Optimizer)
     is_optim_schedule_tuple = (
         isinstance(optim, tuple)
         and len(optim) == 2
         and isinstance(optim[0], list)
         and all([isinstance(x, torch.optim.Optimizer) for x in optim[0]])
         and isinstance(optim[1], list)
         and all([isinstance(x, torch.optim.lr_scheduler._LRScheduler) for x in optim[0]])
     )
     assert is_optimizer or is_optim_schedule_tuple
     return optim
def get_optimizers_only(module: pl.LightningModule) -> IterOptimizer:
    optimizers = module.configure_optimizers()
    if isinstance(optimizers, Optimizer):
        return [optimizers]

    if len(optimizers) > 1:
        # either list of optimizers or list of optimizers + schedulers
        if isinstance(optimizers[1], Optimizer):
            return optimizers
        else:
            if isinstance(optimizers[0], Optimizer):
                return [optimizers[0]]
            else:
                return optimizers[0]
    else:
        # list with single optimizer
        assert isinstance(optimizers[0], Optimizer)
        return optimizers
Ejemplo n.º 3
0
def init_wandb_logger(project_config: dict,
                      run_config: dict,
                      lit_model: pl.LightningModule,
                      datamodule: pl.LightningDataModule,
                      log_path: str = "logs/") -> pl.loggers.WandbLogger:
    """Initialize Weights&Biases logger."""

    # with this line wandb will throw an error if the run to be resumed does not exist yet
    # instead of auto-creating a new run
    os.environ["WANDB_RESUME"] = "must"

    resume_from_checkpoint = run_config.get("resume_training",
                                            {}).get("resume_from_checkpoint",
                                                    None)
    wandb_run_id = run_config.get("resume_training",
                                  {}).get("wandb_run_id", None)

    wandb_logger = WandbLogger(
        project=project_config["loggers"]["wandb"]["project"],
        entity=project_config["loggers"]["wandb"]["entity"],
        log_model=project_config["loggers"]["wandb"]["log_model"],
        offline=project_config["loggers"]["wandb"]["offline"],
        group=run_config.get("wandb", {}).get("group", None),
        job_type=run_config.get("wandb", {}).get("job_type", "train"),
        tags=run_config.get("wandb", {}).get("tags", []),
        notes=run_config.get("wandb", {}).get("notes", ""),

        # resume run only if ckpt was set in the run config
        id=wandb_run_id if resume_from_checkpoint != "None"
        and wandb_run_id != "None" and resume_from_checkpoint is not None
        and resume_from_checkpoint is not False and wandb_run_id is not False
        else None,
        save_dir=log_path,
        save_code=False)

    if not os.path.exists(log_path):
        os.makedirs(log_path)

    if hasattr(lit_model, 'model'):
        wandb_logger.watch(lit_model.model, log=None)
    else:
        wandb_logger.watch(lit_model, log=None)

    wandb_logger.log_hyperparams({
        "model":
        lit_model.model.__class__.__name__,
        "optimizer":
        lit_model.configure_optimizers().__class__.__name__,
        "train_size":
        len(datamodule.data_train) if hasattr(datamodule, 'data_train')
        and datamodule.data_train is not None else 0,
        "val_size":
        len(datamodule.data_val) if hasattr(datamodule, 'data_val')
        and datamodule.data_val is not None else 0,
        "test_size":
        len(datamodule.data_test) if hasattr(datamodule, 'data_test')
        and datamodule.data_test is not None else 0,
    })
    wandb_logger.log_hyperparams(run_config["trainer"])
    wandb_logger.log_hyperparams(run_config["model"])
    wandb_logger.log_hyperparams(run_config["dataset"])

    return wandb_logger