def test_wandb_log_model(wandb, tmpdir): """Test that the logger creates the folders and files in the right place.""" wandb.run = None model = BoringModel() # test log_model=True logger = WandbLogger(log_model=True) logger.experiment.id = "1" logger.experiment.project_name.return_value = "project" trainer = Trainer(default_root_dir=tmpdir, logger=logger, max_epochs=2, limit_train_batches=3, limit_val_batches=3) trainer.fit(model) wandb.init().log_artifact.assert_called_once() # test log_model='all' wandb.init().log_artifact.reset_mock() wandb.init.reset_mock() logger = WandbLogger(log_model="all") logger.experiment.id = "1" logger.experiment.project_name.return_value = "project" trainer = Trainer(default_root_dir=tmpdir, logger=logger, max_epochs=2, limit_train_batches=3, limit_val_batches=3) trainer.fit(model) assert wandb.init().log_artifact.call_count == 2 # test log_model=False wandb.init().log_artifact.reset_mock() wandb.init.reset_mock() logger = WandbLogger(log_model=False) logger.experiment.id = "1" logger.experiment.project_name.return_value = "project" trainer = Trainer(default_root_dir=tmpdir, logger=logger, max_epochs=2, limit_train_batches=3, limit_val_batches=3) trainer.fit(model) assert not wandb.init().log_artifact.called # test correct metadata import pytorch_lightning.loggers.wandb as pl_wandb pl_wandb._WANDB_GREATER_EQUAL_0_10_22 = True wandb.init().log_artifact.reset_mock() wandb.init.reset_mock() wandb.Artifact.reset_mock() logger = pl_wandb.WandbLogger(log_model=True) logger.experiment.id = "1" logger.experiment.project_name.return_value = "project" trainer = Trainer(default_root_dir=tmpdir, logger=logger, max_epochs=2, limit_train_batches=3, limit_val_batches=3) trainer.fit(model) wandb.Artifact.assert_called_once_with( name="model-1", type="model", metadata={ "score": None, "original_filename": "epoch=1-step=5-v3.ckpt", "ModelCheckpoint": { "monitor": None, "mode": "min", "save_last": None, "save_top_k": 1, "save_weights_only": False, "_every_n_train_steps": 0, }, }, )
def test_wandb_log_model(wandb, tmpdir): """ Test that the logger creates the folders and files in the right place. """ wandb.run = None model = BoringModel() # test log_model=True logger = WandbLogger(log_model=True) logger.experiment.id = '1' logger.experiment.project_name.return_value = 'project' trainer = Trainer(default_root_dir=tmpdir, logger=logger, max_epochs=2, limit_train_batches=3, limit_val_batches=3) trainer.fit(model) wandb.init().log_artifact.assert_called_once() # test log_model='all' wandb.init().log_artifact.reset_mock() wandb.init.reset_mock() logger = WandbLogger(log_model='all') logger.experiment.id = '1' logger.experiment.project_name.return_value = 'project' trainer = Trainer(default_root_dir=tmpdir, logger=logger, max_epochs=2, limit_train_batches=3, limit_val_batches=3) trainer.fit(model) assert wandb.init().log_artifact.call_count == 2 # test log_model=False wandb.init().log_artifact.reset_mock() wandb.init.reset_mock() logger = WandbLogger(log_model=False) logger.experiment.id = '1' logger.experiment.project_name.return_value = 'project' trainer = Trainer(default_root_dir=tmpdir, logger=logger, max_epochs=2, limit_train_batches=3, limit_val_batches=3) trainer.fit(model) assert not wandb.init().log_artifact.called # test correct metadata import pytorch_lightning.loggers.wandb as pl_wandb pl_wandb._WANDB_GREATER_EQUAL_0_10_22 = True wandb.init().log_artifact.reset_mock() wandb.init.reset_mock() wandb.Artifact.reset_mock() logger = pl_wandb.WandbLogger(log_model=True) logger.experiment.id = '1' logger.experiment.project_name.return_value = 'project' trainer = Trainer(default_root_dir=tmpdir, logger=logger, max_epochs=2, limit_train_batches=3, limit_val_batches=3) trainer.fit(model) wandb.Artifact.assert_called_once_with(name='model-1', type='model', metadata={ 'score': None, 'original_filename': 'epoch=1-step=5-v3.ckpt', 'ModelCheckpoint': { 'monitor': None, 'mode': 'min', 'save_last': None, 'save_top_k': None, 'save_weights_only': False, '_every_n_train_steps': 0, '_every_n_val_epochs': 1 } })