def __init__(self, pb_object: Union[GeneratedProtocolMessageType, FlyteIdlEntity]): """ :param Union[T, FlyteIdlEntity] pb_object: """ struct = Struct() v = pb_object # This section converts an existing proto object (or a subclass of) to the right type expected by this instance # of GenericProto. GenericProto can be used with any protobuf type (not restricted to FlyteType). This makes it # a bit tricky to figure out the right version of the underlying raw proto class to use to populate the final # struct. # If the provided object has to_flyte_idl(), call it to produce a raw proto. if isinstance(pb_object, FlyteIdlEntity): v = pb_object.to_flyte_idl() # A check to ensure the raw proto (v) is of the correct expected type. This also performs one final attempt to # convert it to the correct type by leveraging from_flyte_idl (implemented by all FlyteTypes) in case this class # is initialized with one. expected_type = type(self).pb_type if expected_type != type(v) and expected_type != type(pb_object): if isinstance(type(self).pb_type, FlyteType): v = expected_type.from_flyte_idl(v).to_flyte_idl() else: raise _user_exceptions.FlyteTypeException( received_type=type(pb_object), expected_type=expected_type, received_value=pb_object ) struct.update(_MessageToDict(v)) super().__init__(scalar=_literals.Scalar(generic=struct,))
def to_python_value(self, ctx: FlyteContext, lv: Literal, expected_python_type: Type[T]) -> T: if not (lv and lv.scalar and lv.scalar.generic): raise AssertionError("Can only covert a generic literal to a Protobuf") pb_obj = expected_python_type() dictionary = _MessageToDict(lv.scalar.generic) pb_obj = _ParseDict(dictionary, pb_obj) return pb_obj
def __init__( self, task_function, task_type, discovery_version, retries, interruptible, deprecated, discoverable, timeout, spark_conf, hadoop_conf, environment, ): """ :param task_function: Function container user code. This will be executed via the SDK's engine. :param Text task_type: string describing the task type :param Text discovery_version: string describing the version for task discovery purposes :param int retries: Number of retries to attempt :param bool interruptible: Whether or not task is interruptible :param Text deprecated: :param bool discoverable: :param datetime.timedelta timeout: :param dict[Text,Text] spark_conf: :param dict[Text,Text] hadoop_conf: :param dict[Text,Text] environment: [optional] environment variables to set when executing this task. """ spark_exec_path = _os.path.abspath(_entrypoint.__file__) if spark_exec_path.endswith('.pyc'): spark_exec_path = spark_exec_path[:-1] spark_job = _task_models.SparkJob( spark_conf=spark_conf, hadoop_conf=hadoop_conf, application_file="local://" + spark_exec_path, executor_path=_sys.executable, ).to_flyte_idl() super(SdkSparkTask, self).__init__( task_function, task_type, discovery_version, retries, interruptible, deprecated, "", "", "", "", "", "", "", "", discoverable, timeout, environment, _MessageToDict(spark_job), )
def __init__( self, task_function, task_type, discovery_version, retries, deprecated, discoverable, timeout, image, num_ps, replicas, command, args, volumeClaimName, environment ): """ :param task_function: Function container user code. This will be executed via the SDK's engine. :param Text task_type: string describing the task type :param Text discovery_version: string describing the version for task discovery purposes :param int retries: Number of retries to attempt :param Text deprecated: :param bool discoverable: :param datetime.timedelta timeout: :param dict[Text, Text] spark_conf: :param dict[Text, Text] hadoop_conf: :param dict[Text, Text] environment: [optional] environment variables to set when executing this task. """ tfjob = _tfjob_model.TFJob( image=image, num_ps=num_ps, replicas=replicas, command=command, args=args, volumeClaimName=volumeClaimName, ).to_flyte_idl() super(SdkTFJobTask, self).__init__( task_function, task_type, discovery_version, retries, deprecated, "", "", "", "", "", "", "", "", discoverable, timeout, environment, _MessageToDict(tfjob), )
def __init__( self, task_function, task_type, cache_version, retries, interruptible, deprecated, cache, timeout, workers_count, ps_replicas_count, chief_replicas_count, per_replica_storage_request, per_replica_cpu_request, per_replica_gpu_request, per_replica_memory_request, per_replica_storage_limit, per_replica_cpu_limit, per_replica_gpu_limit, per_replica_memory_limit, environment, ): tensorflow_job = _task_models.TensorFlowJob( workers_count=workers_count, ps_replicas_count=ps_replicas_count, chief_replicas_count=chief_replicas_count).to_flyte_idl() super(SdkTensorFlowTask, self).__init__( task_function=task_function, task_type=task_type, discovery_version=cache_version, retries=retries, interruptible=interruptible, deprecated=deprecated, storage_request=per_replica_storage_request, cpu_request=per_replica_cpu_request, gpu_request=per_replica_gpu_request, memory_request=per_replica_memory_request, storage_limit=per_replica_storage_limit, cpu_limit=per_replica_cpu_limit, gpu_limit=per_replica_gpu_limit, memory_limit=per_replica_memory_limit, discoverable=cache, timeout=timeout, environment=environment, custom=_MessageToDict(tensorflow_job), )
def to_python_std(self): """ :returns: The protobuf object as defined by the user. :rtype: T """ pb_obj = type(self).pb_type() try: dictionary = _MessageToDict(self.scalar.generic) pb_obj = _ParseDict(dictionary, pb_obj) except Error as err: raise _user_exceptions.FlyteTypeException( received_type="generic", expected_type=type(self).pb_type, received_value=_base64.b64encode(self.scalar.generic), additional_msg=f"Can not deserialize. Error: {err.__str__()}", ) return pb_obj
def __init__( self, hive_job, metadata, ): """ :param _qubole.QuboleHiveJob hive_job: Hive job spec :param TaskMetadata metadata: This contains information needed at runtime to determine behavior such as whether or not outputs are discoverable, timeouts, and retries. """ super(SdkHiveJob, self).__init__( _constants.SdkTaskType.HIVE_JOB, metadata, # Individual hive tasks never take anything, or return anything. They just run a query that's already # got the location set. _interface_model.TypedInterface({}, {}), _MessageToDict(hive_job), )
def __init__( self, task_function, task_type, discovery_version, retries, interruptible, deprecated, discoverable, timeout, workers_count, per_replica_storage_request, per_replica_cpu_request, per_replica_gpu_request, per_replica_memory_request, per_replica_storage_limit, per_replica_cpu_limit, per_replica_gpu_limit, per_replica_memory_limit, environment ): pytorch_job = _task_models.PyTorchJob( workers_count=workers_count ).to_flyte_idl() super(SdkPyTorchTask, self).__init__( task_function=task_function, task_type=task_type, discovery_version=discovery_version, retries=retries, interruptible=interruptible, deprecated=deprecated, storage_request=per_replica_storage_request, cpu_request=per_replica_cpu_request, gpu_request=per_replica_gpu_request, memory_request=per_replica_memory_request, storage_limit=per_replica_storage_limit, cpu_limit=per_replica_cpu_limit, gpu_limit=per_replica_gpu_limit, memory_limit=per_replica_memory_limit, discoverable=discoverable, timeout=timeout, environment=environment, custom=_MessageToDict(pytorch_job) )
def with_overrides(self, new_spark_conf: typing.Dict[str, str] = None, new_hadoop_conf: typing.Dict[str, str] = None): """ Creates a new SparkJob instance with the modified configuration or timeouts """ tk = _copy.deepcopy(self) tk._spark_job = self._spark_job.with_overrides(new_spark_conf, new_hadoop_conf) tk._custom = _MessageToDict(tk._spark_job.to_flyte_idl()) salt = _hashlib.md5( _json.dumps(tk.custom, sort_keys=True).encode("utf-8")).hexdigest() tk._id._name = "{}-{}".format(self._id.name, salt) # We are overriding the platform name creation to prevent problems in dynamic tk.assign_name(tk._id._name) return tk
def promote_from_model(cls, literal_model): """ Creates an object of this type from the model primitive defining it. :param flytekit.models.literals.Literal literal_model: :rtype: Protobuf """ pb_obj = cls.pb_type() try: dictionary = _MessageToDict(literal_model.scalar.generic) pb_obj = _ParseDict(dictionary, pb_obj) except Error as err: raise _user_exceptions.FlyteTypeException( received_type="generic", expected_type=cls.pb_type, received_value=_base64.b64encode(literal_model.scalar.generic), additional_msg=f"Can not deserialize. Error: {err.__str__()}", ) return cls(pb_obj)
def reconcile_partial_pod_spec_and_task(self, pod_spec, primary_container_name): """ Assigns the custom field as a the reconciled primary container and pod spec defintion. :param _sdk_runnable.SdkRunnableTask sdk_runnable_task: :param generated_pb2.PodSpec pod_spec: :param Text primary_container_name: :rtype: SdkSidecarTask """ # First, insert a placeholder primary container if it is not defined in the pod spec. containers = pod_spec.containers primary_exists = False for container in containers: if container.name == primary_container_name: primary_exists = True break if not primary_exists: containers.extend( [_k8s_pb2.Container(name=primary_container_name)]) final_containers = [] for container in containers: # In the case of the primary container, we overwrite specific container attributes with the default values # used in an SDK runnable task. if container.name == primary_container_name: container.image = self._container.image # clear existing commands del container.command[:] container.command.extend(self._container.command) # also clear existing args del container.args[:] container.args.extend(self._container.args) resource_requirements = _k8s_pb2.ResourceRequirements() for resource in self._container.resources.limits: resource_requirements.limits[ _core_task.Resources.ResourceName.Name( resource.name).lower()].CopyFrom( _resource_pb2.Quantity(string=resource.value)) for resource in self._container.resources.requests: resource_requirements.requests[ _core_task.Resources.ResourceName.Name( resource.name).lower()].CopyFrom( _resource_pb2.Quantity(string=resource.value)) if resource_requirements.ByteSize(): # Important! Only copy over resource requirements if they are non-empty. container.resources.CopyFrom(resource_requirements) del container.env[:] container.env.extend([ _k8s_pb2.EnvVar(name=key, value=val) for key, val in _six.iteritems(self._container.env) ]) final_containers.append(container) del pod_spec.containers[:] pod_spec.containers.extend(final_containers) sidecar_job_plugin = _task_models.SidecarJob( pod_spec=pod_spec, primary_container_name=primary_container_name, ).to_flyte_idl() self.assign_custom_and_return(_MessageToDict(sidecar_job_plugin))
def to_literal(self, ctx: FlyteContext, python_val: T, python_type: Type[T], expected: LiteralType) -> Literal: struct = Struct() struct.update(_MessageToDict(python_val)) return Literal(scalar=Scalar(generic=struct))
def __init__( self, statement, output_schema, routing_group=None, catalog=None, schema=None, task_inputs=None, interruptible=False, discoverable=False, discovery_version=None, retries=1, timeout=None, deprecated=None, ): """ :param Text statement: Presto query specification :param flytekit.common.types.schema.Schema output_schema: Schema that represents that data queried from Presto :param Text routing_group: The routing group that a Presto query should be sent to for the given environment :param Text catalog: The catalog to set for the given Presto query :param Text schema: The schema to set for the given Presto query :param dict[Text,flytekit.common.types.base_sdk_types.FlyteSdkType] task_inputs: Optional inputs to the Presto task :param bool discoverable: :param Text discovery_version: String describing the version for task discovery purposes :param int retries: Number of retries to attempt :param datetime.timedelta timeout: :param Text deprecated: This string can be used to mark the task as deprecated. Consumers of the task will receive deprecation warnings. """ # Set as class fields which are used down below to configure implicit # parameters self._routing_group = routing_group or "" self._catalog = catalog or "" self._schema = schema or "" metadata = _task_model.TaskMetadata( discoverable, # This needs to have the proper version reflected in it _task_model.RuntimeMetadata(_task_model.RuntimeMetadata.RuntimeType.FLYTE_SDK, __version__, "python"), timeout or _datetime.timedelta(seconds=0), _literals.RetryStrategy(retries), interruptible, discovery_version, deprecated, ) presto_query = _presto_models.PrestoQuery( routing_group=routing_group or "", catalog=catalog or "", schema=schema or "", statement=statement, ) # Here we set the routing_group, catalog, and schema as implicit # parameters for caching purposes i = _interface.TypedInterface( { "__implicit_routing_group": _interface_model.Variable( type=_types.LiteralType(simple=_types.SimpleType.STRING), description="The routing group set as an implicit input", ), "__implicit_catalog": _interface_model.Variable( type=_types.LiteralType(simple=_types.SimpleType.STRING), description="The catalog set as an implicit input", ), "__implicit_schema": _interface_model.Variable( type=_types.LiteralType(simple=_types.SimpleType.STRING), description="The schema set as an implicit input", ), }, { # Set the schema for the Presto query as an output "results": _interface_model.Variable( type=_types.LiteralType(schema=output_schema.schema_type), description="The schema for the Presto query", ) }, ) super(SdkPrestoTask, self).__init__( _constants.SdkTaskType.PRESTO_TASK, metadata, i, _MessageToDict(presto_query.to_flyte_idl()), ) # Set user provided inputs task_inputs(self)
def __init__( self, task_type, discovery_version, retries, interruptible, task_inputs, deprecated, discoverable, timeout, spark_type, main_class, main_application_file, spark_conf, hadoop_conf, environment, ): """ :param Text task_type: string describing the task type :param Text discovery_version: string describing the version for task discovery purposes :param int retries: Number of retries to attempt :param bool interruptible: Whether or not task is interruptible :param Text deprecated: :param bool discoverable: :param datetime.timedelta timeout: :param Text spark_type: Type of Spark Job: Scala/Java :param Text main_class: Main class to execute for Scala/Java jobs :param Text main_application_file: Main application file :param dict[Text,Text] spark_conf: :param dict[Text,Text] hadoop_conf: :param dict[Text,Text] environment: [optional] environment variables to set when executing this task. """ spark_job = _task_models.SparkJob( spark_conf=spark_conf, hadoop_conf=hadoop_conf, spark_type=spark_type, application_file=main_application_file, main_class=main_class, executor_path=_sys.executable, ).to_flyte_idl() super(SdkGenericSparkTask, self).__init__( task_type, _task_models.TaskMetadata( discoverable, _task_models.RuntimeMetadata( _task_models.RuntimeMetadata.RuntimeType.FLYTE_SDK, __version__, 'spark'), timeout, _literal_models.RetryStrategy(retries), interruptible, discovery_version, deprecated), _interface.TypedInterface({}, {}), _MessageToDict(spark_job), ) # Add Inputs if task_inputs is not None: task_inputs(self) # Container after the Inputs have been updated. self._container = self._get_container_definition( environment=environment)
def __init__( self, task_type, discovery_version, retries, interruptible, task_inputs, deprecated, discoverable, timeout, jar_file, main_class, args, flink_properties, environment, ): """ :param Text task_type: string describing the task type :param Text discovery_version: string describing the version for task discovery purposes :param int retries: Number of retries to attempt :param bool interruptible: Whether or not task is interruptible :param Text deprecated: :param bool discoverable: :param datetime.timedelta timeout: :param Text main_class: Main class to execute for Scala/Java jobs :param Text jar_file: fat jar file :param dict[Text,Text] flink_properties: :param dict[Text,Text] environment: [optional] environment variables to set when executing this task. """ flink_job = _task_models.FlinkJob(flink_properties=flink_properties, jar_file=jar_file, main_class=main_class, args=args).to_flyte_idl() super(SdkGenericFlinkTask, self).__init__( task_type, _task_models.TaskMetadata( discoverable, _task_models.RuntimeMetadata( _task_models.RuntimeMetadata.RuntimeType.FLYTE_SDK, __version__, "flink", ), timeout, _literal_models.RetryStrategy(retries), interruptible, discovery_version, deprecated, ), _interface.TypedInterface({}, {}), _MessageToDict(flink_job), ) # Add Inputs if task_inputs is not None: task_inputs(self) # Container after the Inputs have been updated. self._container = self._get_container_definition( environment=environment)