def test_success(self, mock_hook): task = MLEngineListVersionsOperator( task_id="task-id", project_id=TEST_PROJECT_ID, model_name=TEST_MODEL_NAME, gcp_conn_id=TEST_GCP_CONN_ID, delegate_to=TEST_DELEGATE_TO, ) task.execute(None) mock_hook.assert_called_once_with(delegate_to=TEST_DELEGATE_TO, gcp_conn_id=TEST_GCP_CONN_ID) mock_hook.return_value.list_versions.assert_called_once_with( project_id=TEST_PROJECT_ID, model_name=TEST_MODEL_NAME, )
def test_missing_model_name(self): with pytest.raises(AirflowException): MLEngineListVersionsOperator( task_id="task-id", project_id=TEST_PROJECT_ID, model_name=None, gcp_conn_id=TEST_GCP_CONN_ID, delegate_to=TEST_DELEGATE_TO, )
) # [END howto_operator_gcp_mlengine_create_version2] # [START howto_operator_gcp_mlengine_default_version] set_defaults_version = MLEngineSetDefaultVersionOperator( task_id="set-default-version", project_id=PROJECT_ID, model_name=MODEL_NAME, version_name="v2", ) # [END howto_operator_gcp_mlengine_default_version] # [START howto_operator_gcp_mlengine_list_versions] list_version = MLEngineListVersionsOperator( task_id="list-version", project_id=PROJECT_ID, model_name=MODEL_NAME, ) # [END howto_operator_gcp_mlengine_list_versions] # [START howto_operator_gcp_mlengine_print_versions] list_version_result = BashOperator( bash_command=f"echo {list_version.output}", task_id="list-version-result", ) # [END howto_operator_gcp_mlengine_print_versions] # [START howto_operator_gcp_mlengine_get_prediction] prediction = MLEngineStartBatchPredictionJobOperator( task_id="prediction", project_id=PROJECT_ID,