def test_success(self, mock_hook): task = MLEngineSetDefaultVersionOperator( task_id="task-id", project_id=TEST_PROJECT_ID, model_name=TEST_MODEL_NAME, version_name=TEST_VERSION_NAME, gcp_conn_id=TEST_GCP_CONN_ID, delegate_to=TEST_DELEGATE_TO, ) task.execute(None) mock_hook.assert_called_once_with(delegate_to=TEST_DELEGATE_TO, gcp_conn_id=TEST_GCP_CONN_ID) mock_hook.return_value.set_default_version.assert_called_once_with( project_id=TEST_PROJECT_ID, model_name=TEST_MODEL_NAME, version_name=TEST_VERSION_NAME )
def test_missing_version_name(self): with pytest.raises(AirflowException): MLEngineSetDefaultVersionOperator( task_id="task-id", project_id=TEST_PROJECT_ID, model_name=TEST_MODEL_NAME, version_name=None, gcp_conn_id=TEST_GCP_CONN_ID, delegate_to=TEST_DELEGATE_TO, )
version={ "name": "v2", "description": "Second version", "deployment_uri": SAVED_MODEL_PATH, "runtime_version": "1.15", "machineType": "mls1-c1-m2", "framework": "TENSORFLOW", "pythonVersion": "3.7", }, ) # [END howto_operator_gcp_mlengine_create_version2] # [START howto_operator_gcp_mlengine_default_version] set_defaults_version = MLEngineSetDefaultVersionOperator( task_id="set-default-version", project_id=PROJECT_ID, model_name=MODEL_NAME, version_name="v2", ) # [END howto_operator_gcp_mlengine_default_version] # [START howto_operator_gcp_mlengine_list_versions] list_version = MLEngineListVersionsOperator( task_id="list-version", project_id=PROJECT_ID, model_name=MODEL_NAME, ) # [END howto_operator_gcp_mlengine_list_versions] # [START howto_operator_gcp_mlengine_print_versions] list_version_result = BashOperator( bash_command=f"echo {list_version.output}",