class RemoteDataContext(_OutputDataContext): _CLOUD_PROVIDER_TO_PROXIES = { _constants.CloudProvider.AWS: _s3proxy.AwsS3Proxy(), _constants.CloudProvider.GCP: _gcs_proxy.GCSProxy(), } def __init__(self, cloud_provider=None): """ :param Optional[Text] cloud_provider: From flytekit.common.constants.CloudProvider enum """ cloud_provider = cloud_provider or _platform_config.CLOUD_PROVIDER.get() proxy = type(self)._CLOUD_PROVIDER_TO_PROXIES.get(cloud_provider, None) if proxy is None: raise _user_exception.FlyteAssertion( "Configured cloud provider is not supported for data I/O. Received: {}, expected one of: {}".format( cloud_provider, list(type(self)._CLOUD_PROVIDER_TO_PROXIES.keys()) ) ) super(RemoteDataContext, self).__init__(proxy)
class Data(object): # TODO: More proxies for more environments. _DATA_PROXIES = { "s3:/": _s3proxy.AwsS3Proxy(), "gs:/": _gcs_proxy.GCSProxy(), "http://": _http_data_proxy.HttpFileProxy(), "https://": _http_data_proxy.HttpFileProxy(), } @classmethod def _load_data_proxy_by_path(cls, path): """ :param Text path: :rtype: flytekit.interfaces.data.common.DataProxy """ for k, v in _six.iteritems(cls._DATA_PROXIES): if path.startswith(k): return v return _OutputDataContext.get_default_proxy() @classmethod def data_exists(cls, path): """ :param Text path: :rtype: bool: whether the file exists or not """ with _common_utils.PerformanceTimer("Check file exists {}".format(path)): proxy = cls._load_data_proxy_by_path(path) return proxy.exists(path) @classmethod def get_data(cls, remote_path, local_path, is_multipart=False): """ :param Text remote_path: :param Text local_path: :param bool is_multipart: """ try: with _common_utils.PerformanceTimer("Copying ({} -> {})".format(remote_path, local_path)): proxy = cls._load_data_proxy_by_path(remote_path) if is_multipart: proxy.download_directory(remote_path, local_path) else: proxy.download(remote_path, local_path) except Exception as ex: raise _user_exception.FlyteAssertion( "Failed to get data from {remote_path} to {local_path} (recursive={is_multipart}).\n\n" "Original exception: {error_string}".format( remote_path=remote_path, local_path=local_path, is_multipart=is_multipart, error_string=_six.text_type(ex) ) ) @classmethod def put_data(cls, local_path, remote_path, is_multipart=False): """ :param Text local_path: :param Text remote_path: :param bool is_multipart: """ try: with _common_utils.PerformanceTimer("Writing ({} -> {})".format(local_path, remote_path)): proxy = cls._load_data_proxy_by_path(remote_path) if is_multipart: proxy.upload_directory(local_path, remote_path) else: proxy.upload(local_path, remote_path) except Exception as ex: raise _user_exception.FlyteAssertion( "Failed to put data from {local_path} to {remote_path} (recursive={is_multipart}).\n\n" "Original exception: {error_string}".format( remote_path=remote_path, local_path=local_path, is_multipart=is_multipart, error_string=_six.text_type(ex) ) ) @classmethod def get_remote_path(cls): """ :rtype: Text """ return _OutputDataContext.get_active_proxy().get_random_path() @classmethod def get_remote_directory(cls): """ :rtype: Text """ return _OutputDataContext.get_active_proxy().get_random_directory()
def _handle_annotated_task(task_def: PythonTask, inputs: str, output_prefix: str, raw_output_data_prefix: str): """ Entrypoint for all PythonTask extensions """ _click.echo("Running native-typed task") cloud_provider = _platform_config.CLOUD_PROVIDER.get() log_level = _internal_config.LOGGING_LEVEL.get() or _sdk_config.LOGGING_LEVEL.get() _logging.getLogger().setLevel(log_level) ctx = FlyteContext.current_context() # Create directories user_workspace_dir = ctx.file_access.local_access.get_random_directory() _click.echo(f"Using user directory {user_workspace_dir}") pathlib.Path(user_workspace_dir).mkdir(parents=True, exist_ok=True) from flytekit import __version__ as _api_version execution_parameters = ExecutionParameters( execution_id=_identifier.WorkflowExecutionIdentifier( project=_internal_config.EXECUTION_PROJECT.get(), domain=_internal_config.EXECUTION_DOMAIN.get(), name=_internal_config.EXECUTION_NAME.get(), ), execution_date=_datetime.datetime.utcnow(), stats=_get_stats( # Stats metric path will be: # registration_project.registration_domain.app.module.task_name.user_stats # and it will be tagged with execution-level values for project/domain/wf/lp "{}.{}.{}.user_stats".format( _internal_config.TASK_PROJECT.get() or _internal_config.PROJECT.get(), _internal_config.TASK_DOMAIN.get() or _internal_config.DOMAIN.get(), _internal_config.TASK_NAME.get() or _internal_config.NAME.get(), ), tags={ "exec_project": _internal_config.EXECUTION_PROJECT.get(), "exec_domain": _internal_config.EXECUTION_DOMAIN.get(), "exec_workflow": _internal_config.EXECUTION_WORKFLOW.get(), "exec_launchplan": _internal_config.EXECUTION_LAUNCHPLAN.get(), "api_version": _api_version, }, ), logging=_logging, tmp_dir=user_workspace_dir, ) if cloud_provider == _constants.CloudProvider.AWS: file_access = _data_proxy.FileAccessProvider( local_sandbox_dir=_sdk_config.LOCAL_SANDBOX.get(), remote_proxy=_s3proxy.AwsS3Proxy(raw_output_data_prefix), ) elif cloud_provider == _constants.CloudProvider.GCP: file_access = _data_proxy.FileAccessProvider( local_sandbox_dir=_sdk_config.LOCAL_SANDBOX.get(), remote_proxy=_gcs_proxy.GCSProxy(raw_output_data_prefix), ) elif cloud_provider == _constants.CloudProvider.LOCAL: # A fake remote using the local disk will automatically be created file_access = _data_proxy.FileAccessProvider(local_sandbox_dir=_sdk_config.LOCAL_SANDBOX.get()) else: raise Exception(f"Bad cloud provider {cloud_provider}") with ctx.new_file_access_context(file_access_provider=file_access) as ctx: # TODO: This is copied from serialize, which means there's a similarity here I'm not seeing. env = { _internal_config.CONFIGURATION_PATH.env_var: _internal_config.CONFIGURATION_PATH.get(), _internal_config.IMAGE.env_var: _internal_config.IMAGE.get(), } serialization_settings = SerializationSettings( project=_internal_config.TASK_PROJECT.get(), domain=_internal_config.TASK_DOMAIN.get(), version=_internal_config.TASK_VERSION.get(), image_config=get_image_config(), env=env, ) # The reason we need this is because of dynamic tasks. Even if we move compilation all to Admin, # if a dynamic task calls some task, t1, we have to write to the DJ Spec the correct task # identifier for t1. with ctx.new_serialization_settings(serialization_settings=serialization_settings) as ctx: # Because execution states do not look up the context chain, it has to be made last with ctx.new_execution_context( mode=ExecutionState.Mode.TASK_EXECUTION, execution_params=execution_parameters ) as ctx: _dispatch_execute(ctx, task_def, inputs, output_prefix)
def __init__(self): super(RemoteDataContext, self).__init__(_s3proxy.AwsS3Proxy())