Example #1
0
class RemoteDataContext(_OutputDataContext):

    _CLOUD_PROVIDER_TO_PROXIES = {
        _constants.CloudProvider.AWS: _s3proxy.AwsS3Proxy(),
        _constants.CloudProvider.GCP: _gcs_proxy.GCSProxy(),
    }

    def __init__(self, cloud_provider=None):
        """
        :param Optional[Text] cloud_provider: From flytekit.common.constants.CloudProvider enum
        """
        cloud_provider = cloud_provider or _platform_config.CLOUD_PROVIDER.get()
        proxy = type(self)._CLOUD_PROVIDER_TO_PROXIES.get(cloud_provider, None)
        if proxy is None:
            raise _user_exception.FlyteAssertion(
                "Configured cloud provider is not supported for data I/O.  Received: {}, expected one of: {}".format(
                    cloud_provider,
                    list(type(self)._CLOUD_PROVIDER_TO_PROXIES.keys())
                )
            )
        super(RemoteDataContext, self).__init__(proxy)
Example #2
0
class Data(object):
    # TODO: More proxies for more environments.
    _DATA_PROXIES = {
        "s3:/": _s3proxy.AwsS3Proxy(),
        "gs:/": _gcs_proxy.GCSProxy(),
        "http://": _http_data_proxy.HttpFileProxy(),
        "https://": _http_data_proxy.HttpFileProxy(),
    }

    @classmethod
    def _load_data_proxy_by_path(cls, path):
        """
        :param Text path:
        :rtype: flytekit.interfaces.data.common.DataProxy
        """
        for k, v in _six.iteritems(cls._DATA_PROXIES):
            if path.startswith(k):
                return v
        return _OutputDataContext.get_default_proxy()

    @classmethod
    def data_exists(cls, path):
        """
        :param Text path:
        :rtype: bool: whether the file exists or not
        """
        with _common_utils.PerformanceTimer("Check file exists {}".format(path)):
            proxy = cls._load_data_proxy_by_path(path)
            return proxy.exists(path)

    @classmethod
    def get_data(cls, remote_path, local_path, is_multipart=False):
        """
        :param Text remote_path:
        :param Text local_path:
        :param bool is_multipart:
        """
        try:
            with _common_utils.PerformanceTimer("Copying ({} -> {})".format(remote_path, local_path)):
                proxy = cls._load_data_proxy_by_path(remote_path)
                if is_multipart:
                    proxy.download_directory(remote_path, local_path)
                else:
                    proxy.download(remote_path, local_path)
        except Exception as ex:
            raise _user_exception.FlyteAssertion(
                "Failed to get data from {remote_path} to {local_path} (recursive={is_multipart}).\n\n"
                "Original exception: {error_string}".format(
                    remote_path=remote_path,
                    local_path=local_path,
                    is_multipart=is_multipart,
                    error_string=_six.text_type(ex)
                )
            )

    @classmethod
    def put_data(cls, local_path, remote_path, is_multipart=False):
        """
        :param Text local_path:
        :param Text remote_path:
        :param bool is_multipart:
        """
        try:
            with _common_utils.PerformanceTimer("Writing ({} -> {})".format(local_path, remote_path)):
                proxy = cls._load_data_proxy_by_path(remote_path)
                if is_multipart:
                    proxy.upload_directory(local_path, remote_path)
                else:
                    proxy.upload(local_path, remote_path)
        except Exception as ex:
            raise _user_exception.FlyteAssertion(
                "Failed to put data from {local_path} to {remote_path} (recursive={is_multipart}).\n\n"
                "Original exception: {error_string}".format(
                    remote_path=remote_path,
                    local_path=local_path,
                    is_multipart=is_multipart,
                    error_string=_six.text_type(ex)
                )
            )

    @classmethod
    def get_remote_path(cls):
        """
        :rtype: Text
        """
        return _OutputDataContext.get_active_proxy().get_random_path()

    @classmethod
    def get_remote_directory(cls):
        """
        :rtype: Text
        """
        return _OutputDataContext.get_active_proxy().get_random_directory()
Example #3
0
def _handle_annotated_task(task_def: PythonTask, inputs: str, output_prefix: str, raw_output_data_prefix: str):
    """
    Entrypoint for all PythonTask extensions
    """
    _click.echo("Running native-typed task")
    cloud_provider = _platform_config.CLOUD_PROVIDER.get()
    log_level = _internal_config.LOGGING_LEVEL.get() or _sdk_config.LOGGING_LEVEL.get()
    _logging.getLogger().setLevel(log_level)

    ctx = FlyteContext.current_context()

    # Create directories
    user_workspace_dir = ctx.file_access.local_access.get_random_directory()
    _click.echo(f"Using user directory {user_workspace_dir}")
    pathlib.Path(user_workspace_dir).mkdir(parents=True, exist_ok=True)
    from flytekit import __version__ as _api_version

    execution_parameters = ExecutionParameters(
        execution_id=_identifier.WorkflowExecutionIdentifier(
            project=_internal_config.EXECUTION_PROJECT.get(),
            domain=_internal_config.EXECUTION_DOMAIN.get(),
            name=_internal_config.EXECUTION_NAME.get(),
        ),
        execution_date=_datetime.datetime.utcnow(),
        stats=_get_stats(
            # Stats metric path will be:
            # registration_project.registration_domain.app.module.task_name.user_stats
            # and it will be tagged with execution-level values for project/domain/wf/lp
            "{}.{}.{}.user_stats".format(
                _internal_config.TASK_PROJECT.get() or _internal_config.PROJECT.get(),
                _internal_config.TASK_DOMAIN.get() or _internal_config.DOMAIN.get(),
                _internal_config.TASK_NAME.get() or _internal_config.NAME.get(),
            ),
            tags={
                "exec_project": _internal_config.EXECUTION_PROJECT.get(),
                "exec_domain": _internal_config.EXECUTION_DOMAIN.get(),
                "exec_workflow": _internal_config.EXECUTION_WORKFLOW.get(),
                "exec_launchplan": _internal_config.EXECUTION_LAUNCHPLAN.get(),
                "api_version": _api_version,
            },
        ),
        logging=_logging,
        tmp_dir=user_workspace_dir,
    )

    if cloud_provider == _constants.CloudProvider.AWS:
        file_access = _data_proxy.FileAccessProvider(
            local_sandbox_dir=_sdk_config.LOCAL_SANDBOX.get(), remote_proxy=_s3proxy.AwsS3Proxy(raw_output_data_prefix),
        )
    elif cloud_provider == _constants.CloudProvider.GCP:
        file_access = _data_proxy.FileAccessProvider(
            local_sandbox_dir=_sdk_config.LOCAL_SANDBOX.get(), remote_proxy=_gcs_proxy.GCSProxy(raw_output_data_prefix),
        )
    elif cloud_provider == _constants.CloudProvider.LOCAL:
        # A fake remote using the local disk will automatically be created
        file_access = _data_proxy.FileAccessProvider(local_sandbox_dir=_sdk_config.LOCAL_SANDBOX.get())
    else:
        raise Exception(f"Bad cloud provider {cloud_provider}")

    with ctx.new_file_access_context(file_access_provider=file_access) as ctx:
        # TODO: This is copied from serialize, which means there's a similarity here I'm not seeing.
        env = {
            _internal_config.CONFIGURATION_PATH.env_var: _internal_config.CONFIGURATION_PATH.get(),
            _internal_config.IMAGE.env_var: _internal_config.IMAGE.get(),
        }

        serialization_settings = SerializationSettings(
            project=_internal_config.TASK_PROJECT.get(),
            domain=_internal_config.TASK_DOMAIN.get(),
            version=_internal_config.TASK_VERSION.get(),
            image_config=get_image_config(),
            env=env,
        )

        # The reason we need this is because of dynamic tasks. Even if we move compilation all to Admin,
        # if a dynamic task calls some task, t1, we have to write to the DJ Spec the correct task
        # identifier for t1.
        with ctx.new_serialization_settings(serialization_settings=serialization_settings) as ctx:
            # Because execution states do not look up the context chain, it has to be made last
            with ctx.new_execution_context(
                mode=ExecutionState.Mode.TASK_EXECUTION, execution_params=execution_parameters
            ) as ctx:
                _dispatch_execute(ctx, task_def, inputs, output_prefix)
Example #4
0
def gcs_proxy():
    return _gcs_proxy.GCSProxy()
Example #5
0
def test_random_path(mock_update_cmd_config_and_execute, gsutil_parallelism,
                     gcs_proxy):
    gcs_with_raw_prefix = _gcs_proxy.GCSProxy("gcs://stuff")
    result = gcs_with_raw_prefix.get_random_path()
    assert result.startswith("gcs://stuff")
Example #6
0
def test_raw_prefix_property(mock_update_cmd_config_and_execute,
                             gsutil_parallelism, gcs_proxy):
    gcs_with_raw_prefix = _gcs_proxy.GCSProxy("gcs://stuff")
    assert gcs_with_raw_prefix.raw_output_data_prefix_override == "gcs://stuff"