def test_mlflow_logger():
    """
    verify that basic functionality of mlflow logger works
    """
    reset_seed()

    try:
        from pytorch_lightning.logging import MLFlowLogger
    except ModuleNotFoundError:
        return

    hparams = get_hparams()
    model = LightningTestModel(hparams)

    root_dir = os.path.dirname(os.path.realpath(__file__))
    mlflow_dir = os.path.join(root_dir, "mlruns")

    logger = MLFlowLogger("test", f"file://{mlflow_dir}")
    logger.log_hyperparams(hparams)
    logger.save()

    trainer_options = dict(max_nb_epochs=1,
                           train_percent_check=0.01,
                           logger=logger)

    trainer = Trainer(**trainer_options)
    result = trainer.fit(model)

    assert result == 1, "Training failed"

    n = RANDOM_FILE_PATHS.pop()
    shutil.move(mlflow_dir, mlflow_dir + f'_{n}')
def test_mlflow_pickle():
    """
    verify that pickling trainer with mlflow logger works
    """
    reset_seed()

    try:
        from pytorch_lightning.logging import MLFlowLogger
    except ModuleNotFoundError:
        return

    hparams = get_hparams()
    model = LightningTestModel(hparams)

    root_dir = os.path.dirname(os.path.realpath(__file__))
    mlflow_dir = os.path.join(root_dir, "mlruns")

    logger = MLFlowLogger("test", f"file://{mlflow_dir}")
    logger.log_hyperparams(hparams)
    logger.save()

    trainer_options = dict(max_nb_epochs=1, logger=logger)

    trainer = Trainer(**trainer_options)
    pkl_bytes = pickle.dumps(trainer)
    trainer2 = pickle.loads(pkl_bytes)
    trainer2.logger.log_metrics({"acc": 1.0})

    n = RANDOM_FILE_PATHS.pop()
    shutil.move(mlflow_dir, mlflow_dir + f'_{n}')