def test_mlflow_logger_with_unexpected_characters(client, mlflow, tmpdir): """Test that the logger raises warning with special characters not accepted by MLFlow.""" logger = MLFlowLogger("test", save_dir=tmpdir) metrics = {"[some_metric]": 10} with pytest.warns(RuntimeWarning, match="special characters in metric name"): logger.log_metrics(metrics)
def test_mlflow_logger_experiment_calls(client, mlflow, time, tmpdir): """ Test that the logger calls methods on the mlflow experiment correctly. """ time.return_value = 1 logger = MLFlowLogger("test", save_dir=tmpdir, artifact_location="my_artifact_location") logger._mlflow_client.get_experiment_by_name.return_value = None params = {"test": "test_param"} logger.log_hyperparams(params) logger.experiment.log_param.assert_called_once_with( logger.run_id, "test", "test_param") metrics = {"some_metric": 10} logger.log_metrics(metrics) logger.experiment.log_metric.assert_called_once_with( logger.run_id, "some_metric", 10, 1000, None) logger._mlflow_client.create_experiment.assert_called_once_with( name="test", artifact_location="my_artifact_location")
def test_mlflow_logger(tmpdir): """Verify that basic functionality of mlflow logger works.""" tutils.reset_seed() hparams = tutils.get_default_hparams() model = LightningTestModel(hparams) mlflow_dir = os.path.join(tmpdir, 'mlruns') logger = MLFlowLogger('test', tracking_uri=f'file:{os.sep * 2}{mlflow_dir}') # Test already exists logger2 = MLFlowLogger('test', tracking_uri=f'file:{os.sep * 2}{mlflow_dir}') _ = logger2.run_id # Try logging string logger.log_metrics({'acc': 'test'}) trainer_options = dict( default_root_dir=tmpdir, max_epochs=1, train_percent_check=0.05, logger=logger ) trainer = Trainer(**trainer_options) result = trainer.fit(model) assert result == 1, 'Training failed'
def objective(trial, args): params = get_trial_params(trial) params['hidden_size'] = 2**params['hidden_size'] params['acc_grads'] = 2**params['acc_grads'] early_stopper = EarlyStopping( monitor='val_loss', min_delta=0.005, patience=3, mode='min') callbacks = [early_stopper, PyTorchLightningPruningCallback( trial, monitor="val_loss")] if args.model_type == 'attnlstm': params['attn_width'] = trial.suggest_int("attn_width", 3, 64) if 'split' in args.val_mode: dataset_hour = args.data.split('_')[-1] logger = MLFlowLogger(experiment_name=f'Optuna_{dataset_hour}h_{args.val_mode[-1]}_split') print(f'Optuna_{dataset_hour}_{args.val_mode[-1]}_split') val_losses = [] for _split_id in range(int(args.val_mode[-1])): print(f"Split {_split_id} Trial {trial.number}") args.__dict__["split_id"] = 0 for key in params: args.__dict__[str(key)] = params.get(key) model = LitLSTM(args) trainer = Trainer( logger=logger, callbacks=callbacks, **get_trainer_params(args), ) logger.log_hyperparams(model.args) args.__dict__["val_mode"] = args.val_mode args.__dict__["split_id"] = _split_id model._get_data(args, data_mode='init') trainer.fit(model) trainer.test(model, test_dataloaders=model.test_dataloader()) # logger.finalize() val_losses.append(model.metrics['val_loss']) # log mean val loss for later retrieval of best model mean_val_loss = torch.stack(val_losses).mean() logger.log_metrics({"mean_val_loss": mean_val_loss}, step=0) logger.finalize() return mean_val_loss elif args.val_mode == 'full': logger = MLFlowLogger(experiment_name='Optuna_full') for key in params: args.__dict__[str(key)] = params.get(key) model = LitLSTM(args) trainer = Trainer( logger=logger, callbacks=callbacks, **get_trainer_params(args), ) logger.log_hyperparams(model.args) trainer.fit(model) trainer.test(model, test_dataloaders=model.test_dataloader()) model.save_preds_and_targets(to_disk=True) logger.finalize() return model.metrics['val_loss']
def test_mlflow_logger_experiment_calls(client, mlflow, time, tmpdir): """ Test that the logger calls methods on the mlflow experiment correctly. """ time.return_value = 1 logger = MLFlowLogger('test', save_dir=tmpdir, artifact_location='my_artifact_location') logger._mlflow_client.get_experiment_by_name.return_value = None params = {'test': 'test_param'} logger.log_hyperparams(params) logger.experiment.log_param.assert_called_once_with( logger.run_id, 'test', 'test_param') metrics = {'some_metric': 10} logger.log_metrics(metrics) logger.experiment.log_metric.assert_called_once_with( logger.run_id, 'some_metric', 10, 1000, None) logger._mlflow_client.create_experiment.assert_called_once_with( name='test', artifact_location='my_artifact_location', )