コード例 #1
0
def test_wandb_log_model(wandb, tmpdir):
    """Test that the logger creates the folders and files in the right place."""

    wandb.run = None
    model = BoringModel()

    # test log_model=True
    logger = WandbLogger(log_model=True)
    logger.experiment.id = "1"
    logger.experiment.project_name.return_value = "project"
    trainer = Trainer(default_root_dir=tmpdir, logger=logger, max_epochs=2, limit_train_batches=3, limit_val_batches=3)
    trainer.fit(model)
    wandb.init().log_artifact.assert_called_once()

    # test log_model='all'
    wandb.init().log_artifact.reset_mock()
    wandb.init.reset_mock()
    logger = WandbLogger(log_model="all")
    logger.experiment.id = "1"
    logger.experiment.project_name.return_value = "project"
    trainer = Trainer(default_root_dir=tmpdir, logger=logger, max_epochs=2, limit_train_batches=3, limit_val_batches=3)
    trainer.fit(model)
    assert wandb.init().log_artifact.call_count == 2

    # test log_model=False
    wandb.init().log_artifact.reset_mock()
    wandb.init.reset_mock()
    logger = WandbLogger(log_model=False)
    logger.experiment.id = "1"
    logger.experiment.project_name.return_value = "project"
    trainer = Trainer(default_root_dir=tmpdir, logger=logger, max_epochs=2, limit_train_batches=3, limit_val_batches=3)
    trainer.fit(model)
    assert not wandb.init().log_artifact.called

    # test correct metadata
    import pytorch_lightning.loggers.wandb as pl_wandb

    pl_wandb._WANDB_GREATER_EQUAL_0_10_22 = True
    wandb.init().log_artifact.reset_mock()
    wandb.init.reset_mock()
    wandb.Artifact.reset_mock()
    logger = pl_wandb.WandbLogger(log_model=True)
    logger.experiment.id = "1"
    logger.experiment.project_name.return_value = "project"
    trainer = Trainer(default_root_dir=tmpdir, logger=logger, max_epochs=2, limit_train_batches=3, limit_val_batches=3)
    trainer.fit(model)
    wandb.Artifact.assert_called_once_with(
        name="model-1",
        type="model",
        metadata={
            "score": None,
            "original_filename": "epoch=1-step=5-v3.ckpt",
            "ModelCheckpoint": {
                "monitor": None,
                "mode": "min",
                "save_last": None,
                "save_top_k": 1,
                "save_weights_only": False,
                "_every_n_train_steps": 0,
            },
        },
    )
コード例 #2
0
def test_wandb_log_model(wandb, tmpdir):
    """ Test that the logger creates the folders and files in the right place. """

    wandb.run = None
    model = BoringModel()

    # test log_model=True
    logger = WandbLogger(log_model=True)
    logger.experiment.id = '1'
    logger.experiment.project_name.return_value = 'project'
    trainer = Trainer(default_root_dir=tmpdir,
                      logger=logger,
                      max_epochs=2,
                      limit_train_batches=3,
                      limit_val_batches=3)
    trainer.fit(model)
    wandb.init().log_artifact.assert_called_once()

    # test log_model='all'
    wandb.init().log_artifact.reset_mock()
    wandb.init.reset_mock()
    logger = WandbLogger(log_model='all')
    logger.experiment.id = '1'
    logger.experiment.project_name.return_value = 'project'
    trainer = Trainer(default_root_dir=tmpdir,
                      logger=logger,
                      max_epochs=2,
                      limit_train_batches=3,
                      limit_val_batches=3)
    trainer.fit(model)
    assert wandb.init().log_artifact.call_count == 2

    # test log_model=False
    wandb.init().log_artifact.reset_mock()
    wandb.init.reset_mock()
    logger = WandbLogger(log_model=False)
    logger.experiment.id = '1'
    logger.experiment.project_name.return_value = 'project'
    trainer = Trainer(default_root_dir=tmpdir,
                      logger=logger,
                      max_epochs=2,
                      limit_train_batches=3,
                      limit_val_batches=3)
    trainer.fit(model)
    assert not wandb.init().log_artifact.called

    # test correct metadata
    import pytorch_lightning.loggers.wandb as pl_wandb
    pl_wandb._WANDB_GREATER_EQUAL_0_10_22 = True
    wandb.init().log_artifact.reset_mock()
    wandb.init.reset_mock()
    wandb.Artifact.reset_mock()
    logger = pl_wandb.WandbLogger(log_model=True)
    logger.experiment.id = '1'
    logger.experiment.project_name.return_value = 'project'
    trainer = Trainer(default_root_dir=tmpdir,
                      logger=logger,
                      max_epochs=2,
                      limit_train_batches=3,
                      limit_val_batches=3)
    trainer.fit(model)
    wandb.Artifact.assert_called_once_with(name='model-1',
                                           type='model',
                                           metadata={
                                               'score': None,
                                               'original_filename':
                                               'epoch=1-step=5-v3.ckpt',
                                               'ModelCheckpoint': {
                                                   'monitor': None,
                                                   'mode': 'min',
                                                   'save_last': None,
                                                   'save_top_k': None,
                                                   'save_weights_only': False,
                                                   '_every_n_train_steps': 0,
                                                   '_every_n_val_epochs': 1
                                               }
                                           })