Exemple #1
0
    def __init__(
        self,
        name: str,
        notebook_path: str,
        task_config: T = None,
        inputs: typing.Optional[typing.Dict[str, typing.Type]] = None,
        outputs: typing.Optional[typing.Dict[str, typing.Type]] = None,
        **kwargs,
    ):
        plugin_class = TaskPlugins.find_pythontask_plugin(type(task_config))
        self._plugin = plugin_class(task_config=task_config,
                                    task_function=_dummy_task_func)
        task_type = f"nb-{self._plugin.task_type}"
        self._notebook_path = os.path.abspath(notebook_path)

        if not os.path.exists(self._notebook_path):
            raise ValueError(
                f"Illegal notebook path passed in {self._notebook_path}")

        outputs.update({
            self._IMPLICIT_OP_NOTEBOOK:
            self._IMPLICIT_OP_NOTEBOOK_TYPE,
            self._IMPLICIT_RENDERED_NOTEBOOK:
            self._IMPLICIT_RENDERED_NOTEBOOK_TYPE,
        })
        super().__init__(name,
                         task_config,
                         task_type=task_type,
                         interface=Interface(inputs=inputs, outputs=outputs),
                         **kwargs)
Exemple #2
0
    def __init__(
        self,
        name: str,
        notebook_path: str,
        task_config: T = None,
        inputs: typing.Optional[typing.Dict[str, typing.Type]] = None,
        outputs: typing.Optional[typing.Dict[str, typing.Type]] = None,
        **kwargs,
    ):
        # Each instance of NotebookTask instantiates an underlying task with a dummy function that will only be used
        # to run pre- and post- execute functions using the corresponding task plugin.
        # We rename the function name here to ensure the generated task has a unique name and avoid duplicate task name
        # errors.
        # This seem like a hack. We should use a plugin_class that doesn't require a fake-function to make work.
        plugin_class = TaskPlugins.find_pythontask_plugin(type(task_config))
        self._config_task_instance = plugin_class(
            task_config=task_config, task_function=_dummy_task_func)
        # Rename the internal task so that there are no conflicts at serialization time. Technically these internal
        # tasks should not be serialized at all, but we don't currently have a mechanism for skipping Flyte entities
        # at serialization time.
        self._config_task_instance._name = f"{PAPERMILL_TASK_PREFIX}.{name}"
        task_type = f"nb-{self._config_task_instance.task_type}"
        self._notebook_path = os.path.abspath(notebook_path)

        if not os.path.exists(self._notebook_path):
            raise ValueError(
                f"Illegal notebook path passed in {self._notebook_path}")

        if outputs:
            outputs.update({
                self._IMPLICIT_OP_NOTEBOOK:
                self._IMPLICIT_OP_NOTEBOOK_TYPE,
                self._IMPLICIT_RENDERED_NOTEBOOK:
                self._IMPLICIT_RENDERED_NOTEBOOK_TYPE,
            })
        super().__init__(name,
                         task_config,
                         task_type=task_type,
                         interface=Interface(inputs=inputs, outputs=outputs),
                         **kwargs)
Exemple #3
0
            application_file="local://" + settings.entrypoint_settings.path,
            executor_path=os.path.join(settings.flytekit_virtualenv_root,
                                       "bin/python"),
            main_class="",
            spark_type=SparkType.PYTHON,
        )
        return MessageToDict(job.to_flyte_idl())

    def pre_execute(self,
                    user_params: ExecutionParameters) -> ExecutionParameters:
        import pyspark as _pyspark

        ctx = FlyteContext.current_context()
        if not (ctx.execution_state and ctx.execution_state.Mode
                == ExecutionState.Mode.TASK_EXECUTION):
            # If either of above cases is not true, then we are in local execution of this task
            # Add system spark-conf for local/notebook based execution.
            spark_conf = set()
            for k, v in self.task_config.spark_conf.items():
                spark_conf.add((k, v))
            spark_conf.add(("spark.master", "local"))
            _pyspark.SparkConf().setAll(spark_conf)

        sess = _pyspark.sql.SparkSession.builder.appName(
            f"FlyteSpark: {user_params.execution_id}").getOrCreate()
        return user_params.builder().add_attr("SPARK_SESSION", sess).build()


# Inject the Spark plugin into flytekits dynamic plugin loading system
TaskPlugins.register_pythontask_plugin(Spark, PysparkFunctionTask)
Exemple #4
0
                "DISTRIBUTED_TRAINING_CONTEXT", dist_ctx).build()

        return user_params

    def post_execute(self, user_params: ExecutionParameters, rval: Any) -> Any:
        """
        In the case of distributed execution, we check the should_persist_predicate in the configuration to determine
        if the output should be persisted. This is because in distributed training, multiple nodes may produce partial
        outputs and only the user process knows the output that should be generated. They can control the choice using
        the predicate.

        To control if output is generated across every execution, we override the post_execute method and sometimes
        return a None
        """
        if self._is_distributed():
            logger.info("Distributed context detected!")
            dctx = flytekit.current_context().distributed_training_context
            if not self.task_config.should_persist_output(dctx):
                logger.info(
                    "output persistence predicate not met, Flytekit will ignore outputs"
                )
                raise IgnoreOutputs(
                    f"Distributed context - Persistence predicate not met. Ignoring outputs - {dctx}"
                )
        return rval


# Register the Tensorflow Plugin into the flytekit core plugin system
TaskPlugins.register_pythontask_plugin(SagemakerTrainingJobConfig,
                                       SagemakerCustomTrainingTask)
Exemple #5
0
    """

    num_workers: int


class PyTorchFunctionTask(PythonFunctionTask[PyTorch]):
    """
    Plugin that submits a PyTorchJob (see https://github.com/kubeflow/pytorch-operator)
        defined by the code within the _task_function to k8s cluster.
    """

    _PYTORCH_TASK_TYPE = "pytorch"

    def __init__(self, task_config: PyTorch, task_function: Callable,
                 **kwargs):
        super().__init__(
            task_config,
            task_function,
            task_type=self._PYTORCH_TASK_TYPE,
            **kwargs,
        )

    def get_custom(self, settings: SerializationSettings) -> Dict[str, Any]:
        job = PyTorchJob(workers_count=self.task_config.num_workers)
        return MessageToDict(job.to_flyte_idl())


# Register the Pytorch Plugin into the flytekit core plugin system
TaskPlugins.register_pythontask_plugin(PyTorch, PyTorchFunctionTask)
Exemple #6
0
    per_replica_limits: Optional[Resources] = None


class TensorflowFunctionTask(PythonFunctionTask[TfJob]):
    """
    Plugin that submits a TFJob (see https://github.com/kubeflow/tf-operator)
        defined by the code within the _task_function to k8s cluster.
    """

    _TF_JOB_TASK_TYPE = "tensorflow"

    def __init__(self, task_config: TfJob, task_function: Callable, **kwargs):
        super().__init__(task_type=self._TF_JOB_TASK_TYPE,
                         task_config=task_config,
                         task_function=task_function,
                         requests=task_config.per_replica_requests,
                         limits=task_config.per_replica_limits,
                         **kwargs)

    def get_custom(self, settings: SerializationSettings) -> Dict[str, Any]:
        job = _task_model.TensorFlowJob(
            workers_count=self.task_config.num_workers,
            ps_replicas_count=self.task_config.num_ps_replicas,
            chief_replicas_count=self.task_config.num_chief_replicas,
        )
        return MessageToDict(job.to_flyte_idl())


# Register the Tensorflow Plugin into the flytekit core plugin system
TaskPlugins.register_pythontask_plugin(TfJob, TensorflowFunctionTask)
Exemple #7
0
            final_containers.append(container)

        self.task_config._pod_spec.containers = final_containers

        return ApiClient().sanitize_for_serialization(self.task_config.pod_spec)

    def get_k8s_pod(self, settings: SerializationSettings) -> _task_models.K8sPod:
        return _task_models.K8sPod(
            pod_spec=self._serialize_pod_spec(settings),
            metadata=_task_models.K8sObjectMetadata(
                labels=self.task_config.labels,
                annotations=self.task_config.annotations,
            ),
        )

    def get_container(self, settings: SerializationSettings) -> _task_models.Container:
        return None

    def get_config(self, settings: SerializationSettings) -> Dict[str, str]:
        return {_PRIMARY_CONTAINER_NAME_FIELD: self.task_config.primary_container_name}

    def local_execute(self, ctx: FlyteContext, **kwargs) -> Union[Tuple[Promise], Promise, None]:
        logger.warning(
            "Running pod task locally. Local environment may not match pod environment which may cause issues."
        )
        return super().local_execute(ctx=ctx, **kwargs)


TaskPlugins.register_pythontask_plugin(Pod, PodFunctionTask)
Exemple #8
0
    def get_config(self, settings: SerializationSettings) -> Dict[str, str]:
        # Parameters in taskTemplate config will be used to create aws job definition.
        # More detail about job definition: https://docs.aws.amazon.com/batch/latest/userguide/job_definition_parameters.html
        return {"platformCapabilities": self._task_config.platformCapabilities}

    def get_command(self, settings: SerializationSettings) -> List[str]:
        container_args = [
            "pyflyte-execute",
            "--inputs",
            "{{.input}}",
            "--output-prefix",
            # As of FlytePropeller v0.16.28, aws array batch plugin support to run single job.
            # This task will call aws batch plugin to execute the task on aws batch service.
            # For single job, FlytePropeller will always read the output from this directory (outputPrefix/0)
            # More detail, see https://github.com/flyteorg/flyteplugins/blob/0dd93c23ed2edeca65d58e89b0edb613f88120e0/go/tasks/plugins/array/catalog.go#L501.
            "{{.outputPrefix}}/0",
            "--raw-output-data-prefix",
            "{{.rawOutputDataPrefix}}",
            "--resolver",
            self.task_resolver.location,
            "--",
            *self.task_resolver.loader_args(settings, self),
        ]

        return container_args


# Inject the AWS batch plugin into flytekits dynamic plugin loading system
TaskPlugins.register_pythontask_plugin(AWSBatchConfig, AWSBatchFunctionTask)