def test_mlflow_logger(): """ verify that basic functionality of mlflow logger works """ reset_seed() try: from pytorch_lightning.logging import MLFlowLogger except ModuleNotFoundError: return hparams = get_hparams() model = LightningTestModel(hparams) root_dir = os.path.dirname(os.path.realpath(__file__)) mlflow_dir = os.path.join(root_dir, "mlruns") logger = MLFlowLogger("test", f"file://{mlflow_dir}") logger.log_hyperparams(hparams) logger.save() trainer_options = dict(max_nb_epochs=1, train_percent_check=0.01, logger=logger) trainer = Trainer(**trainer_options) result = trainer.fit(model) assert result == 1, "Training failed" n = RANDOM_FILE_PATHS.pop() shutil.move(mlflow_dir, mlflow_dir + f'_{n}')
def test_mlflow_pickle(): """ verify that pickling trainer with mlflow logger works """ reset_seed() try: from pytorch_lightning.logging import MLFlowLogger except ModuleNotFoundError: return hparams = get_hparams() model = LightningTestModel(hparams) root_dir = os.path.dirname(os.path.realpath(__file__)) mlflow_dir = os.path.join(root_dir, "mlruns") logger = MLFlowLogger("test", f"file://{mlflow_dir}") logger.log_hyperparams(hparams) logger.save() trainer_options = dict(max_nb_epochs=1, logger=logger) trainer = Trainer(**trainer_options) pkl_bytes = pickle.dumps(trainer) trainer2 = pickle.loads(pkl_bytes) trainer2.logger.log_metrics({"acc": 1.0}) n = RANDOM_FILE_PATHS.pop() shutil.move(mlflow_dir, mlflow_dir + f'_{n}')