def dynamic_execute(self, task_function: Callable, **kwargs) -> Any: """ By the time this function is invoked, the local_execute function should have unwrapped the Promises and Flyte literal wrappers so that the kwargs we are working with here are now Python native literal values. This function is also expected to return Python native literal values. Since the user code within a dynamic task constitute a workflow, we have to first compile the workflow, and then execute that workflow. When running for real in production, the task would stop after the compilation step, and then create a file representing that newly generated workflow, instead of executing it. """ ctx = FlyteContextManager.current_context() if ctx.execution_state and ctx.execution_state.mode == ExecutionState.Mode.LOCAL_WORKFLOW_EXECUTION: updated_exec_state = ctx.execution_state.with_params( mode=ExecutionState.Mode.TASK_EXECUTION) with FlyteContextManager.with_context( ctx.with_execution_state(updated_exec_state)): logger.info("Executing Dynamic workflow, using raw inputs") return exception_scopes.user_entry_point(task_function)( **kwargs) if ctx.execution_state and ctx.execution_state.mode == ExecutionState.Mode.TASK_EXECUTION: return self.compile_into_workflow(ctx, task_function, **kwargs) if ctx.execution_state and ctx.execution_state.mode == ExecutionState.Mode.LOCAL_TASK_EXECUTION: return exception_scopes.user_entry_point(task_function)(**kwargs) raise ValueError( f"Invalid execution provided, execution state: {ctx.execution_state}" )
def patch(target: Union[PythonTask, WorkflowBase, ReferenceEntity]): """ This is a decorator used for testing. """ if (not isinstance(target, PythonTask) and not isinstance(target, WorkflowBase) and not isinstance(target, ReferenceEntity)): raise Exception( "Can only use mocks on tasks/workflows declared in Python.") logger.info( "When using this patch function on Flyte entities, please be aware weird issues may arise if also" "using mock.patch on internal Flyte classes like PythonFunctionWorkflow. See" "https://github.com/flyteorg/flyte/issues/854 for more information") def wrapper(test_fn): def new_test(*args, **kwargs): logger.warning(f"Invoking mock method for target: '{target.name}'") m = MagicMock() saved = target.execute target.execute = m results = test_fn(m, *args, **kwargs) target.execute = saved return results return new_test return wrapper
def dynamic_execute(self, task_function: Callable, **kwargs) -> Any: """ By the time this function is invoked, the _local_execute function should have unwrapped the Promises and Flyte literal wrappers so that the kwargs we are working with here are now Python native literal values. This function is also expected to return Python native literal values. Since the user code within a dynamic task constitute a workflow, we have to first compile the workflow, and then execute that workflow. When running for real in production, the task would stop after the compilation step, and then create a file representing that newly generated workflow, instead of executing it. """ ctx = FlyteContextManager.current_context() if ctx.execution_state and ctx.execution_state.mode == ExecutionState.Mode.LOCAL_WORKFLOW_EXECUTION: updated_exec_state = ctx.execution_state.with_params( mode=ExecutionState.Mode.TASK_EXECUTION) with FlyteContextManager.with_context( ctx.with_execution_state(updated_exec_state)): logger.info("Executing Dynamic workflow, using raw inputs") return task_function(**kwargs) if ctx.execution_state and ctx.execution_state.mode == ExecutionState.Mode.TASK_EXECUTION: is_fast_execution = bool( ctx.execution_state and ctx.execution_state.additional_context and ctx.execution_state.additional_context.get( "dynamic_addl_distro")) if is_fast_execution: ctx = ctx.with_serialization_settings( SerializationSettings.new_builder( ).with_fast_serialization_settings( FastSerializationSettings(enabled=True)).build()) return self.compile_into_workflow(ctx, task_function, **kwargs)
def get_protocol(url: str): # copy from fsspec https://github.com/fsspec/filesystem_spec/blob/fe09da6942ad043622212927df7442c104fe7932/fsspec/utils.py#L387-L391 parts = re.split(r"(\:\:|\://)", url, 1) if len(parts) > 1: return parts[0] logger.info("Setting protocol to file") return "file"
def check_call(cmd_args, **kwargs): if not isinstance(cmd_args, list): cmd_args = _schlex.split(cmd_args) # Jupyter notebooks hijack I/O and thus we cannot dump directly to stdout. with _tempfile.TemporaryFile() as std_out: with _tempfile.TemporaryFile() as std_err: ret_code = _subprocess.Popen(cmd_args, stdout=std_out, stderr=std_err, **kwargs).wait() # Dump sub-process' std out into current std out std_out.seek(0) logger.info("Output of command '{}':\n{}\n".format( cmd_args, std_out.read())) if ret_code != 0: std_err.seek(0) err_str = std_err.read() logger.error("Error from command '{}':\n{}\n".format( cmd_args, err_str)) raise Exception( "Called process exited with error code: {}. Stderr dump:\n\n{}" .format(ret_code, err_str)) return 0
def _register(): logger.info( "Registering fsspec known implementations and overriding all default implementations for persistence." ) DataPersistencePlugins.register_plugin("/", FSSpecPersistence, force=True) for k, v in known_implementations.items(): DataPersistencePlugins.register_plugin(f"{k}://", FSSpecPersistence, force=True)
def __exit__(self, exc_type, exc_val, exc_tb): end_wall_time = _time.perf_counter() end_process_time = _time.process_time() logger.info( "Exiting timed context: {} [Wall Time: {}s, Process Time: {}s]". format( self._context_statement, end_wall_time - self._start_wall_time, end_process_time - self._start_process_time, ))
def _output_deck(task_name: str, new_user_params: ExecutionParameters): ctx = FlyteContext.current_context() if ctx.execution_state.mode == ExecutionState.Mode.TASK_EXECUTION: output_dir = ctx.execution_state.engine_dir else: output_dir = ctx.file_access.get_random_local_directory() deck_path = os.path.join(output_dir, DECK_FILE_NAME) with open(deck_path, "w") as f: f.write(_get_deck(new_user_params)) logger.info( f"{task_name} task creates flyte deck html to file://{deck_path}")
def execute(self, **kwargs) -> typing.Any: """ Executes the given script by substituting the inputs and outputs and extracts the outputs from the filesystem """ logger.info(f"Running shell script as type {self.task_type}") if self.script_file: with open(self.script_file) as f: self._script = f.read() outputs: typing.Dict[str, str] = {} if self._output_locs: for v in self._output_locs: outputs[v.var] = self._interpolizer.interpolate(v.location, inputs=kwargs) if os.name == "nt": self._script = self._script.lstrip().rstrip().replace("\n", "&&") if "env" in kwargs and isinstance(kwargs["env"], dict): kwargs["export_env"] = self.make_export_string_from_env_dict( kwargs["env"]) gen_script = self._interpolizer.interpolate(self._script, inputs=kwargs, outputs=outputs) if self._debug: print("\n==============================================\n") print(gen_script) print("\n==============================================\n") try: subprocess.check_call(gen_script, shell=True) except subprocess.CalledProcessError as e: files = os.listdir(".") fstr = "\n-".join(files) logger.error( f"Failed to Execute Script, return-code {e.returncode} \n" f"StdErr: {e.stderr}\n" f"StdOut: {e.stdout}\n" f" Current directory contents: .\n-{fstr}") raise final_outputs = [] for v in self._output_locs: if issubclass(v.var_type, FlyteFile): final_outputs.append(FlyteFile(outputs[v.var])) if issubclass(v.var_type, FlyteDirectory): final_outputs.append(FlyteDirectory(outputs[v.var])) if len(final_outputs) == 1: return final_outputs[0] if len(final_outputs) > 1: return tuple(final_outputs) return None
def load_task(self, loader_args: List[str]) -> ExecutableTemplateShimTask: logger.info(f"Task template loader args: {loader_args}") ctx = FlyteContext.current_context() task_template_local_path = os.path.join( ctx.execution_state.working_dir, "task_template.pb") ctx.file_access.get_data(loader_args[0], task_template_local_path) task_template_proto = common_utils.load_proto_from_file( _tasks_pb2.TaskTemplate, task_template_local_path) task_template_model = _task_model.TaskTemplate.from_flyte_idl( task_template_proto) executor_class = load_object_from_module(loader_args[1]) return ExecutableTemplateShimTask(task_template_model, executor_class)
def _finder(handler_map, df_type: Type, protocol: str, format: str): try: return handler_map[df_type][protocol][format] except KeyError: try: hh = handler_map[df_type][protocol][""] logger.info( f"Didn't find format specific handler {type(handler_map)} for protocol {protocol}" f" format {format}, using default instead.") return hh except KeyError: ... raise ValueError( f"Failed to find a handler for {df_type}, protocol {protocol}, fmt {format}" )
def unwrap_literal_map_and_execute( self, ctx: FlyteContext, input_literal_map: _literal_models.LiteralMap ) -> Union[VoidPromise, _literal_models.LiteralMap, _dynamic_job.DynamicJobSpec]: """ Please see the implementation of the dispatch_execute function in the real task. """ # Invoked before the task is executed # Translate the input literals to Python native native_inputs = TypeEngine.literal_map_to_kwargs( ctx, input_literal_map, self.python_interface.inputs) logger.info(f"Invoking {self.name} with inputs: {native_inputs}") try: native_outputs = self.execute(**native_inputs) except Exception as e: logger.exception(f"Exception when executing {e}") raise e logger.info( f"Task executed successfully in user level, outputs: {native_outputs}" ) expected_output_names = list(self.python_interface.outputs.keys()) if len(expected_output_names) == 1: native_outputs_as_map = {expected_output_names[0]: native_outputs} elif len(expected_output_names) == 0: native_outputs_as_map = {} else: native_outputs_as_map = { expected_output_names[i]: native_outputs[i] for i, _ in enumerate(native_outputs) } # We manually construct a LiteralMap here because task inputs and outputs actually violate the assumption # built into the IDL that all the values of a literal map are of the same type. literals = {} for k, v in native_outputs_as_map.items(): literal_type = self.interface.outputs[k].type py_type = self.python_interface.outputs[k] if isinstance(v, tuple): raise AssertionError( f"Output({k}) in task{self.name} received a tuple {v}, instead of {py_type}" ) literals[k] = TypeEngine.to_literal(ctx, v, py_type, literal_type) outputs_literal_map = _literal_models.LiteralMap(literals=literals) # After the execute has been successfully completed return outputs_literal_map
def _local_execute( self, ctx: FlyteContext, **kwargs) -> Union[Tuple[Promise], Promise, VoidPromise]: """ Performs local execution of a workflow. kwargs are expected to be Promises for the most part (unless, someone has hardcoded in my_wf(input_1=5) or something). :param ctx: The FlyteContext :param kwargs: parameters for the workflow itself """ logger.info( f"Executing Workflow {self._name}, ctx{ctx.execution_state.Mode}") # This is done to support the invariant that Workflow local executions always work with Promise objects # holding Flyte literal values. Even in a wf, a user can call a sub-workflow with a Python native value. for k, v in kwargs.items(): if not isinstance(v, Promise): t = self._native_interface.inputs[k] kwargs[k] = Promise(var=k, val=TypeEngine.to_literal( ctx, v, t, self.interface.inputs[k].type)) function_outputs = self.execute(**kwargs) if (isinstance(function_outputs, VoidPromise) or function_outputs is None or len(self.python_interface.outputs) == 0): # The reason this is here is because a workflow function may return a task that doesn't return anything # def wf(): # return t1() # or it may not return at all # def wf(): # t1() # In the former case we get the task's VoidPromise, in the latter we get None return VoidPromise(self.name) # TODO: Can we refactor the task code to be similar to what's in this function? promises = _workflow_fn_outputs_to_promise( ctx, self._native_interface.outputs, self.interface.outputs, function_outputs) # TODO: With the native interface, create_task_output should be able to derive the typed interface, and it # should be able to do the conversion of the output of the execute() call directly. return create_task_output(promises, self._native_interface)
def pre_execute(self, user_params: ExecutionParameters) -> ExecutionParameters: """ Pre-execute for Sagemaker will automatically add the distributed context to the execution params, only if the number of execution instances is > 1. Otherwise this is considered to be a single node execution """ if self._is_distributed(): logger.info("Distributed context detected!") exec_state = FlyteContextManager.current_context().execution_state if exec_state and exec_state.mode == ExecutionState.Mode.TASK_EXECUTION: """ This mode indicates we are actually in a remote execute environment (within sagemaker in this case) """ dist_ctx = DistributedTrainingContext.from_env() else: dist_ctx = DistributedTrainingContext.local_execute() return user_params.builder().add_attr( "DISTRIBUTED_TRAINING_CONTEXT", dist_ctx).build() return user_params
def register( registrable_entities: typing.List[RegistrableEntity], project: str, domain: str, version: str, client: SynchronousFlyteClient, ): # The incoming registrable entities are already in base protobuf form, not model form, so we use the # raw client's methods instead of the friendly client's methods by calling super for admin_entity in registrable_entities: try: if isinstance(admin_entity, _idl_admin_TaskSpec): ident, task_spec = hydrate_registration_parameters( identifier_pb2.TASK, project, domain, version, admin_entity) logger.debug(f"Creating task {ident}") super(SynchronousFlyteClient, client).create_task( TaskCreateRequest(id=ident, spec=task_spec)) elif isinstance(admin_entity, _idl_admin_WorkflowSpec): ident, wf_spec = hydrate_registration_parameters( identifier_pb2.WORKFLOW, project, domain, version, admin_entity) logger.debug(f"Creating workflow {ident}") super(SynchronousFlyteClient, client).create_workflow( WorkflowCreateRequest(id=ident, spec=wf_spec)) elif isinstance(admin_entity, _idl_admin_LaunchPlan): ident, admin_lp = hydrate_registration_parameters( identifier_pb2.LAUNCH_PLAN, project, domain, version, admin_entity) logger.debug(f"Creating launch plan {ident}") super(SynchronousFlyteClient, client).create_launch_plan( LaunchPlanCreateRequest(id=ident, spec=admin_lp.spec)) else: raise AssertionError( f"Unknown entity of type {type(admin_entity)}") except FlyteEntityAlreadyExistsException: logger.info(f"{admin_entity} already exists") except Exception as e: logger.info( f"Failed to register entity {admin_entity} with error {e}") raise e
def dynamic_execute(self, task_function: Callable, **kwargs) -> Any: """ By the time this function is invoked, the _local_execute function should have unwrapped the Promises and Flyte literal wrappers so that the kwargs we are working with here are now Python native literal values. This function is also expected to return Python native literal values. Since the user code within a dynamic task constitute a workflow, we have to first compile the workflow, and then execute that workflow. When running for real in production, the task would stop after the compilation step, and then create a file representing that newly generated workflow, instead of executing it. """ ctx = FlyteContext.current_context() if ctx.execution_state and ctx.execution_state.mode == ExecutionState.Mode.LOCAL_WORKFLOW_EXECUTION: with ctx.new_execution_context(ExecutionState.Mode.TASK_EXECUTION): logger.info("Executing Dynamic workflow, using raw inputs") return task_function(**kwargs) if ctx.execution_state and ctx.execution_state.mode == ExecutionState.Mode.TASK_EXECUTION: return self.compile_into_workflow(ctx, task_function, **kwargs)
def post_execute(self, user_params: ExecutionParameters, rval: Any) -> Any: """ In the case of distributed execution, we check the should_persist_predicate in the configuration to determine if the output should be persisted. This is because in distributed training, multiple nodes may produce partial outputs and only the user process knows the output that should be generated. They can control the choice using the predicate. To control if output is generated across every execution, we override the post_execute method and sometimes return a None """ if self._is_distributed(): logger.info("Distributed context detected!") dctx = flytekit.current_context().distributed_training_context if not self.task_config.should_persist_output(dctx): logger.info( "output persistence predicate not met, Flytekit will ignore outputs" ) raise IgnoreOutputs( f"Distributed context - Persistence predicate not met. Ignoring outputs - {dctx}" ) return rval
def execute(self, **kwargs) -> Any: """ TODO: Figure out how to share FlyteContext ExecutionParameters with the notebook kernel (as notebook kernel is executed in a separate python process) For Spark, the notebooks today need to use the new_session or just getOrCreate session and get a handle to the singleton """ logger.info( f"Hijacking the call for task-type {self.task_type}, to call notebook." ) # Execute Notebook via Papermill. pm.execute_notebook(self._notebook_path, self.output_notebook_path, parameters=kwargs) # type: ignore outputs = self.extract_outputs(self.output_notebook_path) self.render_nb_html(self.output_notebook_path, self.rendered_output_path) m = {} if outputs: m = outputs.literals output_list = [] for k, type_v in self.python_interface.outputs.items(): if k == self._IMPLICIT_OP_NOTEBOOK: output_list.append(self.output_notebook_path) elif k == self._IMPLICIT_RENDERED_NOTEBOOK: output_list.append(self.rendered_output_path) elif k in m: v = TypeEngine.to_python_value( ctx=FlyteContext.current_context(), lv=m[k], expected_python_type=type_v) output_list.append(v) else: raise RuntimeError( f"Expected output {k} of type {v} not found in the notebook outputs" ) return tuple(output_list)
def get_config_file(c: typing.Union[str, ConfigFile, None]) -> typing.Optional[ConfigFile]: """ Checks if the given argument is a file or a configFile and returns a loaded configFile else returns None """ if c is None: # See if there's a config file in the current directory where Python is being run from current_location_config = Path("flytekit.config") if current_location_config.exists(): logger.info(f"Using configuration from Python process root {current_location_config.absolute()}") return ConfigFile(str(current_location_config.absolute())) # If not, see if there's a config in the user's home directory home_dir_config = Path(Path.home(), ".flyte", "config") # _default_config_file_name in main.py if home_dir_config.exists(): logger.info(f"Using configuration from home directory {home_dir_config.absolute()}") return ConfigFile(str(home_dir_config.absolute())) # If not, see if the env var that flytectl sandbox tells the user to set is set, # or see if there's something in the default home directory location flytectl_path = Path(Path.home(), ".flyte", "config.yaml") flytectl_path_from_env = getenv(FLYTECTL_CONFIG_ENV_VAR, None) if flytectl_path_from_env: flytectl_path = Path(flytectl_path_from_env) if flytectl_path.exists(): logger.info(f"Using flytectl/YAML config {flytectl_path.absolute()}") return ConfigFile(str(flytectl_path.absolute())) # If not, then return None and let caller handle return None if isinstance(c, str): logger.debug(f"Using specified config file at {c}") return ConfigFile(c) return c
def _update_cmd_config_and_execute(s3_cfg: S3Config, cmd: List[str]): env = _os.environ.copy() if s3_cfg.enable_debug: cmd.insert(1, "--debug") if s3_cfg.endpoint is not None: cmd.insert(1, s3_cfg.endpoint) cmd.insert(1, "--endpoint-url") if S3_ACCESS_KEY_ID_ENV_NAME not in os.environ: if s3_cfg.access_key_id: env[S3_ACCESS_KEY_ID_ENV_NAME] = s3_cfg.access_key_id if S3_SECRET_ACCESS_KEY_ENV_NAME not in os.environ: if s3_cfg.secret_access_key: env[S3_SECRET_ACCESS_KEY_ENV_NAME] = s3_cfg.secret_access_key retry = 0 while True: try: try: return subprocess.check_call(cmd, env=env) except Exception as e: if retry > 0: logger.info( f"AWS command failed with error {e}, command: {cmd}, retry {retry}" ) logger.debug( f"Appending anonymous flag and retrying command {cmd}") anonymous_cmd = cmd[:] # strings only, so this is deep enough anonymous_cmd.insert(1, S3_ANONYMOUS_FLAG) return subprocess.check_call(anonymous_cmd, env=env) except Exception as e: logger.error( f"Exception when trying to execute {cmd}, reason: {str(e)}") retry += 1 if retry > s3_cfg.retries: raise secs = s3_cfg.backoff logger.info( f"Sleeping before retrying again, after {secs.total_seconds()} seconds" ) time.sleep(secs.total_seconds()) logger.info("Retrying again")
""" Flytekit PyTorch ========================================= .. currentmodule:: flytekit.extras.pytorch .. autosummary:: :template: custom.rst :toctree: generated/ PyTorchCheckpoint """ from flytekit.loggers import logger try: from .checkpoint import PyTorchCheckpoint, PyTorchCheckpointTransformer from .native import PyTorchModuleTransformer, PyTorchTensorTransformer except ImportError: logger.info( "We won't register PyTorchCheckpointTransformer, PyTorchTensorTransformer, and PyTorchModuleTransformer because torch is not installed." )
def dispatch_execute( self, ctx: FlyteContext, input_literal_map: _literal_models.LiteralMap ) -> Union[_literal_models.LiteralMap, _dynamic_job.DynamicJobSpec]: """ This function is largely similar to the base PythonTask, with the exception that we have to infer the Python interface before executing. Also, we refer to ``self.task_template`` rather than just ``self`` similar to task classes that derive from the base ``PythonTask``. """ # Invoked before the task is executed new_user_params = self.pre_execute(ctx.user_space_params) # Create another execution context with the new user params, but let's keep the same working dir with FlyteContextManager.with_context( ctx.with_execution_state( ctx.execution_state.with_params( user_space_params=new_user_params))) as exec_ctx: # Added: Have to reverse the Python interface from the task template Flyte interface # See docstring for more details. guessed_python_input_types = TypeEngine.guess_python_types( self.task_template.interface.inputs) native_inputs = TypeEngine.literal_map_to_kwargs( exec_ctx, input_literal_map, guessed_python_input_types) logger.info( f"Invoking FlyteTask executor {self.task_template.id.name} with inputs: {native_inputs}" ) try: native_outputs = self.execute(**native_inputs) except Exception as e: logger.exception(f"Exception when executing {e}") raise e logger.debug("Task executed successfully in user level") # Lets run the post_execute method. This may result in a IgnoreOutputs Exception, which is # bubbled up to be handled at the callee layer. native_outputs = self.post_execute(new_user_params, native_outputs) # Short circuit the translation to literal map because what's returned may be a dj spec (or an # already-constructed LiteralMap if the dynamic task was a no-op), not python native values if isinstance(native_outputs, _literal_models.LiteralMap) or isinstance( native_outputs, _dynamic_job.DynamicJobSpec): return native_outputs expected_output_names = list( self.task_template.interface.outputs.keys()) if len(expected_output_names) == 1: # Here we have to handle the fact that the task could've been declared with a typing.NamedTuple of # length one. That convention is used for naming outputs - and single-length-NamedTuples are # particularly troublesome but elegant handling of them is not a high priority # Again, we're using the output_tuple_name as a proxy. # Deleted some stuff native_outputs_as_map = { expected_output_names[0]: native_outputs } elif len(expected_output_names) == 0: native_outputs_as_map = {} else: native_outputs_as_map = { expected_output_names[i]: native_outputs[i] for i, _ in enumerate(native_outputs) } # We manually construct a LiteralMap here because task inputs and outputs actually violate the assumption # built into the IDL that all the values of a literal map are of the same type. literals = {} for k, v in native_outputs_as_map.items(): literal_type = self.task_template.interface.outputs[k].type py_type = type(v) if isinstance(v, tuple): raise AssertionError( f"Output({k}) in task{self.task_template.id.name} received a tuple {v}, instead of {py_type}" ) try: literals[k] = TypeEngine.to_literal( exec_ctx, v, py_type, literal_type) except Exception as e: raise AssertionError( f"failed to convert return value for var {k}") from e outputs_literal_map = _literal_models.LiteralMap(literals=literals) # After the execute has been successfully completed return outputs_literal_map
def local_execute(self, ctx: FlyteContext, **kwargs) -> Union[Tuple[Promise], Promise, VoidPromise]: """ This function is used only in the local execution path and is responsible for calling dispatch execute. Use this function when calling a task with native values (or Promises containing Flyte literals derived from Python native values). """ # Unwrap the kwargs values. After this, we essentially have a LiteralMap # The reason why we need to do this is because the inputs during local execute can be of 2 types # - Promises or native constants # Promises as essentially inputs from previous task executions # native constants are just bound to this specific task (default values for a task input) # Also along with promises and constants, there could be dictionary or list of promises or constants kwargs = translate_inputs_to_literals( ctx, incoming_values=kwargs, flyte_interface_types=self.interface.inputs, # type: ignore native_types=self.get_input_types(), ) input_literal_map = _literal_models.LiteralMap(literals=kwargs) # if metadata.cache is set, check memoized version if self.metadata.cache: # TODO: how to get a nice `native_inputs` here? logger.info( f"Checking cache for task named {self.name}, cache version {self.metadata.cache_version} " f"and inputs: {input_literal_map}") outputs_literal_map = LocalTaskCache.get( self.name, self.metadata.cache_version, input_literal_map) # The cache returns None iff the key does not exist in the cache if outputs_literal_map is None: logger.info("Cache miss, task will be executed now") outputs_literal_map = self.dispatch_execute( ctx, input_literal_map) # TODO: need `native_inputs` LocalTaskCache.set(self.name, self.metadata.cache_version, input_literal_map, outputs_literal_map) logger.info( f"Cache set for task named {self.name}, cache version {self.metadata.cache_version} " f"and inputs: {input_literal_map}") else: logger.info("Cache hit") else: es = ctx.execution_state b = es.user_space_params.with_task_sandbox() ctx = ctx.current_context().with_execution_state( es.with_params(user_space_params=b.build())).build() outputs_literal_map = self.dispatch_execute(ctx, input_literal_map) outputs_literals = outputs_literal_map.literals # TODO maybe this is the part that should be done for local execution, we pass the outputs to some special # location, otherwise we dont really need to right? The higher level execute could just handle literalMap # After running, we again have to wrap the outputs, if any, back into Promise objects output_names = list(self.interface.outputs.keys()) # type: ignore if len(output_names) != len(outputs_literals): # Length check, clean up exception raise AssertionError( f"Length difference {len(output_names)} {len(outputs_literals)}" ) # Tasks that don't return anything still return a VoidPromise if len(output_names) == 0: return VoidPromise(self.name) vals = [Promise(var, outputs_literals[var]) for var in output_names] return create_task_output(vals, self.python_interface)
def execute(self, **kwargs) -> Any: context = ge.data_context.DataContext( self._context_root_dir) # type: ignore if len(self.python_interface.inputs.keys()) != 1: raise TypeError( "Expected one input argument to validate the dataset") dataset_key = list(self.python_interface.inputs.keys())[0] dataset = kwargs[dataset_key] datatype = self.python_interface.inputs[dataset_key] if not issubclass(datatype, (FlyteFile, FlyteSchema, str)): raise TypeError( "'dataset' has to have FlyteFile/FlyteSchema/str datatype") # determine the type of data connector selected_datasource = list( filter(lambda x: x["name"] == self._datasource_name, context.list_datasources())) if not selected_datasource: raise ValueError("Datasource doesn't exist!") data_connector_class_lookup = { data_connector_name: data_connector_class["class_name"] for data_connector_name, data_connector_class in selected_datasource[0]["data_connectors"].items() } specified_data_connector_class = data_connector_class_lookup[ self._data_connector_name] is_runtime = False if specified_data_connector_class == "RuntimeDataConnector": is_runtime = True if not self._data_asset_name: raise ValueError( "data_asset_name has to be given in a RuntimeBatchRequest") # FlyteFile if issubclass(datatype, FlyteFile): dataset = self._flyte_file(dataset) # FlyteSchema # convert schema to parquet file if issubclass(datatype, FlyteSchema) and not is_runtime: dataset = self._flyte_schema(dataset) # minimalistic batch request final_batch_request = { "data_asset_name": self._data_asset_name if is_runtime else dataset, "datasource_name": self._datasource_name, "data_connector_name": self._data_connector_name, } # Great Expectations' RuntimeBatchRequest if self._batch_request_config and ( self._batch_request_config.runtime_parameters or is_runtime): final_batch_request.update({ "runtime_parameters": self._batch_request_config.runtime_parameters if self._batch_request_config.runtime_parameters else {}, "batch_identifiers": self._batch_request_config.batch_identifiers, "batch_spec_passthrough": self._batch_request_config.batch_spec_passthrough, }) if is_runtime and issubclass(datatype, str): final_batch_request["runtime_parameters"]["query"] = dataset elif is_runtime and issubclass(datatype, FlyteSchema): final_batch_request["runtime_parameters"][ "batch_data"] = dataset.open().all() else: raise AssertionError( "Can only use runtime_parameters for query(str)/schema data" ) # Great Expectations' BatchRequest elif self._batch_request_config: final_batch_request.update({ "data_connector_query": self._batch_request_config.data_connector_query, "batch_spec_passthrough": self._batch_request_config.batch_spec_passthrough, }) if self._checkpoint_params: checkpoint = SimpleCheckpoint( f"_tmp_checkpoint_{self._expectation_suite_name}", context, **self._checkpoint_params, ) else: checkpoint = SimpleCheckpoint( f"_tmp_checkpoint_{self._expectation_suite_name}", context, ) # identify every run uniquely run_id = RunIdentifier( **{ "run_name": self._datasource_name + "_run", "run_time": datetime.datetime.utcnow(), }) checkpoint_result = checkpoint.run( run_id=run_id, validations=[{ "batch_request": final_batch_request, "expectation_suite_name": self._expectation_suite_name, }], ) final_result = convert_to_json_serializable( checkpoint_result.list_validation_results())[0] result_string = "" if final_result["success"] is False: for every_result in final_result["results"]: if every_result["success"] is False: result_string += ( every_result["expectation_config"]["kwargs"]["column"] + " -> " + every_result["expectation_config"]["expectation_type"] + "\n") # raise a Great Expectations' exception raise ValidationError( "Validation failed!\nCOLUMN\t\tFAILED EXPECTATION\n" + result_string) logger.info("Validation succeeded!") return final_result
def __init__( self, name: str, task_config: T, task_type="python-task", container_image: Optional[str] = None, requests: Optional[Resources] = None, limits: Optional[Resources] = None, environment: Optional[Dict[str, str]] = None, task_resolver: Optional[TaskResolverMixin] = None, secret_requests: Optional[List[Secret]] = None, **kwargs, ): """ :param name: unique name for the task, usually the function's module and name. :param task_config: Configuration object for Task. Should be a unique type for that specific Task :param task_type: String task type to be associated with this Task :param container_image: String FQN for the image. :param requests: custom resource request settings. :param limits: custom resource limit settings. :param environment: Environment variables you want the task to have when run. :param task_resolver: Custom resolver - will pick up the default resolver if empty, or the resolver set in the compilation context if one is set. :param List[Secret] secret_requests: Secrets that are requested by this container execution. These secrets will be mounted based on the configuration in the Secret and available through the SecretManager using the name of the secret as the group Ideally the secret keys should also be semi-descriptive. The key values will be available from runtime, if the backend is configured to provide secrets and if secrets are available in the configured secrets store. Possible options for secret stores are - `Vault <https://www.vaultproject.io/>`, - `Confidant <https://lyft.github.io/confidant/>`, - `Kube secrets <https://kubernetes.io/docs/concepts/configuration/secret/>` - `AWS Parameter store <https://docs.aws.amazon.com/systems-manager/latest/userguide/systems-manager-parameter-store.html>`_ etc """ sec_ctx = None if secret_requests: for s in secret_requests: if not isinstance(s, Secret): raise AssertionError( f"Secret {s} should be of type flytekit.Secret, received {type(s)}" ) sec_ctx = SecurityContext(secrets=secret_requests) super().__init__( task_type=task_type, name=name, task_config=task_config, security_ctx=sec_ctx, **kwargs, ) self._container_image = container_image # TODO(katrogan): Implement resource overrides self._resources = ResourceSpec( requests=requests if requests else Resources(), limits=limits if limits else Resources()) self._environment = environment compilation_state = FlyteContext.current_context().compilation_state if compilation_state and compilation_state.task_resolver: if task_resolver: logger.info( f"Not using the passed in task resolver {task_resolver} because one found in compilation context" ) self._task_resolver = compilation_state.task_resolver if self._task_resolver.task_name(self) is not None: self._name = self._task_resolver.task_name(self) else: self._task_resolver = task_resolver or default_task_resolver
def __enter__(self): logger.info("Entering timed context: {}".format( self._context_statement)) self._start_wall_time = _time.perf_counter() self._start_process_time = _time.process_time()
def get_protocol(path: typing.Optional[str] = None): if path: return DataPersistencePlugins.get_protocol(path) logger.info("Setting protocol to file") return "file"
""" from flytekit.configuration.internal import LocalSDK from flytekit.loggers import logger from .basic_dfs import ( ArrowToParquetEncodingHandler, PandasToParquetEncodingHandler, ParquetToArrowDecodingHandler, ParquetToPandasDecodingHandler, ) from .structured_dataset import ( StructuredDataset, StructuredDatasetDecoder, StructuredDatasetEncoder, StructuredDatasetTransformerEngine, ) try: from .bigquery import ( ArrowToBQEncodingHandlers, BQToArrowDecodingHandler, BQToPandasDecodingHandler, PandasToBQEncodingHandlers, ) except ImportError: logger.info( "We won't register bigquery handler for structured dataset because " "we can't find the packages google-cloud-bigquery-storage and google-cloud-bigquery" )
def dispatch_execute( self, ctx: FlyteContext, input_literal_map: _literal_models.LiteralMap ) -> Union[_literal_models.LiteralMap, _dynamic_job.DynamicJobSpec]: """ This method translates Flyte's Type system based input values and invokes the actual call to the executor This method is also invoked during runtime. * ``VoidPromise`` is returned in the case when the task itself declares no outputs. * ``Literal Map`` is returned when the task returns either one more outputs in the declaration. Individual outputs may be none * ``DynamicJobSpec`` is returned when a dynamic workflow is executed """ # Invoked before the task is executed new_user_params = self.pre_execute(ctx.user_space_params) # Create another execution context with the new user params, but let's keep the same working dir with ctx.new_execution_context( mode=ctx.execution_state.mode, execution_params=new_user_params, working_dir=ctx.execution_state.working_dir, ) as exec_ctx: # TODO We could support default values here too - but not part of the plan right now # Translate the input literals to Python native native_inputs = TypeEngine.literal_map_to_kwargs( exec_ctx, input_literal_map, self.python_interface.inputs) # TODO: Logger should auto inject the current context information to indicate if the task is running within # a workflow or a subworkflow etc logger.info(f"Invoking {self.name} with inputs: {native_inputs}") try: native_outputs = self.execute(**native_inputs) except Exception as e: logger.exception(f"Exception when executing {e}") raise e logger.info( f"Task executed successfully in user level, outputs: {native_outputs}" ) # Lets run the post_execute method. This may result in a IgnoreOutputs Exception, which is # bubbled up to be handled at the callee layer. native_outputs = self.post_execute(new_user_params, native_outputs) # Short circuit the translation to literal map because what's returned may be a dj spec (or an # already-constructed LiteralMap if the dynamic task was a no-op), not python native values if isinstance(native_outputs, _literal_models.LiteralMap) or isinstance( native_outputs, _dynamic_job.DynamicJobSpec): return native_outputs expected_output_names = list(self._outputs_interface.keys()) if len(expected_output_names) == 1: # Here we have to handle the fact that the task could've been declared with a typing.NamedTuple of # length one. That convention is used for naming outputs - and single-length-NamedTuples are # particularly troublesome but elegant handling of them is not a high priority # Again, we're using the output_tuple_name as a proxy. if self.python_interface.output_tuple_name and isinstance( native_outputs, tuple): native_outputs_as_map = { expected_output_names[0]: native_outputs[0] } else: native_outputs_as_map = { expected_output_names[0]: native_outputs } elif len(expected_output_names) == 0: native_outputs_as_map = {} else: native_outputs_as_map = { expected_output_names[i]: native_outputs[i] for i, _ in enumerate(native_outputs) } # We manually construct a LiteralMap here because task inputs and outputs actually violate the assumption # built into the IDL that all the values of a literal map are of the same type. literals = {} for k, v in native_outputs_as_map.items(): literal_type = self._outputs_interface[k].type py_type = self.get_type_for_output_var(k, v) if isinstance(v, tuple): raise AssertionError( f"Output({k}) in task{self.name} received a tuple {v}, instead of {py_type}" ) try: literals[k] = TypeEngine.to_literal( exec_ctx, v, py_type, literal_type) except Exception as e: raise AssertionError( f"failed to convert return value for var {k}") from e outputs_literal_map = _literal_models.LiteralMap(literals=literals) # After the execute has been successfully completed return outputs_literal_map
def to_python_value( self, ctx: FlyteContext, lv: Literal, expected_python_type: Type[GreatExpectationsType], ) -> GreatExpectationsType: if not (lv and lv.scalar and ((lv.scalar.primitive and lv.scalar.primitive.string_value) or lv.scalar.schema or lv.scalar.blob or lv.scalar.structured_dataset)): raise AssertionError( "Can only validate a literal string/FlyteFile/FlyteSchema value" ) # fetch the configuration type_conf = GreatExpectationsTypeTransformer.get_config( expected_python_type) conf_dict = type_conf[1].to_dict() # type: ignore ge_conf = GreatExpectationsFlyteConfig(**conf_dict) # fetch the data context context = ge.data_context.DataContext( ge_conf.context_root_dir) # type: ignore # determine the type of data connector selected_datasource = list( filter(lambda x: x["name"] == ge_conf.datasource_name, context.list_datasources())) if not selected_datasource: raise ValueError("Datasource doesn't exist!") data_connector_class_lookup = { data_connector_name: data_connector_class["class_name"] for data_connector_name, data_connector_class in selected_datasource[0]["data_connectors"].items() } specified_data_connector_class = data_connector_class_lookup[ ge_conf.data_connector_name] is_runtime = False if specified_data_connector_class == "RuntimeDataConnector": is_runtime = True if not ge_conf.data_asset_name: raise ValueError( "data_asset_name has to be given in a RuntimeBatchRequest") # file path for FlyteSchema and FlyteFile temp_dataset = "" # return value return_dataset = "" # FlyteSchema if lv.scalar.schema or lv.scalar.structured_dataset: return_dataset, temp_dataset = self._flyte_schema( is_runtime=is_runtime, ctx=ctx, ge_conf=ge_conf, lv=lv, expected_python_type=type_conf[0]) # FlyteFile if lv.scalar.blob: return_dataset, temp_dataset = self._flyte_file( ctx=ctx, ge_conf=ge_conf, lv=lv, expected_python_type=type_conf[0]) if lv.scalar.primitive: dataset = return_dataset = lv.scalar.primitive.string_value else: dataset = temp_dataset batch_request_conf = ge_conf.batch_request_config # minimalistic batch request final_batch_request = { "data_asset_name": ge_conf.data_asset_name if is_runtime else dataset, "datasource_name": ge_conf.datasource_name, "data_connector_name": ge_conf.data_connector_name, } # Great Expectations' RuntimeBatchRequest if batch_request_conf and (batch_request_conf["runtime_parameters"] or is_runtime): final_batch_request.update({ "runtime_parameters": batch_request_conf["runtime_parameters"] if batch_request_conf["runtime_parameters"] else {}, "batch_identifiers": batch_request_conf["batch_identifiers"], "batch_spec_passthrough": batch_request_conf["batch_spec_passthrough"], }) if is_runtime and lv.scalar.primitive: final_batch_request["runtime_parameters"]["query"] = dataset elif is_runtime and (lv.scalar.schema or lv.scalar.structured_dataset): final_batch_request["runtime_parameters"][ "batch_data"] = return_dataset.open().all() else: raise AssertionError( "Can only use runtime_parameters for query(str)/schema data" ) # Great Expectations' BatchRequest elif batch_request_conf: final_batch_request.update({ "data_connector_query": batch_request_conf["data_connector_query"], "batch_spec_passthrough": batch_request_conf["batch_spec_passthrough"], }) if ge_conf.checkpoint_params: checkpoint = SimpleCheckpoint( f"_tmp_checkpoint_{ge_conf.expectation_suite_name}", context, **ge_conf.checkpoint_params, ) else: checkpoint = SimpleCheckpoint( f"_tmp_checkpoint_{ge_conf.expectation_suite_name}", context) # identify every run uniquely run_id = RunIdentifier( **{ "run_name": ge_conf.datasource_name + "_run", "run_time": datetime.datetime.utcnow(), }) checkpoint_result = checkpoint.run( run_id=run_id, validations=[{ "batch_request": final_batch_request, "expectation_suite_name": ge_conf.expectation_suite_name, }], ) final_result = convert_to_json_serializable( checkpoint_result.list_validation_results())[0] result_string = "" if final_result["success"] is False: for every_result in final_result["results"]: if every_result["success"] is False: result_string += ( every_result["expectation_config"]["kwargs"]["column"] + " -> " + every_result["expectation_config"]["expectation_type"] + "\n") # raise a Great Expectations' exception raise ValidationError( "Validation failed!\nCOLUMN\t\tFAILED EXPECTATION\n" + result_string) logger.info("Validation succeeded!") return typing.cast(GreatExpectationsType, return_dataset)