Ejemplo n.º 1
0
            application_file="local://" + settings.entrypoint_settings.path,
            executor_path=os.path.join(settings.flytekit_virtualenv_root,
                                       "bin/python"),
            main_class="",
            spark_type=SparkType.PYTHON,
        )
        return MessageToDict(job.to_flyte_idl())

    def pre_execute(self,
                    user_params: ExecutionParameters) -> ExecutionParameters:
        import pyspark as _pyspark

        ctx = FlyteContext.current_context()
        if not (ctx.execution_state and ctx.execution_state.Mode
                == ExecutionState.Mode.TASK_EXECUTION):
            # If either of above cases is not true, then we are in local execution of this task
            # Add system spark-conf for local/notebook based execution.
            spark_conf = set()
            for k, v in self.task_config.spark_conf.items():
                spark_conf.add((k, v))
            spark_conf.add(("spark.master", "local"))
            _pyspark.SparkConf().setAll(spark_conf)

        sess = _pyspark.sql.SparkSession.builder.appName(
            f"FlyteSpark: {user_params.execution_id}").getOrCreate()
        return user_params.builder().add_attr("SPARK_SESSION", sess).build()


# Inject the Spark plugin into flytekits dynamic plugin loading system
TaskPlugins.register_pythontask_plugin(Spark, PysparkFunctionTask)
Ejemplo n.º 2
0
                "DISTRIBUTED_TRAINING_CONTEXT", dist_ctx).build()

        return user_params

    def post_execute(self, user_params: ExecutionParameters, rval: Any) -> Any:
        """
        In the case of distributed execution, we check the should_persist_predicate in the configuration to determine
        if the output should be persisted. This is because in distributed training, multiple nodes may produce partial
        outputs and only the user process knows the output that should be generated. They can control the choice using
        the predicate.

        To control if output is generated across every execution, we override the post_execute method and sometimes
        return a None
        """
        if self._is_distributed():
            logger.info("Distributed context detected!")
            dctx = flytekit.current_context().distributed_training_context
            if not self.task_config.should_persist_output(dctx):
                logger.info(
                    "output persistence predicate not met, Flytekit will ignore outputs"
                )
                raise IgnoreOutputs(
                    f"Distributed context - Persistence predicate not met. Ignoring outputs - {dctx}"
                )
        return rval


# Register the Tensorflow Plugin into the flytekit core plugin system
TaskPlugins.register_pythontask_plugin(SagemakerTrainingJobConfig,
                                       SagemakerCustomTrainingTask)
Ejemplo n.º 3
0
    """

    num_workers: int


class PyTorchFunctionTask(PythonFunctionTask[PyTorch]):
    """
    Plugin that submits a PyTorchJob (see https://github.com/kubeflow/pytorch-operator)
        defined by the code within the _task_function to k8s cluster.
    """

    _PYTORCH_TASK_TYPE = "pytorch"

    def __init__(self, task_config: PyTorch, task_function: Callable,
                 **kwargs):
        super().__init__(
            task_config,
            task_function,
            task_type=self._PYTORCH_TASK_TYPE,
            **kwargs,
        )

    def get_custom(self, settings: SerializationSettings) -> Dict[str, Any]:
        job = PyTorchJob(workers_count=self.task_config.num_workers)
        return MessageToDict(job.to_flyte_idl())


# Register the Pytorch Plugin into the flytekit core plugin system
TaskPlugins.register_pythontask_plugin(PyTorch, PyTorchFunctionTask)
Ejemplo n.º 4
0
    per_replica_limits: Optional[Resources] = None


class TensorflowFunctionTask(PythonFunctionTask[TfJob]):
    """
    Plugin that submits a TFJob (see https://github.com/kubeflow/tf-operator)
        defined by the code within the _task_function to k8s cluster.
    """

    _TF_JOB_TASK_TYPE = "tensorflow"

    def __init__(self, task_config: TfJob, task_function: Callable, **kwargs):
        super().__init__(task_type=self._TF_JOB_TASK_TYPE,
                         task_config=task_config,
                         task_function=task_function,
                         requests=task_config.per_replica_requests,
                         limits=task_config.per_replica_limits,
                         **kwargs)

    def get_custom(self, settings: SerializationSettings) -> Dict[str, Any]:
        job = _task_model.TensorFlowJob(
            workers_count=self.task_config.num_workers,
            ps_replicas_count=self.task_config.num_ps_replicas,
            chief_replicas_count=self.task_config.num_chief_replicas,
        )
        return MessageToDict(job.to_flyte_idl())


# Register the Tensorflow Plugin into the flytekit core plugin system
TaskPlugins.register_pythontask_plugin(TfJob, TensorflowFunctionTask)
Ejemplo n.º 5
0
            final_containers.append(container)

        self.task_config._pod_spec.containers = final_containers

        return ApiClient().sanitize_for_serialization(self.task_config.pod_spec)

    def get_k8s_pod(self, settings: SerializationSettings) -> _task_models.K8sPod:
        return _task_models.K8sPod(
            pod_spec=self._serialize_pod_spec(settings),
            metadata=_task_models.K8sObjectMetadata(
                labels=self.task_config.labels,
                annotations=self.task_config.annotations,
            ),
        )

    def get_container(self, settings: SerializationSettings) -> _task_models.Container:
        return None

    def get_config(self, settings: SerializationSettings) -> Dict[str, str]:
        return {_PRIMARY_CONTAINER_NAME_FIELD: self.task_config.primary_container_name}

    def local_execute(self, ctx: FlyteContext, **kwargs) -> Union[Tuple[Promise], Promise, None]:
        logger.warning(
            "Running pod task locally. Local environment may not match pod environment which may cause issues."
        )
        return super().local_execute(ctx=ctx, **kwargs)


TaskPlugins.register_pythontask_plugin(Pod, PodFunctionTask)
Ejemplo n.º 6
0
    def get_config(self, settings: SerializationSettings) -> Dict[str, str]:
        # Parameters in taskTemplate config will be used to create aws job definition.
        # More detail about job definition: https://docs.aws.amazon.com/batch/latest/userguide/job_definition_parameters.html
        return {"platformCapabilities": self._task_config.platformCapabilities}

    def get_command(self, settings: SerializationSettings) -> List[str]:
        container_args = [
            "pyflyte-execute",
            "--inputs",
            "{{.input}}",
            "--output-prefix",
            # As of FlytePropeller v0.16.28, aws array batch plugin support to run single job.
            # This task will call aws batch plugin to execute the task on aws batch service.
            # For single job, FlytePropeller will always read the output from this directory (outputPrefix/0)
            # More detail, see https://github.com/flyteorg/flyteplugins/blob/0dd93c23ed2edeca65d58e89b0edb613f88120e0/go/tasks/plugins/array/catalog.go#L501.
            "{{.outputPrefix}}/0",
            "--raw-output-data-prefix",
            "{{.rawOutputDataPrefix}}",
            "--resolver",
            self.task_resolver.location,
            "--",
            *self.task_resolver.loader_args(settings, self),
        ]

        return container_args


# Inject the AWS batch plugin into flytekits dynamic plugin loading system
TaskPlugins.register_pythontask_plugin(AWSBatchConfig, AWSBatchFunctionTask)