예제 #1
0
    def __init__(
        self,
        name: str,
        notebook_path: str,
        task_config: T = None,
        inputs: typing.Optional[typing.Dict[str, typing.Type]] = None,
        outputs: typing.Optional[typing.Dict[str, typing.Type]] = None,
        **kwargs,
    ):
        plugin_class = TaskPlugins.find_pythontask_plugin(type(task_config))
        self._plugin = plugin_class(task_config=task_config,
                                    task_function=_dummy_task_func)
        task_type = f"nb-{self._plugin.task_type}"
        self._notebook_path = os.path.abspath(notebook_path)

        if not os.path.exists(self._notebook_path):
            raise ValueError(
                f"Illegal notebook path passed in {self._notebook_path}")

        outputs.update({
            self._IMPLICIT_OP_NOTEBOOK:
            self._IMPLICIT_OP_NOTEBOOK_TYPE,
            self._IMPLICIT_RENDERED_NOTEBOOK:
            self._IMPLICIT_RENDERED_NOTEBOOK_TYPE,
        })
        super().__init__(name,
                         task_config,
                         task_type=task_type,
                         interface=Interface(inputs=inputs, outputs=outputs),
                         **kwargs)
예제 #2
0
                    This mode indicates we are actually in a remote execute environment (within sagemaker in this case)
                """
                dist_ctx = DistributedTrainingContext.from_env()
            else:
                dist_ctx = DistributedTrainingContext.local_execute()
            return user_params.builder().add_attr("DISTRIBUTED_TRAINING_CONTEXT", dist_ctx).build()

        return user_params

    def post_execute(self, user_params: ExecutionParameters, rval: Any) -> Any:
        """
        In the case of distributed execution, we check the should_persist_predicate in the configuration to determine
        if the output should be persisted. This is because in distributed training, multiple nodes may produce partial
        outputs and only the user process knows the output that should be generated. They can control the choice using
        the predicate.

        To control if output is generated across every execution, we override the post_execute method and sometimes
        return a None
        """
        if self._is_distributed():
            logging.info("Distributed context detected!")
            dctx = flytekit.current_context().distributed_training_context
            if not self.task_config.should_persist_output(dctx):
                logging.info("output persistence predicate not met, Flytekit will ignore outputs")
                raise IgnoreOutputs(f"Distributed context - Persistence predicate not met. Ignoring outputs - {dctx}")
        return rval


# Register the Tensorflow Plugin into the flytekit core plugin system
TaskPlugins.register_pythontask_plugin(SagemakerTrainingJobConfig, SagemakerCustomTrainingTask)
예제 #3
0
                for resource in sdk_default_container.resources.limits:
                    resource_requirements.limits[
                        _core_task.Resources.ResourceName.Name(resource.name).lower()
                    ].CopyFrom(Quantity(string=resource.value))
                for resource in sdk_default_container.resources.requests:
                    resource_requirements.requests[
                        _core_task.Resources.ResourceName.Name(resource.name).lower()
                    ].CopyFrom(Quantity(string=resource.value))
                if resource_requirements.ByteSize():
                    # Important! Only copy over resource requirements if they are non-empty.
                    container.resources.CopyFrom(resource_requirements)

                del container.env[:]
                container.env.extend([EnvVar(name=key, value=val) for key, val in sdk_default_container.env.items()])

            final_containers.append(container)

        del self.task_config._pod_spec.containers[:]
        self.task_config._pod_spec.containers.extend(final_containers)

        pod_job_plugin = _task_models.SidecarJob(
            pod_spec=self.task_config.pod_spec, primary_container_name=self.task_config.primary_container_name,
        ).to_flyte_idl()
        return MessageToDict(pod_job_plugin)

    def _local_execute(self, ctx: FlyteContext, **kwargs) -> Union[Tuple[Promise], Promise, None]:
        raise _user_exceptions.FlyteUserException("Local execute is not currently supported for pod tasks")


TaskPlugins.register_pythontask_plugin(Pod, PodFunctionTask)
예제 #4
0
    num_workers: int
    per_replica_requests: Optional[Resources] = None
    per_replica_limits: Optional[Resources] = None


class PyTorchFunctionTask(PythonFunctionTask[PyTorch]):
    """
    Plugin that submits a PyTorchJob (see https://github.com/kubeflow/pytorch-operator)
        defined by the code within the _task_function to k8s cluster.
    """

    _PYTORCH_TASK_TYPE = "pytorch"

    def __init__(self, task_config: PyTorch, task_function: Callable,
                 **kwargs):
        super().__init__(task_config,
                         task_function,
                         task_type=self._PYTORCH_TASK_TYPE,
                         requests=task_config.per_replica_requests,
                         limits=task_config.per_replica_limits,
                         **kwargs)

    def get_custom(self, settings: SerializationSettings) -> Dict[str, Any]:
        job = _task_model.PyTorchJob(
            workers_count=self.task_config.num_workers)
        return MessageToDict(job.to_flyte_idl())


# Register the Pytorch Plugin into the flytekit core plugin system
TaskPlugins.register_pythontask_plugin(PyTorch, PyTorchFunctionTask)
예제 #5
0
            application_file="local://" + settings.entrypoint_settings.path,
            executor_path=os.path.join(settings.flytekit_virtualenv_root,
                                       "bin/python"),
            main_class="",
            spark_type=SparkType.PYTHON,
        )
        return MessageToDict(job.to_flyte_idl())

    def pre_execute(self,
                    user_params: ExecutionParameters) -> ExecutionParameters:
        import pyspark as _pyspark

        ctx = FlyteContext.current_context()
        if not (ctx.execution_state and ctx.execution_state.Mode
                == ExecutionState.Mode.TASK_EXECUTION):
            # If either of above cases is not true, then we are in local execution of this task
            # Add system spark-conf for local/notebook based execution.
            spark_conf = set()
            for k, v in self.task_config.spark_conf.items():
                spark_conf.add((k, v))
            spark_conf.add(("spark.master", "local"))
            _pyspark.SparkConf().setAll(spark_conf)

        sess = _pyspark.sql.SparkSession.builder.appName(
            f"FlyteSpark: {user_params.execution_id}").getOrCreate()
        return user_params.builder().add_attr("SPARK_SESSION", sess).build()


# Inject the Spark plugin into flytekits dynamic plugin loading system
TaskPlugins.register_pythontask_plugin(Spark, PysparkFunctionTask)

class _Dynamic(object):
    pass


class DynamicWorkflowTask(PythonFunctionTask[_Dynamic]):
    def __init__(
        self,
        task_config: _Dynamic,
        dynamic_workflow_function: Callable,
        metadata: Optional[TaskMetadata] = None,
        **kwargs,
    ):
        self._wf = None

        super().__init__(
            task_config=task_config,
            task_function=dynamic_workflow_function,
            metadata=metadata,
            task_type="dynamic-task",
            **kwargs,
        )


# Register dynamic workflow task
TaskPlugins.register_pythontask_plugin(_Dynamic, DynamicWorkflowTask)
dynamic = functools.partial(
    task.task, task_config=_Dynamic(), execution_mode=PythonFunctionTask.ExecutionBehavior.DYNAMIC
)
예제 #7
0
    per_replica_limits: Optional[Resources] = None


class TensorflowFunctionTask(PythonFunctionTask[TfJob]):
    """
    Plugin that submits a TFJob (see https://github.com/kubeflow/tf-operator)
        defined by the code within the _task_function to k8s cluster.
    """

    _TF_JOB_TASK_TYPE = "tensorflow"

    def __init__(self, task_config: TfJob, task_function: Callable, **kwargs):
        super().__init__(task_type=self._TF_JOB_TASK_TYPE,
                         task_config=task_config,
                         task_function=task_function,
                         requests=task_config.per_replica_requests,
                         limits=task_config.per_replica_limits,
                         **kwargs)

    def get_custom(self, settings: SerializationSettings) -> Dict[str, Any]:
        job = _task_model.TensorFlowJob(
            workers_count=self.task_config.num_workers,
            ps_replicas_count=self.task_config.num_ps_replicas,
            chief_replicas_count=self.task_config.num_chief_replicas,
        )
        return MessageToDict(job.to_flyte_idl())


# Register the Tensorflow Plugin into the flytekit core plugin system
TaskPlugins.register_pythontask_plugin(TfJob, TensorflowFunctionTask)