def test_whenOnTrainEndSuccess_givenAMLFlowCallback_thenHasSuccessTerminatedStatus(self): with patch("poutyne.framework.mlflow_logger.MlflowClient") as ml_flow_client_patch: ml_flow_client_patch.return_value.create_experiment = self.experiment_mock ml_flow_client_patch.return_value.create_run = self.run_mock mlflow_logger = MLFlowLogger(self.a_experiment_name) mlflow_logger.set_params({"epochs": self.epochs}) mlflow_logger.on_train_end(self.a_log) ml_flow_client_calls = [call().set_terminated(self.a_run_id, status="FINISHED")] ml_flow_client_patch.assert_has_calls(ml_flow_client_calls)
def test_whenOnTrainEndSuccess_givenAMLFlowCallback_thenLogLastEpochNumber(self): with patch("poutyne.framework.mlflow_logger.MlflowClient") as ml_flow_client_patch: ml_flow_client_patch.return_value.create_experiment = self.experiment_mock ml_flow_client_patch.return_value.create_run = self.run_mock mlflow_logger = MLFlowLogger(self.a_experiment_name) mlflow_logger.set_params({"epochs": self.epochs}) mlflow_logger.on_train_end(self.a_log) ml_flow_client_calls = [ call().log_metric(run_id=self.a_run_id, key='last-epoch', value=self.epochs, step=None) ] ml_flow_client_patch.assert_has_calls(ml_flow_client_calls)