Example #1
0
    def get_logger(self, cfg: DictConfig,
                   save_dir: Path) -> pl_loggers.WandbLogger:
        """Returns the Weights and Biases (wandb) logger object (really an wandb Run object)
        The run object corresponds to a single execution of the script and is returned from `wandb.init()`.

        Args:
            run_id: Unique run id. If run id exists, will continue logging to that run.
            cfg: The entire config got from hydra, for purposes of logging the config of each run in wandb.
            save_dir: Root dir to save wandb log files

        Returns:
            wandb.wandb_sdk.wandb_run.Run: wandb run object. Can be used for logging.
        """
        # Some argument names to wandb are different from the attribute names of the class.
        # Pop the offending attributes before passing to init func.
        args_dict = asdict_filtered(self)
        run_name = args_dict.pop("run_name")
        run_id = args_dict.pop("run_id")

        # If `self.save_hyperparams()` is called in LightningModule, it will save the cfg passed as argument
        # cfg_dict = OmegaConf.to_container(cfg, resolve=True)

        wb_logger = pl_loggers.WandbLogger(name=run_name,
                                           id=run_id,
                                           save_dir=str(save_dir),
                                           **args_dict)

        return wb_logger
 def get_trainer(self, pl_logger: LightningLoggerBase,
                 callbacks: List[Callback],
                 default_root_dir: str) -> pl.Trainer:
     trainer = pl.Trainer(
         logger=pl_logger,
         callbacks=callbacks,
         default_root_dir=default_root_dir,
         **asdict_filtered(self),
     )
     return trainer
Example #3
0
 def get_optimizer(self, model_params) -> torch.optim.Optimizer:
     return torch.optim.SGD(params=model_params, **asdict_filtered(self))
 def get_scheduler(
         self, optimizer: Optimizer
 ) -> torch.optim.lr_scheduler.ReduceLROnPlateau:
     return torch.optim.lr_scheduler.ReduceLROnPlateau(
         optimizer, **asdict_filtered(self))
 def get_scheduler(self,
                   optimizer: Optimizer) -> torch.optim.lr_scheduler.StepLR:
     return torch.optim.lr_scheduler.StepLR(optimizer,
                                            **asdict_filtered(self))
 def get_scheduler(
         self, optimizer: Optimizer) -> torch.optim.lr_scheduler.CyclicLR:
     return torch.optim.lr_scheduler.CyclicLR(optimizer,
                                              cycle_momentum=False,
                                              **asdict_filtered(self))
Example #7
0
 def get_model(self) -> torch.nn.Module:
     return DeepLab(**asdict_filtered(self))
Example #8
0
 def get_callback(self) -> Callback:
     return LearningRateMonitor(**asdict_filtered(self))
Example #9
0
 def get_callback(self, exp_dir: str, cfg: DictConfig) -> Callback:
     return LogMedia(exp_dir=exp_dir, cfg=cfg, **asdict_filtered(self))
Example #10
0
 def get_callback(self, logs_dir) -> Callback:
     return ModelCheckpoint(dirpath=logs_dir, **asdict_filtered(self))
Example #11
0
 def get_callback(self) -> Callback:
     return EarlyStopping(**asdict_filtered(self))
 def get_datamodule(self) -> LaPaDataModule:
     return LaPaDataModule(**asdict_filtered(self))