示例#1
0
    def __init__(self, pb_object: Union[GeneratedProtocolMessageType, FlyteIdlEntity]):
        """
        :param Union[T, FlyteIdlEntity] pb_object:
        """
        struct = Struct()
        v = pb_object

        # This section converts an existing proto object (or a subclass of) to the right type expected by this instance
        # of GenericProto. GenericProto can be used with any protobuf type (not restricted to FlyteType). This makes it
        # a bit tricky to figure out the right version of the underlying raw proto class to use to populate the final
        # struct.
        # If the provided object has to_flyte_idl(), call it to produce a raw proto.
        if isinstance(pb_object, FlyteIdlEntity):
            v = pb_object.to_flyte_idl()

        # A check to ensure the raw proto (v) is of the correct expected type. This also performs one final attempt to
        # convert it to the correct type by leveraging from_flyte_idl (implemented by all FlyteTypes) in case this class
        # is initialized with one.
        expected_type = type(self).pb_type
        if expected_type != type(v) and expected_type != type(pb_object):
            if isinstance(type(self).pb_type, FlyteType):
                v = expected_type.from_flyte_idl(v).to_flyte_idl()
            else:
                raise _user_exceptions.FlyteTypeException(
                    received_type=type(pb_object), expected_type=expected_type, received_value=pb_object
                )

        struct.update(_MessageToDict(v))
        super().__init__(scalar=_literals.Scalar(generic=struct,))
示例#2
0
    def to_python_value(self, ctx: FlyteContext, lv: Literal, expected_python_type: Type[T]) -> T:
        if not (lv and lv.scalar and lv.scalar.generic):
            raise AssertionError("Can only covert a generic literal to a Protobuf")

        pb_obj = expected_python_type()
        dictionary = _MessageToDict(lv.scalar.generic)
        pb_obj = _ParseDict(dictionary, pb_obj)
        return pb_obj
示例#3
0
    def __init__(
        self,
        task_function,
        task_type,
        discovery_version,
        retries,
        interruptible,
        deprecated,
        discoverable,
        timeout,
        spark_conf,
        hadoop_conf,
        environment,
    ):
        """
        :param task_function: Function container user code.  This will be executed via the SDK's engine.
        :param Text task_type: string describing the task type
        :param Text discovery_version: string describing the version for task discovery purposes
        :param int retries: Number of retries to attempt
        :param bool interruptible: Whether or not task is interruptible
        :param Text deprecated:
        :param bool discoverable:
        :param datetime.timedelta timeout:
        :param dict[Text,Text] spark_conf:
        :param dict[Text,Text] hadoop_conf:
        :param dict[Text,Text] environment: [optional] environment variables to set when executing this task.
        """

        spark_exec_path = _os.path.abspath(_entrypoint.__file__)
        if spark_exec_path.endswith('.pyc'):
            spark_exec_path = spark_exec_path[:-1]

        spark_job = _task_models.SparkJob(
            spark_conf=spark_conf,
            hadoop_conf=hadoop_conf,
            application_file="local://" + spark_exec_path,
            executor_path=_sys.executable,
        ).to_flyte_idl()
        super(SdkSparkTask, self).__init__(
            task_function,
            task_type,
            discovery_version,
            retries,
            interruptible,
            deprecated,
            "",
            "",
            "",
            "",
            "",
            "",
            "",
            "",
            discoverable,
            timeout,
            environment,
            _MessageToDict(spark_job),
        )
示例#4
0
 def __init__(
         self,
         task_function,
         task_type,
         discovery_version,
         retries,
         deprecated,
         discoverable,
         timeout,
         image,
         num_ps,
         replicas,
         command,
         args,
         volumeClaimName,
         environment
 ):
     """
     :param task_function: Function container user code.  This will be executed via the SDK's engine.
     :param Text task_type: string describing the task type
     :param Text discovery_version: string describing the version for task discovery purposes
     :param int retries: Number of retries to attempt
     :param Text deprecated:
     :param bool discoverable:
     :param datetime.timedelta timeout:
     :param dict[Text, Text] spark_conf:
     :param dict[Text, Text] hadoop_conf:
     :param dict[Text, Text] environment: [optional] environment variables to set when executing this task.
     """
     tfjob = _tfjob_model.TFJob(
         image=image,
         num_ps=num_ps,
         replicas=replicas,
         command=command,
         args=args,
         volumeClaimName=volumeClaimName,
     ).to_flyte_idl()
     super(SdkTFJobTask, self).__init__(
         task_function,
         task_type,
         discovery_version,
         retries,
         deprecated,
         "",
         "",
         "",
         "",
         "",
         "",
         "",
         "",
         discoverable,
         timeout,
         environment,
         _MessageToDict(tfjob),
     )
示例#5
0
 def __init__(
     self,
     task_function,
     task_type,
     cache_version,
     retries,
     interruptible,
     deprecated,
     cache,
     timeout,
     workers_count,
     ps_replicas_count,
     chief_replicas_count,
     per_replica_storage_request,
     per_replica_cpu_request,
     per_replica_gpu_request,
     per_replica_memory_request,
     per_replica_storage_limit,
     per_replica_cpu_limit,
     per_replica_gpu_limit,
     per_replica_memory_limit,
     environment,
 ):
     tensorflow_job = _task_models.TensorFlowJob(
         workers_count=workers_count,
         ps_replicas_count=ps_replicas_count,
         chief_replicas_count=chief_replicas_count).to_flyte_idl()
     super(SdkTensorFlowTask, self).__init__(
         task_function=task_function,
         task_type=task_type,
         discovery_version=cache_version,
         retries=retries,
         interruptible=interruptible,
         deprecated=deprecated,
         storage_request=per_replica_storage_request,
         cpu_request=per_replica_cpu_request,
         gpu_request=per_replica_gpu_request,
         memory_request=per_replica_memory_request,
         storage_limit=per_replica_storage_limit,
         cpu_limit=per_replica_cpu_limit,
         gpu_limit=per_replica_gpu_limit,
         memory_limit=per_replica_memory_limit,
         discoverable=cache,
         timeout=timeout,
         environment=environment,
         custom=_MessageToDict(tensorflow_job),
     )
示例#6
0
 def to_python_std(self):
     """
     :returns: The protobuf object as defined by the user.
     :rtype: T
     """
     pb_obj = type(self).pb_type()
     try:
         dictionary = _MessageToDict(self.scalar.generic)
         pb_obj = _ParseDict(dictionary, pb_obj)
     except Error as err:
         raise _user_exceptions.FlyteTypeException(
             received_type="generic",
             expected_type=type(self).pb_type,
             received_value=_base64.b64encode(self.scalar.generic),
             additional_msg=f"Can not deserialize. Error: {err.__str__()}",
         )
     return pb_obj
示例#7
0
 def __init__(
     self,
     hive_job,
     metadata,
 ):
     """
     :param _qubole.QuboleHiveJob hive_job: Hive job spec
     :param TaskMetadata metadata: This contains information needed at runtime to determine behavior such as
         whether or not outputs are discoverable, timeouts, and retries.
     """
     super(SdkHiveJob, self).__init__(
         _constants.SdkTaskType.HIVE_JOB,
         metadata,
         # Individual hive tasks never take anything, or return anything. They just run a query that's already
         # got the location set.
         _interface_model.TypedInterface({}, {}),
         _MessageToDict(hive_job),
     )
示例#8
0
 def __init__(
         self,
         task_function,
         task_type,
         discovery_version,
         retries,
         interruptible,
         deprecated,
         discoverable,
         timeout,
         workers_count,
         per_replica_storage_request,
         per_replica_cpu_request,
         per_replica_gpu_request,
         per_replica_memory_request,
         per_replica_storage_limit,
         per_replica_cpu_limit,
         per_replica_gpu_limit,
         per_replica_memory_limit,
         environment
 ):
     pytorch_job = _task_models.PyTorchJob(
         workers_count=workers_count
     ).to_flyte_idl()
     super(SdkPyTorchTask, self).__init__(
         task_function=task_function,
         task_type=task_type,
         discovery_version=discovery_version,
         retries=retries,
         interruptible=interruptible,
         deprecated=deprecated,
         storage_request=per_replica_storage_request,
         cpu_request=per_replica_cpu_request,
         gpu_request=per_replica_gpu_request,
         memory_request=per_replica_memory_request,
         storage_limit=per_replica_storage_limit,
         cpu_limit=per_replica_cpu_limit,
         gpu_limit=per_replica_gpu_limit,
         memory_limit=per_replica_memory_limit,
         discoverable=discoverable,
         timeout=timeout,
         environment=environment,
         custom=_MessageToDict(pytorch_job)
     )
示例#9
0
    def with_overrides(self,
                       new_spark_conf: typing.Dict[str, str] = None,
                       new_hadoop_conf: typing.Dict[str, str] = None):
        """
        Creates a new SparkJob instance with the modified configuration or timeouts
        """
        tk = _copy.deepcopy(self)
        tk._spark_job = self._spark_job.with_overrides(new_spark_conf,
                                                       new_hadoop_conf)
        tk._custom = _MessageToDict(tk._spark_job.to_flyte_idl())

        salt = _hashlib.md5(
            _json.dumps(tk.custom,
                        sort_keys=True).encode("utf-8")).hexdigest()
        tk._id._name = "{}-{}".format(self._id.name, salt)
        # We are overriding the platform name creation to prevent problems in dynamic
        tk.assign_name(tk._id._name)

        return tk
示例#10
0
    def promote_from_model(cls, literal_model):
        """
        Creates an object of this type from the model primitive defining it.
        :param flytekit.models.literals.Literal literal_model:
        :rtype: Protobuf
        """
        pb_obj = cls.pb_type()
        try:
            dictionary = _MessageToDict(literal_model.scalar.generic)
            pb_obj = _ParseDict(dictionary, pb_obj)
        except Error as err:
            raise _user_exceptions.FlyteTypeException(
                received_type="generic",
                expected_type=cls.pb_type,
                received_value=_base64.b64encode(literal_model.scalar.generic),
                additional_msg=f"Can not deserialize. Error: {err.__str__()}",
            )

        return cls(pb_obj)
示例#11
0
    def reconcile_partial_pod_spec_and_task(self, pod_spec,
                                            primary_container_name):
        """
        Assigns the custom field as a the reconciled primary container and pod spec defintion.
        :param _sdk_runnable.SdkRunnableTask sdk_runnable_task:
        :param generated_pb2.PodSpec pod_spec:
        :param Text primary_container_name:
        :rtype: SdkSidecarTask
        """

        # First, insert a placeholder primary container if it is not defined in the pod spec.
        containers = pod_spec.containers
        primary_exists = False
        for container in containers:
            if container.name == primary_container_name:
                primary_exists = True
                break
        if not primary_exists:
            containers.extend(
                [_k8s_pb2.Container(name=primary_container_name)])

        final_containers = []
        for container in containers:
            # In the case of the primary container, we overwrite specific container attributes with the default values
            # used in an SDK runnable task.
            if container.name == primary_container_name:
                container.image = self._container.image
                # clear existing commands
                del container.command[:]
                container.command.extend(self._container.command)
                # also clear existing args
                del container.args[:]
                container.args.extend(self._container.args)

                resource_requirements = _k8s_pb2.ResourceRequirements()
                for resource in self._container.resources.limits:
                    resource_requirements.limits[
                        _core_task.Resources.ResourceName.Name(
                            resource.name).lower()].CopyFrom(
                                _resource_pb2.Quantity(string=resource.value))
                for resource in self._container.resources.requests:
                    resource_requirements.requests[
                        _core_task.Resources.ResourceName.Name(
                            resource.name).lower()].CopyFrom(
                                _resource_pb2.Quantity(string=resource.value))
                if resource_requirements.ByteSize():
                    # Important! Only copy over resource requirements if they are non-empty.
                    container.resources.CopyFrom(resource_requirements)

                del container.env[:]
                container.env.extend([
                    _k8s_pb2.EnvVar(name=key, value=val)
                    for key, val in _six.iteritems(self._container.env)
                ])

            final_containers.append(container)

        del pod_spec.containers[:]
        pod_spec.containers.extend(final_containers)

        sidecar_job_plugin = _task_models.SidecarJob(
            pod_spec=pod_spec,
            primary_container_name=primary_container_name,
        ).to_flyte_idl()

        self.assign_custom_and_return(_MessageToDict(sidecar_job_plugin))
示例#12
0
 def to_literal(self, ctx: FlyteContext, python_val: T,
                python_type: Type[T], expected: LiteralType) -> Literal:
     struct = Struct()
     struct.update(_MessageToDict(python_val))
     return Literal(scalar=Scalar(generic=struct))
示例#13
0
    def __init__(
        self,
        statement,
        output_schema,
        routing_group=None,
        catalog=None,
        schema=None,
        task_inputs=None,
        interruptible=False,
        discoverable=False,
        discovery_version=None,
        retries=1,
        timeout=None,
        deprecated=None,
    ):
        """
        :param Text statement: Presto query specification
        :param flytekit.common.types.schema.Schema output_schema: Schema that represents that data queried from Presto
        :param Text routing_group: The routing group that a Presto query should be sent to for the given environment
        :param Text catalog: The catalog to set for the given Presto query
        :param Text schema: The schema to set for the given Presto query
        :param dict[Text,flytekit.common.types.base_sdk_types.FlyteSdkType] task_inputs: Optional inputs to the Presto task
        :param bool discoverable:
        :param Text discovery_version: String describing the version for task discovery purposes
        :param int retries: Number of retries to attempt
        :param datetime.timedelta timeout:
        :param Text deprecated: This string can be used to mark the task as deprecated.  Consumers of the task will
            receive deprecation warnings.
        """

        # Set as class fields which are used down below to configure implicit
        # parameters
        self._routing_group = routing_group or ""
        self._catalog = catalog or ""
        self._schema = schema or ""

        metadata = _task_model.TaskMetadata(
            discoverable,
            # This needs to have the proper version reflected in it
            _task_model.RuntimeMetadata(_task_model.RuntimeMetadata.RuntimeType.FLYTE_SDK, __version__, "python"),
            timeout or _datetime.timedelta(seconds=0),
            _literals.RetryStrategy(retries),
            interruptible,
            discovery_version,
            deprecated,
        )

        presto_query = _presto_models.PrestoQuery(
            routing_group=routing_group or "",
            catalog=catalog or "",
            schema=schema or "",
            statement=statement,
        )

        # Here we set the routing_group, catalog, and schema as implicit
        # parameters for caching purposes
        i = _interface.TypedInterface(
            {
                "__implicit_routing_group": _interface_model.Variable(
                    type=_types.LiteralType(simple=_types.SimpleType.STRING),
                    description="The routing group set as an implicit input",
                ),
                "__implicit_catalog": _interface_model.Variable(
                    type=_types.LiteralType(simple=_types.SimpleType.STRING),
                    description="The catalog set as an implicit input",
                ),
                "__implicit_schema": _interface_model.Variable(
                    type=_types.LiteralType(simple=_types.SimpleType.STRING),
                    description="The schema set as an implicit input",
                ),
            },
            {
                # Set the schema for the Presto query as an output
                "results": _interface_model.Variable(
                    type=_types.LiteralType(schema=output_schema.schema_type),
                    description="The schema for the Presto query",
                )
            },
        )

        super(SdkPrestoTask, self).__init__(
            _constants.SdkTaskType.PRESTO_TASK,
            metadata,
            i,
            _MessageToDict(presto_query.to_flyte_idl()),
        )

        # Set user provided inputs
        task_inputs(self)
示例#14
0
    def __init__(
        self,
        task_type,
        discovery_version,
        retries,
        interruptible,
        task_inputs,
        deprecated,
        discoverable,
        timeout,
        spark_type,
        main_class,
        main_application_file,
        spark_conf,
        hadoop_conf,
        environment,
    ):
        """
        :param Text task_type: string describing the task type
        :param Text discovery_version: string describing the version for task discovery purposes
        :param int retries: Number of retries to attempt
        :param bool interruptible: Whether or not task is interruptible
        :param Text deprecated:
        :param bool discoverable:
        :param datetime.timedelta timeout:
        :param Text spark_type: Type of Spark Job: Scala/Java
        :param Text main_class: Main class to execute for Scala/Java jobs
        :param Text main_application_file: Main application file
        :param dict[Text,Text] spark_conf:
        :param dict[Text,Text] hadoop_conf:
        :param dict[Text,Text] environment: [optional] environment variables to set when executing this task.
        """

        spark_job = _task_models.SparkJob(
            spark_conf=spark_conf,
            hadoop_conf=hadoop_conf,
            spark_type=spark_type,
            application_file=main_application_file,
            main_class=main_class,
            executor_path=_sys.executable,
        ).to_flyte_idl()

        super(SdkGenericSparkTask, self).__init__(
            task_type,
            _task_models.TaskMetadata(
                discoverable,
                _task_models.RuntimeMetadata(
                    _task_models.RuntimeMetadata.RuntimeType.FLYTE_SDK,
                    __version__, 'spark'), timeout,
                _literal_models.RetryStrategy(retries), interruptible,
                discovery_version, deprecated),
            _interface.TypedInterface({}, {}),
            _MessageToDict(spark_job),
        )

        # Add Inputs
        if task_inputs is not None:
            task_inputs(self)

        # Container after the Inputs have been updated.
        self._container = self._get_container_definition(
            environment=environment)
示例#15
0
    def __init__(
        self,
        task_type,
        discovery_version,
        retries,
        interruptible,
        task_inputs,
        deprecated,
        discoverable,
        timeout,
        jar_file,
        main_class,
        args,
        flink_properties,
        environment,
    ):
        """
        :param Text task_type: string describing the task type
        :param Text discovery_version: string describing the version for task discovery purposes
        :param int retries: Number of retries to attempt
        :param bool interruptible: Whether or not task is interruptible
        :param Text deprecated:
        :param bool discoverable:
        :param datetime.timedelta timeout:
        :param Text main_class: Main class to execute for Scala/Java jobs
        :param Text jar_file: fat jar file
        :param dict[Text,Text] flink_properties:
        :param dict[Text,Text] environment: [optional] environment variables to set when executing this task.
        """

        flink_job = _task_models.FlinkJob(flink_properties=flink_properties,
                                          jar_file=jar_file,
                                          main_class=main_class,
                                          args=args).to_flyte_idl()

        super(SdkGenericFlinkTask, self).__init__(
            task_type,
            _task_models.TaskMetadata(
                discoverable,
                _task_models.RuntimeMetadata(
                    _task_models.RuntimeMetadata.RuntimeType.FLYTE_SDK,
                    __version__,
                    "flink",
                ),
                timeout,
                _literal_models.RetryStrategy(retries),
                interruptible,
                discovery_version,
                deprecated,
            ),
            _interface.TypedInterface({}, {}),
            _MessageToDict(flink_job),
        )

        # Add Inputs
        if task_inputs is not None:
            task_inputs(self)

        # Container after the Inputs have been updated.
        self._container = self._get_container_definition(
            environment=environment)