Ejemplo n.º 1
0
    def test_whenLogConfigParamsAConfigDictWithSequence_givenAMLFlowCallback_thenLogParams(self):
        with patch("poutyne.framework.mlflow_logger.MlflowClient") as ml_flow_client_patch:
            ml_flow_client_patch.return_value.create_experiment = self.experiment_mock
            ml_flow_client_patch.return_value.create_run = self.run_mock

            mlflow_logger = MLFlowLogger(self.a_experiment_name)
            mlflow_logger.log_config_params(self.settings_in_dictconfig_with_sequence)

            ml_flow_client_calls = self._populate_calls_from_dict(self.settings_in_dictconfig_with_sequence)
            ml_flow_client_patch.assert_has_calls(ml_flow_client_calls)
Ejemplo n.º 2
0
    def test_whenOnTestFailure_givenAMLFlowCallback_thenHasFailureTerminatedStatus(self):
        # pylint: disable=protected-access
        with patch("poutyne.framework.mlflow_logger.MlflowClient") as ml_flow_client_patch:
            ml_flow_client_patch.return_value.create_experiment = self.experiment_mock
            ml_flow_client_patch.return_value.create_run = self.run_mock

            mlflow_logger = MLFlowLogger(self.a_experiment_name)
            mlflow_logger._status_handling()

            ml_flow_client_calls = [call().set_terminated(self.a_run_id, status="FAILED")]
            ml_flow_client_patch.assert_has_calls(ml_flow_client_calls)
Ejemplo n.º 3
0
    def test_whenLogConfigParamsASimpleDict_givenAMLFlowCallback_thenLogParams(self):
        with patch("poutyne.framework.mlflow_logger.MlflowClient") as ml_flow_client_patch:
            ml_flow_client_patch.return_value.create_experiment = self.experiment_mock
            ml_flow_client_patch.return_value.create_run = self.run_mock

            mlflow_logger = MLFlowLogger(self.a_experiment_name)
            mlflow_logger.log_config_params(self.settings_in_dict)

            ml_flow_client_calls = []
            for key, value in self.settings_in_dict.items():
                ml_flow_client_calls.append(call().log_param(run_id=self.a_run_id, key=key, value=value))
            ml_flow_client_patch.assert_has_calls(ml_flow_client_calls)
Ejemplo n.º 4
0
    def test_whenLogMetric_givenAMLFlowCallback_thenLogMetric(self):
        with patch("poutyne.framework.mlflow_logger.MlflowClient") as ml_flow_client_patch:
            ml_flow_client_patch.return_value.create_experiment = self.experiment_mock
            ml_flow_client_patch.return_value.create_run = self.run_mock

            mlflow_logger = MLFlowLogger(self.a_experiment_name)

            ml_flow_client_calls = []
            for key, value in self.a_log.items():
                mlflow_logger.log_metric(key, value)
                ml_flow_client_calls.append(call().log_metric(run_id=self.a_run_id, key=key, value=value, step=None))
            ml_flow_client_patch.assert_has_calls(ml_flow_client_calls)
Ejemplo n.º 5
0
    def test_whenOnTestSuccess_givenAMLFlowCallback_thenHasSuccessTerminatedStatus(self):
        with patch("poutyne.framework.mlflow_logger.MlflowClient") as ml_flow_client_patch:
            ml_flow_client_patch.return_value.create_experiment = self.experiment_mock
            ml_flow_client_patch.return_value.create_run = self.run_mock

            mlflow_logger = MLFlowLogger(self.a_experiment_name)
            mlflow_logger.set_params({"epochs": self.epochs})
            mlflow_logger.on_train_end(self.a_log)
            mlflow_logger.on_test_begin({})  # since we change status at the start of testing
            mlflow_logger.on_test_end(self.a_log)

            ml_flow_client_calls = [call().set_terminated(self.a_run_id, status="FINISHED")]
            ml_flow_client_patch.assert_has_calls(ml_flow_client_calls)
Ejemplo n.º 6
0
    def test_whenNewExperiment_givenAMLFlowInstantiation_thenCreateNewExperiment(self):
        with patch("poutyne.framework.mlflow_logger.MlflowClient") as ml_flow_client_patch:
            ml_flow_client_patch.return_value.create_experiment = self.experiment_mock

            MLFlowLogger(self.a_experiment_name)

            create_experiment_call = [call().create_experiment(self.a_experiment_name, self.none_tracking_uri)]

            ml_flow_client_patch.assert_has_calls(create_experiment_call)
Ejemplo n.º 7
0
    def test_whenGitRepo_givenAMLFlowInstantiation_thenLogGitCommit(self, get_git_commit_patch):
        with patch("poutyne.framework.mlflow_logger.MlflowClient") as ml_flow_client_patch:
            ml_flow_client_patch.return_value.create_run = self.run_mock
            MLFlowLogger(self.a_experiment_name)

            git_logging_call = [call(self.the_working_directory)]
            get_git_commit_patch.assert_has_calls(git_logging_call)

            mlflow_client_call = [call().set_tag(self.a_run_id, mlflow_default_git_commit_tag, a_git_commit)]
            ml_flow_client_patch.assert_has_calls(mlflow_client_call)
Ejemplo n.º 8
0
    def test_whenExperimentAlreadyCreated_givenAMLFlowInstantiation_thenGetExperiment(self):
        with patch("poutyne.framework.mlflow_logger.MlflowClient") as ml_flow_client_patch:
            ml_flow_client_patch.return_value.create_experiment = MagicMock(
                side_effect=MlflowException(self.a_exception_message)
            )
            ml_flow_client_patch.return_value.get_experiment_by_name = MagicMock(return_value=self.experiment_mock)

            MLFlowLogger(self.a_experiment_name)

            create_experiment_calls = [
                call().create_experiment(self.a_experiment_name, self.none_tracking_uri),
                call().get_experiment_by_name(self.a_experiment_name),
            ]

            ml_flow_client_patch.assert_has_calls(create_experiment_calls)
Ejemplo n.º 9
0
    def test_whenOnTrainEndSuccess_givenAMLFlowCallback_thenLogLastEpochNumber(self):
        with patch("poutyne.framework.mlflow_logger.MlflowClient") as ml_flow_client_patch:
            ml_flow_client_patch.return_value.create_experiment = self.experiment_mock
            ml_flow_client_patch.return_value.create_run = self.run_mock

            mlflow_logger = MLFlowLogger(self.a_experiment_name)
            mlflow_logger.set_params({"epochs": self.epochs})
            mlflow_logger.on_train_end(self.a_log)

            ml_flow_client_calls = [
                call().log_metric(run_id=self.a_run_id, key='last-epoch', value=self.epochs, step=None)
            ]
            ml_flow_client_patch.assert_has_calls(ml_flow_client_calls)
Ejemplo n.º 10
0
    def test_whenCorrectSettings_givenAMLFlowInstantiation_thenMLflowClientIsProperlySet(self):
        with patch("poutyne.framework.mlflow_logger.MlflowClient") as ml_flow_client_patch:
            ml_flow_client_patch.return_value.create_experiment = self.experiment_mock
            ml_flow_client_patch.return_value.create_run = self.run_mock

            mlflow_logger = MLFlowLogger(self.a_experiment_name)

            settings_calls = [
                call().create_experiment(self.a_experiment_name, self.none_tracking_uri),
                call().create_run(experiment_id=self.a_experiment_id),
            ]
            ml_flow_client_patch.assert_has_calls(settings_calls)

            actual_experiment_id = mlflow_logger.experiment_id
            expected_experiment_id = self.a_experiment_id
            self.assertEqual(expected_experiment_id, actual_experiment_id)

            actual_run_id = mlflow_logger.run_id
            expected_run_id = self.a_run_id
            self.assertEqual(expected_run_id, actual_run_id)