def test_fail(self, mock_hook):
     task = MLEngineManageModelOperator(task_id="task-id",
                                        project_id=TEST_PROJECT_ID,
                                        model=TEST_MODEL,
                                        operation="invalid",
                                        gcp_conn_id=TEST_GCP_CONN_ID,
                                        delegate_to=TEST_DELEGATE_TO)
     with self.assertRaises(ValueError):
         task.execute(None)
    def test_success_create_model(self, mock_hook):
        task = MLEngineManageModelOperator(task_id="task-id",
                                           project_id=TEST_PROJECT_ID,
                                           model=TEST_MODEL,
                                           operation="create",
                                           gcp_conn_id=TEST_GCP_CONN_ID,
                                           delegate_to=TEST_DELEGATE_TO)

        task.execute(None)

        mock_hook.assert_called_once_with(delegate_to=TEST_DELEGATE_TO,
                                          gcp_conn_id=TEST_GCP_CONN_ID)
        mock_hook.return_value.create_model.assert_called_once_with(
            project_id=TEST_PROJECT_ID, model=TEST_MODEL)
    def test_success_get_model(self, mock_hook):
        task = MLEngineManageModelOperator(task_id="task-id",
                                           project_id=TEST_PROJECT_ID,
                                           model=TEST_MODEL,
                                           operation="get",
                                           gcp_conn_id=TEST_GCP_CONN_ID,
                                           delegate_to=TEST_DELEGATE_TO)

        result = task.execute(None)

        mock_hook.assert_called_once_with(delegate_to=TEST_DELEGATE_TO,
                                          gcp_conn_id=TEST_GCP_CONN_ID)
        mock_hook.return_value.get_model.assert_called_once_with(
            project_id=TEST_PROJECT_ID, model_name=TEST_MODEL_NAME)
        self.assertEqual(mock_hook.return_value.get_model.return_value, result)
예제 #4
0
    def test_success_get_model(self, mock_hook):
        task = MLEngineManageModelOperator(
            task_id="task-id",
            project_id=TEST_PROJECT_ID,
            model=TEST_MODEL,
            operation="get",
            gcp_conn_id=TEST_GCP_CONN_ID,
            delegate_to=TEST_DELEGATE_TO,
            impersonation_chain=TEST_IMPERSONATION_CHAIN,
        )

        result = task.execute(None)

        mock_hook.assert_called_once_with(
            delegate_to=TEST_DELEGATE_TO,
            gcp_conn_id=TEST_GCP_CONN_ID,
            impersonation_chain=TEST_IMPERSONATION_CHAIN,
        )
        mock_hook.return_value.get_model.assert_called_once_with(
            project_id=TEST_PROJECT_ID, model_name=TEST_MODEL_NAME)
        assert mock_hook.return_value.get_model.return_value == result