def test_success(self, mock_hook):
        task = MLEngineListVersionsOperator(
            task_id="task-id",
            project_id=TEST_PROJECT_ID,
            model_name=TEST_MODEL_NAME,
            gcp_conn_id=TEST_GCP_CONN_ID,
            delegate_to=TEST_DELEGATE_TO,
        )

        task.execute(None)

        mock_hook.assert_called_once_with(delegate_to=TEST_DELEGATE_TO, gcp_conn_id=TEST_GCP_CONN_ID)
        mock_hook.return_value.list_versions.assert_called_once_with(
            project_id=TEST_PROJECT_ID,
            model_name=TEST_MODEL_NAME,
        )
Пример #2
0
 def test_missing_model_name(self):
     with pytest.raises(AirflowException):
         MLEngineListVersionsOperator(
             task_id="task-id",
             project_id=TEST_PROJECT_ID,
             model_name=None,
             gcp_conn_id=TEST_GCP_CONN_ID,
             delegate_to=TEST_DELEGATE_TO,
         )
Пример #3
0
    )
    # [END howto_operator_gcp_mlengine_create_version2]

    # [START howto_operator_gcp_mlengine_default_version]
    set_defaults_version = MLEngineSetDefaultVersionOperator(
        task_id="set-default-version",
        project_id=PROJECT_ID,
        model_name=MODEL_NAME,
        version_name="v2",
    )
    # [END howto_operator_gcp_mlengine_default_version]

    # [START howto_operator_gcp_mlengine_list_versions]
    list_version = MLEngineListVersionsOperator(
        task_id="list-version",
        project_id=PROJECT_ID,
        model_name=MODEL_NAME,
    )
    # [END howto_operator_gcp_mlengine_list_versions]

    # [START howto_operator_gcp_mlengine_print_versions]
    list_version_result = BashOperator(
        bash_command=f"echo {list_version.output}",
        task_id="list-version-result",
    )
    # [END howto_operator_gcp_mlengine_print_versions]

    # [START howto_operator_gcp_mlengine_get_prediction]
    prediction = MLEngineStartBatchPredictionJobOperator(
        task_id="prediction",
        project_id=PROJECT_ID,