Ejemplo n.º 1
0
    def __init__(
        self,
        task_function,
        task_type,
        discovery_version,
        retries,
        interruptible,
        deprecated,
        discoverable,
        timeout,
        spark_conf,
        hadoop_conf,
        environment,
    ):
        """
        :param task_function: Function container user code.  This will be executed via the SDK's engine.
        :param Text task_type: string describing the task type
        :param Text discovery_version: string describing the version for task discovery purposes
        :param int retries: Number of retries to attempt
        :param bool interruptible: Whether or not task is interruptible
        :param Text deprecated:
        :param bool discoverable:
        :param datetime.timedelta timeout:
        :param dict[Text,Text] spark_conf:
        :param dict[Text,Text] hadoop_conf:
        :param dict[Text,Text] environment: [optional] environment variables to set when executing this task.
        """

        spark_exec_path = _os.path.abspath(_entrypoint.__file__)
        if spark_exec_path.endswith('.pyc'):
            spark_exec_path = spark_exec_path[:-1]

        spark_job = _task_models.SparkJob(
            spark_conf=spark_conf,
            hadoop_conf=hadoop_conf,
            application_file="local://" + spark_exec_path,
            executor_path=_sys.executable,
        ).to_flyte_idl()
        super(SdkSparkTask, self).__init__(
            task_function,
            task_type,
            discovery_version,
            retries,
            interruptible,
            deprecated,
            "",
            "",
            "",
            "",
            "",
            "",
            "",
            "",
            discoverable,
            timeout,
            environment,
            _MessageToDict(spark_job),
        )
Ejemplo n.º 2
0
 def get_custom(self, settings: SerializationSettings) -> Dict[str, Any]:
     job = _task_model.SparkJob(
         spark_conf=self.task_config.spark_conf,
         hadoop_conf=self.task_config.hadoop_conf,
         application_file="local://" + settings.entrypoint_settings.path,
         executor_path=settings.python_interpreter,
         main_class="",
         spark_type=SparkType.PYTHON,
     )
     return MessageToDict(job.to_flyte_idl())
Ejemplo n.º 3
0
    def __init__(
        self,
        task_type,
        discovery_version,
        retries,
        interruptible,
        task_inputs,
        deprecated,
        discoverable,
        timeout,
        spark_type,
        main_class,
        main_application_file,
        spark_conf,
        hadoop_conf,
        environment,
    ):
        """
        :param Text task_type: string describing the task type
        :param Text discovery_version: string describing the version for task discovery purposes
        :param int retries: Number of retries to attempt
        :param bool interruptible: Whether or not task is interruptible
        :param Text deprecated:
        :param bool discoverable:
        :param datetime.timedelta timeout:
        :param Text spark_type: Type of Spark Job: Scala/Java
        :param Text main_class: Main class to execute for Scala/Java jobs
        :param Text main_application_file: Main application file
        :param dict[Text,Text] spark_conf:
        :param dict[Text,Text] hadoop_conf:
        :param dict[Text,Text] environment: [optional] environment variables to set when executing this task.
        """

        spark_job = _task_models.SparkJob(
            spark_conf=spark_conf,
            hadoop_conf=hadoop_conf,
            spark_type=spark_type,
            application_file=main_application_file,
            main_class=main_class,
            executor_path=_sys.executable,
        ).to_flyte_idl()

        super(SdkGenericSparkTask, self).__init__(
            task_type,
            _task_models.TaskMetadata(
                discoverable,
                _task_models.RuntimeMetadata(
                    _task_models.RuntimeMetadata.RuntimeType.FLYTE_SDK,
                    __version__, 'spark'), timeout,
                _literal_models.RetryStrategy(retries), interruptible,
                discovery_version, deprecated),
            _interface.TypedInterface({}, {}),
            _MessageToDict(spark_job),
        )

        # Add Inputs
        if task_inputs is not None:
            task_inputs(self)

        # Container after the Inputs have been updated.
        self._container = self._get_container_definition(
            environment=environment)
Ejemplo n.º 4
0
    def __init__(
        self,
        notebook_path,
        inputs,
        outputs,
        spark_conf,
        discovery_version,
        retries,
        deprecated,
        discoverable,
        timeout,
        environment=None,
    ):

        spark_exec_path = _os.path.abspath(_entrypoint.__file__)
        if spark_exec_path.endswith(".pyc"):
            spark_exec_path = spark_exec_path[:-1]

        if spark_conf is None:
            # Parse spark_conf from notebook if not set at task_level.
            with open(notebook_path) as json_file:
                data = _json.load(json_file)
                for p in data["cells"]:
                    meta = p["metadata"]
                    if "tags" in meta:
                        if "conf" in meta["tags"]:
                            sc_str = " ".join(p["source"])
                            ldict = {}
                            exec(sc_str, globals(), ldict)
                            spark_conf = ldict["spark_conf"]

            spark_job = _task_models.SparkJob(
                spark_conf=spark_conf,
                main_class="",
                spark_type=_spark_type.PYTHON,
                hadoop_conf={},
                application_file="local://" + spark_exec_path,
                executor_path=_sys.executable,
            ).to_flyte_idl()

        super(SdkNotebookSparkTask, self).__init__(
            notebook_path,
            inputs,
            outputs,
            _constants.SdkTaskType.SPARK_TASK,
            discovery_version,
            retries,
            deprecated,
            "",
            "",
            "",
            "",
            "",
            "",
            "",
            "",
            discoverable,
            timeout,
            environment,
            _json_format.MessageToDict(spark_job),
        )