def test_basics(): @task def t1(a: int) -> typing.NamedTuple("OutputsBC", t1_int_output=int, c=str): return a + 2, "world" @task def t2(a: str, b: str) -> str: return b + a @workflow def my_wf(a: int, b: str) -> (int, str): x, y = t1(a=a) d = t2(a=y, b=b) return x, d wf_spec = get_serializable(OrderedDict(), serialization_settings, my_wf) assert len(wf_spec.template.interface.inputs) == 2 assert len(wf_spec.template.interface.outputs) == 2 assert len(wf_spec.template.nodes) == 2 assert wf_spec.template.id.resource_type == identifier_models.ResourceType.WORKFLOW # Gets cached the first time around so it's not actually fast. ssettings = ( serialization_settings.new_builder().with_fast_serialization_settings( FastSerializationSettings(enabled=True)).build()) task_spec = get_serializable(OrderedDict(), ssettings, t1) assert "pyflyte-execute" in task_spec.template.container.args lp = LaunchPlan.create( "testlp", my_wf, ) lp_model = get_serializable(OrderedDict(), serialization_settings, lp) assert lp_model.id.name == "testlp"
def test_serialization_settings_transport(): default_img = Image(name="default", fqn="test", tag="tag") serialization_settings = SerializationSettings( project="project", domain="domain", version="version", env={"hello": "blah"}, image_config=ImageConfig( default_image=default_img, images=[default_img], ), flytekit_virtualenv_root="/opt/venv/blah", python_interpreter="/opt/venv/bin/python3", fast_serialization_settings=FastSerializationSettings( enabled=True, destination_dir="/opt/blah/blah/blah", distribution_location="s3://my-special-bucket/blah/bha/asdasdasd/cbvsdsdf/asdddasdasdasdasdasdasd.tar.gz", ), ) tp = serialization_settings.serialized_context with_serialized = serialization_settings.with_serialized_context() assert serialization_settings.env == {"hello": "blah"} assert with_serialized.env assert with_serialized.env[SERIALIZED_CONTEXT_ENV_VAR] == tp ss = SerializationSettings.from_transport(tp) assert ss is not None assert ss == serialization_settings assert len(tp) == 376
def package(ctx, image_config, source, output, force, fast, in_container_source_path, python_interpreter): """ This command produces a Flyte backend registrable package of all entities in Flyte. For tasks, one pb file is produced for each task, representing one TaskTemplate object. For workflows, one pb file is produced for each workflow, representing a WorkflowClosure object. The closure object contains the WorkflowTemplate, along with the relevant tasks for that workflow. This serialization step will set the name of the tasks to the fully qualified name of the task function. """ if os.path.exists(output) and not force: raise click.BadParameter(click.style(f"Output file {output} already exists, specify -f to override.", fg="red")) serialization_settings = SerializationSettings( image_config=image_config, fast_serialization_settings=FastSerializationSettings( enabled=fast, destination_dir=in_container_source_path, ), python_interpreter=python_interpreter, ) pkgs = ctx.obj[constants.CTX_PACKAGES] if not pkgs: display_help_with_error(ctx, "No packages to scan for flyte entities. Aborting!") try: serialize_and_package(pkgs, serialization_settings, source, output, fast) except NoSerializableEntitiesError: click.secho(f"No flyte objects found in packages {pkgs}", fg="yellow")
def test_launch_plan_with_fixed_input(): @task def greet(day_of_week: str, number: int, am: bool) -> str: greeting = "Have a great " + day_of_week + " " greeting += "morning" if am else "evening" return greeting + "!" * number @workflow def go_greet(day_of_week: str, number: int, am: bool = False) -> str: return greet(day_of_week=day_of_week, number=number, am=am) morning_greeting = LaunchPlan.create( "morning_greeting", go_greet, fixed_inputs={"am": True}, default_inputs={"number": 1}, ) @workflow def morning_greeter_caller(day_of_week: str) -> str: greeting = morning_greeting(day_of_week=day_of_week) return greeting settings = ( serialization_settings.new_builder().with_fast_serialization_settings( FastSerializationSettings(enabled=True)).build()) task_spec = get_serializable(OrderedDict(), settings, morning_greeter_caller) assert len(task_spec.template.interface.inputs) == 1 assert len(task_spec.template.interface.outputs) == 1 assert len(task_spec.template.nodes) == 1 assert len(task_spec.template.nodes[0].inputs) == 2
def test_fast_pod_task_serialization(): pod = Pod( pod_spec=V1PodSpec(restart_policy="OnFailure", containers=[V1Container(name="primary")]), primary_container_name="primary", ) @task(task_config=pod, environment={"FOO": "bar"}) def simple_pod_task(i: int): pass default_img = Image(name="default", fqn="test", tag="tag") serialization_settings = SerializationSettings( project="project", domain="domain", version="version", env={"FOO": "baz"}, image_config=ImageConfig(default_image=default_img, images=[default_img]), fast_serialization_settings=FastSerializationSettings(enabled=True), ) serialized = get_serializable(OrderedDict(), serialization_settings, simple_pod_task) assert serialized.template.k8s_pod.pod_spec["containers"][0]["args"] == [ "pyflyte-fast-execute", "--additional-distribution", "{{ .remote_package_path }}", "--dest-dir", "{{ .dest_dir }}", "--", "pyflyte-execute", "--inputs", "{{.input}}", "--output-prefix", "{{.outputPrefix}}", "--raw-output-data-prefix", "{{.rawOutputDataPrefix}}", "--checkpoint-path", "{{.checkpointOutputPrefix}}", "--prev-checkpoint", "{{.prevCheckpointPrefix}}", "--resolver", "flytekit.core.python_auto_container.default_task_resolver", "--", "task-module", "tests.test_pod", "task-name", "simple_pod_task", ]
def test_fast(): @task def t1(a: int) -> typing.NamedTuple("OutputsBC", t1_int_output=int, c=str): return a + 2, "world" @task def t2(a: str, b: str) -> str: return b + a ssettings = ( serialization_settings.new_builder().with_fast_serialization_settings( FastSerializationSettings(enabled=True)).build()) task_spec = get_serializable(OrderedDict(), ssettings, t1) assert "pyflyte-fast-execute" in task_spec.template.container.args
def test_wf1_with_fast_dynamic(): @task def t1(a: int) -> str: a = a + 2 return "fast-" + str(a) @dynamic def my_subwf(a: int) -> typing.List[str]: s = [] for i in range(a): s.append(t1(a=i)) return s @workflow def my_wf(a: int) -> typing.List[str]: v = my_subwf(a=a) return v with context_manager.FlyteContextManager.with_context( context_manager.FlyteContextManager.current_context( ).with_serialization_settings( flytekit.configuration.SerializationSettings( project="test_proj", domain="test_domain", version="abc", image_config=ImageConfig( Image(name="name", fqn="image", tag="name")), env={}, fast_serialization_settings=FastSerializationSettings( enabled=True, destination_dir="/User/flyte/workflows", distribution_location="s3://my-s3-bucket/fast/123", ), ))) as ctx: with context_manager.FlyteContextManager.with_context( ctx.with_execution_state( ctx.execution_state.with_params( mode=ExecutionState.Mode.TASK_EXECUTION, ))) as ctx: input_literal_map = TypeEngine.dict_to_literal_map(ctx, {"a": 5}) dynamic_job_spec = my_subwf.dispatch_execute( ctx, input_literal_map) assert len(dynamic_job_spec._nodes) == 5 assert len(dynamic_job_spec.tasks) == 1 args = " ".join(dynamic_job_spec.tasks[0].container.args) assert args.startswith( "pyflyte-fast-execute --additional-distribution s3://my-s3-bucket/fast/123 " "--dest-dir /User/flyte/workflows") assert context_manager.FlyteContextManager.size() == 1
def serialize_all( pkgs: typing.List[str] = None, local_source_root: typing.Optional[str] = None, folder: typing.Optional[str] = None, mode: typing.Optional[SerializationMode] = None, image: typing.Optional[str] = None, flytekit_virtualenv_root: typing.Optional[str] = None, python_interpreter: typing.Optional[str] = None, config_file: typing.Optional[str] = None, ): """ This function will write to the folder specified the following protobuf types :: flyteidl.admin.launch_plan_pb2.LaunchPlan flyteidl.admin.workflow_pb2.WorkflowSpec flyteidl.admin.task_pb2.TaskSpec These can be inspected by calling (in the launch plan case) :: flyte-cli parse-proto -f filename.pb -p flyteidl.admin.launch_plan_pb2.LaunchPlan See :py:class:`flytekit.models.core.identifier.ResourceType` to match the trailing index in the file name with the entity type. :param pkgs: Dot-delimited Python packages/subpackages to look into for serialization. :param local_source_root: Where to start looking for the code. :param folder: Where to write the output protobuf files :param mode: Regular vs fast :param image: The fully qualified and versioned default image to use :param flytekit_virtualenv_root: The full path of the virtual env in the container. """ if not (mode == SerializationMode.DEFAULT or mode == SerializationMode.FAST): raise AssertionError(f"Unrecognized serialization mode: {mode}") serialization_settings = SerializationSettings( image_config=ImageConfig.auto(config_file, img_name=image), fast_serialization_settings=FastSerializationSettings( enabled=mode == SerializationMode.FAST, # TODO: if we want to move the destination dir as a serialization argument, we should initialize it here ), flytekit_virtualenv_root=flytekit_virtualenv_root, python_interpreter=python_interpreter, ) serialize_to_folder(pkgs, serialization_settings, local_source_root, folder)
def test_container(): @task def t1(a: int) -> (int, str): return a + 2, str(a) + "-HELLO" t2 = ContainerTask( "raw", image="alpine", inputs=kwtypes(a=int, b=str), input_data_dir="/tmp", output_data_dir="/tmp", command=["cat"], arguments=["/tmp/a"], requests=Resources(mem="400Mi", cpu="1"), ) ssettings = ( serialization_settings.new_builder().with_fast_serialization_settings( FastSerializationSettings(enabled=True)).build()) task_spec = get_serializable(OrderedDict(), ssettings, t2) assert "pyflyte" not in task_spec.template.container.args
def test_dynamic(): @dynamic def my_subwf(a: int) -> typing.List[int]: s = [] for i in range(a): s.append(ft(a=i)) return s with context_manager.FlyteContextManager.with_context( context_manager.FlyteContextManager.current_context( ).with_serialization_settings( context_manager.SerializationSettings( project="test_proj", domain="test_domain", version="abc", image_config=ImageConfig( Image(name="name", fqn="image", tag="name")), env={}, fast_serialization_settings=FastSerializationSettings( enabled=True), ))) as ctx: with context_manager.FlyteContextManager.with_context( ctx.with_execution_state( ctx.execution_state.with_params( mode=ExecutionState.Mode.TASK_EXECUTION, ))) as ctx: input_literal_map = TypeEngine.dict_to_literal_map(ctx, {"a": 2}) # Test that it works dynamic_job_spec = my_subwf.dispatch_execute( ctx, input_literal_map) assert len(dynamic_job_spec._nodes) == 2 assert len(dynamic_job_spec.tasks) == 1 assert dynamic_job_spec.tasks[0].id == ft.id # Test that the fast execute stuff does not get applied because the commands of tasks fetched from # Admin should never change. args = " ".join(dynamic_job_spec.tasks[0].container.args) assert not args.startswith("pyflyte-fast-execute")
def register( ctx: click.Context, project: str, domain: str, image_config: ImageConfig, output: str, destination_dir: str, service_account: str, raw_data_prefix: str, version: typing.Optional[str], package_or_module: typing.Tuple[str], ): """ see help """ pkgs = ctx.obj[constants.CTX_PACKAGES] if not pkgs: cli_logger.debug("No pkgs") if pkgs: raise ValueError( "Unimplemented, just specify pkgs like folder/files as args at the end of the command" ) if len(package_or_module) == 0: display_help_with_error( ctx, "Missing argument 'PACKAGE_OR_MODULE...', at least one PACKAGE_OR_MODULE is required but multiple can be passed", ) cli_logger.debug( f"Running pyflyte register from {os.getcwd()} " f"with images {image_config} " f"and image destinationfolder {destination_dir} " f"on {len(package_or_module)} package(s) {package_or_module}") # Create and save FlyteRemote, remote = get_and_save_remote_with_click_context(ctx, project, domain) # Todo: add switch for non-fast - skip the zipping and uploading and no fastserializationsettings # Create a zip file containing all the entries. detected_root = find_common_root(package_or_module) cli_logger.debug(f"Using {detected_root} as root folder for project") zip_file = fast_package(detected_root, output) # Upload zip file to Admin using FlyteRemote. md5_bytes, native_url = remote._upload_file(pathlib.Path(zip_file)) cli_logger.debug(f"Uploaded zip {zip_file} to {native_url}") # Create serialization settings # Todo: Rely on default Python interpreter for now, this will break custom Spark containers serialization_settings = SerializationSettings( project=project, domain=domain, image_config=image_config, fast_serialization_settings=FastSerializationSettings( enabled=True, destination_dir=destination_dir, distribution_location=native_url, ), ) options = Options.default_from(k8s_service_account=service_account, raw_data_prefix=raw_data_prefix) # Load all the entities registerable_entities = load_packages_and_modules(serialization_settings, detected_root, list(package_or_module), options) if len(registerable_entities) == 0: display_help_with_error(ctx, "No Flyte entities were detected. Aborting!") cli_logger.info( f"Found and serialized {len(registerable_entities)} entities") if not version: version = remote._version_from_hash(md5_bytes, serialization_settings, service_account, raw_data_prefix) # noqa cli_logger.info(f"Computed version is {version}") # Register using repo code repo_register(registerable_entities, project, domain, version, remote.client)
def setup_execution( raw_output_data_prefix: str, checkpoint_path: Optional[str] = None, prev_checkpoint: Optional[str] = None, dynamic_addl_distro: Optional[str] = None, dynamic_dest_dir: Optional[str] = None, ): """ :param raw_output_data_prefix: :param checkpoint_path: :param prev_checkpoint: :param dynamic_addl_distro: Works in concert with the other dynamic arg. If present, indicates that if a dynamic task were to run, it should set fast serialize to true and use these values in FastSerializationSettings :param dynamic_dest_dir: See above. :return: """ exe_project = get_one_of("FLYTE_INTERNAL_EXECUTION_PROJECT", "_F_PRJ") exe_domain = get_one_of("FLYTE_INTERNAL_EXECUTION_DOMAIN", "_F_DM") exe_name = get_one_of("FLYTE_INTERNAL_EXECUTION_ID", "_F_NM") exe_wf = get_one_of("FLYTE_INTERNAL_EXECUTION_WORKFLOW", "_F_WF") exe_lp = get_one_of("FLYTE_INTERNAL_EXECUTION_LAUNCHPLAN", "_F_LP") tk_project = get_one_of("FLYTE_INTERNAL_TASK_PROJECT", "_F_TK_PRJ") tk_domain = get_one_of("FLYTE_INTERNAL_TASK_DOMAIN", "_F_TK_DM") tk_name = get_one_of("FLYTE_INTERNAL_TASK_NAME", "_F_TK_NM") tk_version = get_one_of("FLYTE_INTERNAL_TASK_VERSION", "_F_TK_V") compressed_serialization_settings = os.environ.get(SERIALIZED_CONTEXT_ENV_VAR, "") ctx = FlyteContextManager.current_context() # Create directories user_workspace_dir = ctx.file_access.get_random_local_directory() logger.info(f"Using user directory {user_workspace_dir}") pathlib.Path(user_workspace_dir).mkdir(parents=True, exist_ok=True) from flytekit import __version__ as _api_version checkpointer = None if checkpoint_path is not None: checkpointer = SyncCheckpoint(checkpoint_dest=checkpoint_path, checkpoint_src=prev_checkpoint) logger.debug(f"Checkpointer created with source {prev_checkpoint} and dest {checkpoint_path}") execution_parameters = ExecutionParameters( execution_id=_identifier.WorkflowExecutionIdentifier( project=exe_project, domain=exe_domain, name=exe_name, ), execution_date=_datetime.datetime.utcnow(), stats=_get_stats( cfg=StatsConfig.auto(), # Stats metric path will be: # registration_project.registration_domain.app.module.task_name.user_stats # and it will be tagged with execution-level values for project/domain/wf/lp prefix=f"{tk_project}.{tk_domain}.{tk_name}.user_stats", tags={ "exec_project": exe_project, "exec_domain": exe_domain, "exec_workflow": exe_wf, "exec_launchplan": exe_lp, "api_version": _api_version, }, ), logging=user_space_logger, tmp_dir=user_workspace_dir, raw_output_prefix=raw_output_data_prefix, checkpoint=checkpointer, task_id=_identifier.Identifier(_identifier.ResourceType.TASK, tk_project, tk_domain, tk_name, tk_version), ) try: file_access = FileAccessProvider( local_sandbox_dir=tempfile.mkdtemp(prefix="flyte"), raw_output_prefix=raw_output_data_prefix, ) except TypeError: # would be thrown from DataPersistencePlugins.find_plugin logger.error(f"No data plugin found for raw output prefix {raw_output_data_prefix}") raise es = ctx.new_execution_state().with_params( mode=ExecutionState.Mode.TASK_EXECUTION, user_space_params=execution_parameters, ) cb = ctx.new_builder().with_file_access(file_access).with_execution_state(es) if compressed_serialization_settings: ss = SerializationSettings.from_transport(compressed_serialization_settings) ssb = ss.new_builder() ssb.project = exe_project ssb.domain = exe_domain ssb.version = tk_version if dynamic_addl_distro: ssb.fast_serialization_settings = FastSerializationSettings( enabled=True, destination_dir=dynamic_dest_dir, distribution_location=dynamic_addl_distro, ) cb = cb.with_serialization_settings(ssb.build()) with FlyteContextManager.with_context(cb) as ctx: yield ctx
def test_fast(): REQUESTS_GPU = Resources(cpu="123m", mem="234Mi", ephemeral_storage="123M", gpu="1") LIMITS_GPU = Resources(cpu="124M", mem="235Mi", ephemeral_storage="124M", gpu="1") def get_minimal_pod_task_config() -> Pod: primary_container = V1Container(name="flytetask") pod_spec = V1PodSpec(containers=[primary_container]) return Pod(pod_spec=pod_spec, primary_container_name="flytetask") @task( task_config=get_minimal_pod_task_config(), requests=REQUESTS_GPU, limits=LIMITS_GPU, ) def pod_task_with_resources(dummy_input: str) -> str: return dummy_input @dynamic(requests=REQUESTS_GPU, limits=LIMITS_GPU) def dynamic_task_with_pod_subtask(dummy_input: str) -> str: pod_task_with_resources(dummy_input=dummy_input) return dummy_input default_img = Image(name="default", fqn="test", tag="tag") serialization_settings = SerializationSettings( project="project", domain="domain", version="version", env={"FOO": "baz"}, image_config=ImageConfig(default_image=default_img, images=[default_img]), fast_serialization_settings=FastSerializationSettings( enabled=True, destination_dir="/User/flyte/workflows", distribution_location="s3://my-s3-bucket/fast/123", ), ) with context_manager.FlyteContextManager.with_context( context_manager.FlyteContextManager.current_context( ).with_serialization_settings(serialization_settings)) as ctx: with context_manager.FlyteContextManager.with_context( ctx.with_execution_state( ctx.execution_state.with_params( mode=ExecutionState.Mode.TASK_EXECUTION, ))) as ctx: input_literal_map = TypeEngine.dict_to_literal_map( ctx, {"dummy_input": "hi"}) dynamic_job_spec = dynamic_task_with_pod_subtask.dispatch_execute( ctx, input_literal_map) # print(dynamic_job_spec) assert len(dynamic_job_spec._nodes) == 1 assert len(dynamic_job_spec.tasks) == 1 args = " ".join( dynamic_job_spec.tasks[0].k8s_pod.pod_spec["containers"][0] ["args"]) assert args.startswith( "pyflyte-fast-execute --additional-distribution s3://my-s3-bucket/fast/123 " "--dest-dir /User/flyte/workflows") assert dynamic_job_spec.tasks[0].k8s_pod.pod_spec["containers"][0][ "resources"]["limits"]["cpu"] == "124M" assert dynamic_job_spec.tasks[0].k8s_pod.pod_spec["containers"][0][ "resources"]["requests"]["gpu"] == "1" assert context_manager.FlyteContextManager.size() == 1