def test_serialization_images(): @task(container_image="{{.image.xyz.fqn}}:{{.image.default.version}}") def t1(a: int) -> int: return a @task(container_image="{{.image.default.fqn}}:{{.image.default.version}}") def t2(): pass @task def t3(): pass @task(container_image="docker.io/org/myimage:latest") def t4(): pass @task(container_image="docker.io/org/myimage:{{.image.default.version}}") def t5(a: int) -> int: return a os.environ["FLYTE_INTERNAL_IMAGE"] = "docker.io/default:version" set_flyte_config_file( os.path.join(os.path.dirname(os.path.realpath(__file__)), "configs/images.config")) rs = context_manager.SerializationSettings( project="project", domain="domain", version="version", env=None, image_config=get_image_config(), ) t1_spec = get_serializable(OrderedDict(), rs, t1) assert t1_spec.template.container.image == "docker.io/xyz:version" t1_spec.to_flyte_idl() t2_spec = get_serializable(OrderedDict(), rs, t2) assert t2_spec.template.container.image == "docker.io/default:version" t3_spec = get_serializable(OrderedDict(), rs, t3) assert t3_spec.template.container.image == "docker.io/default:version" t4_spec = get_serializable(OrderedDict(), rs, t4) assert t4_spec.template.container.image == "docker.io/org/myimage:latest" t5_spec = get_serializable(OrderedDict(), rs, t5) assert t5_spec.template.container.image == "docker.io/org/myimage:version"
def serialize_all( pkgs: List[str] = None, local_source_root: str = None, folder: str = None, mode: SerializationMode = None, image: str = None, config_path: str = None, flytekit_virtualenv_root: str = None, ): """ In order to register, we have to comply with Admin's endpoints. Those endpoints take the following objects. These flyteidl.admin.launch_plan_pb2.LaunchPlanSpec flyteidl.admin.workflow_pb2.WorkflowSpec flyteidl.admin.task_pb2.TaskSpec However, if we were to merely call .to_flyte_idl() on all the discovered entities, what we would get are: flyteidl.admin.launch_plan_pb2.LaunchPlanSpec flyteidl.core.workflow_pb2.WorkflowTemplate flyteidl.core.tasks_pb2.TaskTemplate For Workflows and Tasks therefore, there is special logic in the serialize function that translates these objects. :param list[Text] pkgs: :param Text folder: :return: """ # m = module (i.e. python file) # k = value of dir(m), type str # o = object (e.g. SdkWorkflow) env = { _internal_config.CONFIGURATION_PATH.env_var: config_path if config_path else _internal_config.CONFIGURATION_PATH.get(), _internal_config.IMAGE.env_var: image, } serialization_settings = flyte_context.SerializationSettings( project=_PROJECT_PLACEHOLDER, domain=_DOMAIN_PLACEHOLDER, version=_VERSION_PLACEHOLDER, image_config=flyte_context.get_image_config(img_name=image), env=env, flytekit_virtualenv_root=flytekit_virtualenv_root, entrypoint_settings=flyte_context.EntrypointSettings( path=_os.path.join(flytekit_virtualenv_root, _DEFAULT_FLYTEKIT_RELATIVE_ENTRYPOINT_LOC)), ) with flyte_context.FlyteContext.current_context( ).new_serialization_settings( serialization_settings=serialization_settings) as ctx: loaded_entities = [] for m, k, o in iterate_registerable_entities_in_order( pkgs, local_source_root=local_source_root): name = _utils.fqdn(m.__name__, k, entity_type=o.resource_type) _logging.debug( "Found module {}\n K: {} Instantiated in {}".format( m, k, o._instantiated_in)) o._id = _identifier.Identifier(o.resource_type, _PROJECT_PLACEHOLDER, _DOMAIN_PLACEHOLDER, name, _VERSION_PLACEHOLDER) loaded_entities.append(o) ctx.serialization_settings.add_instance_var( InstanceVar(module=m, name=k, o=o)) click.echo( f"Found {len(flyte_context.FlyteEntities.entities)} tasks/workflows" ) mode = mode if mode else SerializationMode.DEFAULT # TODO: Clean up the copy() - it's here because we call get_default_launch_plan, which may create a LaunchPlan # object, which gets added to the FlyteEntities.entities list, which we're iterating over. for entity in flyte_context.FlyteEntities.entities.copy(): # TODO: Add a reachable check. Since these entities are always added by the constructor, weird things can # happen. If someone creates a workflow inside a workflow, we don't actually want the inner workflow to be # registered. Or do we? Certainly, we don't want inner tasks to be registered because we don't know how # to reach them, but perhaps workflows should be okay to take into account generated workflows. # Also a user may import dir_b.workflows from dir_a.workflows but workflow packages might only # specify dir_a if isinstance(entity, PythonTask) or isinstance( entity, Workflow) or isinstance(entity, LaunchPlan): if isinstance(entity, PythonTask): if mode == SerializationMode.DEFAULT: serializable = get_serializable( ctx.serialization_settings, entity) elif mode == SerializationMode.FAST: serializable = get_serializable( ctx.serialization_settings, entity, fast=True) else: raise AssertionError( f"Unrecognized serialization mode: {mode}") else: serializable = get_serializable(ctx.serialization_settings, entity) loaded_entities.append(serializable) if isinstance(entity, Workflow): lp = LaunchPlan.get_default_launch_plan(ctx, entity) launch_plan = get_serializable(ctx.serialization_settings, lp) loaded_entities.append(launch_plan) zero_padded_length = _determine_text_chars(len(loaded_entities)) for i, entity in enumerate(loaded_entities): if entity.has_registered: _logging.info( f"Skipping entity {entity.id} because already registered") continue serialized = entity.serialize() fname_index = str(i).zfill(zero_padded_length) fname = "{}_{}_{}.pb".format(fname_index, entity.id.name, entity.id.resource_type) click.echo( f" Writing type: {entity.id.resource_type_name()}, {entity.id.name} to\n {fname}" ) if folder: fname = _os.path.join(folder, fname) _write_proto_to_file(serialized, fname) click.secho( f"Successfully serialized {len(loaded_entities)} flyte objects", fg="green")
def serialize_all( pkgs: List[str] = None, local_source_root: str = None, folder: str = None, mode: SerializationMode = None, image: str = None, config_path: str = None, flytekit_virtualenv_root: str = None, python_interpreter: str = None, ): """ This function will write to the folder specified the following protobuf types :: flyteidl.admin.launch_plan_pb2.LaunchPlan flyteidl.admin.workflow_pb2.WorkflowSpec flyteidl.admin.task_pb2.TaskSpec These can be inspected by calling (in the launch plan case) :: flyte-cli parse-proto -f filename.pb -p flyteidl.admin.launch_plan_pb2.LaunchPlan See :py:class:`flytekit.models.core.identifier.ResourceType` to match the trailing index in the file name with the entity type. :param pkgs: Dot-delimited Python packages/subpackages to look into for serialization. :param local_source_root: Where to start looking for the code. :param folder: Where to write the output protobuf files :param mode: Regular vs fast :param image: The fully qualified and versioned default image to use :param config_path: Path to the config file, if any, to be used during serialization :param flytekit_virtualenv_root: The full path of the virtual env in the container. """ # m = module (i.e. python file) # k = value of dir(m), type str # o = object (e.g. SdkWorkflow) env = { _internal_config.CONFIGURATION_PATH.env_var: config_path if config_path else _internal_config.CONFIGURATION_PATH.get(), _internal_config.IMAGE.env_var: image, } if not (mode == SerializationMode.DEFAULT or mode == SerializationMode.FAST): raise AssertionError(f"Unrecognized serialization mode: {mode}") fast_serialization_settings = flyte_context.FastSerializationSettings( enabled=mode == SerializationMode.FAST, # TODO: if we want to move the destination dir as a serialization argument, we should initialize it here ) serialization_settings = flyte_context.SerializationSettings( project=_PROJECT_PLACEHOLDER, domain=_DOMAIN_PLACEHOLDER, version=_VERSION_PLACEHOLDER, image_config=flyte_context.get_image_config(img_name=image), env=env, flytekit_virtualenv_root=flytekit_virtualenv_root, python_interpreter=python_interpreter, entrypoint_settings=flyte_context.EntrypointSettings( path=_os.path.join(flytekit_virtualenv_root, _DEFAULT_FLYTEKIT_RELATIVE_ENTRYPOINT_LOC)), fast_serialization_settings=fast_serialization_settings, ) ctx = flyte_context.FlyteContextManager.current_context( ).with_serialization_settings(serialization_settings) with flyte_context.FlyteContextManager.with_context(ctx) as ctx: old_style_entities = [] # This first for loop is for legacy API entities - SdkTask, SdkWorkflow, etc. The _get_entity_to_module # function that this iterate calls only works on legacy objects for m, k, o in iterate_registerable_entities_in_order( pkgs, local_source_root=local_source_root): name = _utils.fqdn(m.__name__, k, entity_type=o.resource_type) _logging.debug( "Found module {}\n K: {} Instantiated in {}".format( m, k, o._instantiated_in)) o._id = _identifier.Identifier(o.resource_type, _PROJECT_PLACEHOLDER, _DOMAIN_PLACEHOLDER, name, _VERSION_PLACEHOLDER) old_style_entities.append(o) serialized_old_style_entities = [] for entity in old_style_entities: if entity.has_registered: _logging.info( f"Skipping entity {entity.id} because already registered") continue serialized_old_style_entities.append(entity.serialize()) click.echo( f"Found {len(flyte_context.FlyteEntities.entities)} tasks/workflows" ) new_api_model_values = get_registrable_entities(ctx) loaded_entities = serialized_old_style_entities + new_api_model_values if folder is None: folder = "." persist_registrable_entities(loaded_entities, folder) click.secho( f"Successfully serialized {len(loaded_entities)} flyte objects", fg="green")
def _handle_annotated_task(task_def: PythonTask, inputs: str, output_prefix: str, raw_output_data_prefix: str): """ Entrypoint for all PythonTask extensions """ _click.echo("Running native-typed task") cloud_provider = _platform_config.CLOUD_PROVIDER.get() log_level = _internal_config.LOGGING_LEVEL.get() or _sdk_config.LOGGING_LEVEL.get() _logging.getLogger().setLevel(log_level) ctx = FlyteContext.current_context() # Create directories user_workspace_dir = ctx.file_access.local_access.get_random_directory() _click.echo(f"Using user directory {user_workspace_dir}") pathlib.Path(user_workspace_dir).mkdir(parents=True, exist_ok=True) from flytekit import __version__ as _api_version execution_parameters = ExecutionParameters( execution_id=_identifier.WorkflowExecutionIdentifier( project=_internal_config.EXECUTION_PROJECT.get(), domain=_internal_config.EXECUTION_DOMAIN.get(), name=_internal_config.EXECUTION_NAME.get(), ), execution_date=_datetime.datetime.utcnow(), stats=_get_stats( # Stats metric path will be: # registration_project.registration_domain.app.module.task_name.user_stats # and it will be tagged with execution-level values for project/domain/wf/lp "{}.{}.{}.user_stats".format( _internal_config.TASK_PROJECT.get() or _internal_config.PROJECT.get(), _internal_config.TASK_DOMAIN.get() or _internal_config.DOMAIN.get(), _internal_config.TASK_NAME.get() or _internal_config.NAME.get(), ), tags={ "exec_project": _internal_config.EXECUTION_PROJECT.get(), "exec_domain": _internal_config.EXECUTION_DOMAIN.get(), "exec_workflow": _internal_config.EXECUTION_WORKFLOW.get(), "exec_launchplan": _internal_config.EXECUTION_LAUNCHPLAN.get(), "api_version": _api_version, }, ), logging=_logging, tmp_dir=user_workspace_dir, ) if cloud_provider == _constants.CloudProvider.AWS: file_access = _data_proxy.FileAccessProvider( local_sandbox_dir=_sdk_config.LOCAL_SANDBOX.get(), remote_proxy=_s3proxy.AwsS3Proxy(raw_output_data_prefix), ) elif cloud_provider == _constants.CloudProvider.GCP: file_access = _data_proxy.FileAccessProvider( local_sandbox_dir=_sdk_config.LOCAL_SANDBOX.get(), remote_proxy=_gcs_proxy.GCSProxy(raw_output_data_prefix), ) elif cloud_provider == _constants.CloudProvider.LOCAL: # A fake remote using the local disk will automatically be created file_access = _data_proxy.FileAccessProvider(local_sandbox_dir=_sdk_config.LOCAL_SANDBOX.get()) else: raise Exception(f"Bad cloud provider {cloud_provider}") with ctx.new_file_access_context(file_access_provider=file_access) as ctx: # TODO: This is copied from serialize, which means there's a similarity here I'm not seeing. env = { _internal_config.CONFIGURATION_PATH.env_var: _internal_config.CONFIGURATION_PATH.get(), _internal_config.IMAGE.env_var: _internal_config.IMAGE.get(), } serialization_settings = SerializationSettings( project=_internal_config.TASK_PROJECT.get(), domain=_internal_config.TASK_DOMAIN.get(), version=_internal_config.TASK_VERSION.get(), image_config=get_image_config(), env=env, ) # The reason we need this is because of dynamic tasks. Even if we move compilation all to Admin, # if a dynamic task calls some task, t1, we have to write to the DJ Spec the correct task # identifier for t1. with ctx.new_serialization_settings(serialization_settings=serialization_settings) as ctx: # Because execution states do not look up the context chain, it has to be made last with ctx.new_execution_context( mode=ExecutionState.Mode.TASK_EXECUTION, execution_params=execution_parameters ) as ctx: _dispatch_execute(ctx, task_def, inputs, output_prefix)
def serialize_all( pkgs: List[str] = None, local_source_root: str = None, folder: str = None, mode: SerializationMode = None, image: str = None, config_path: str = None, flytekit_virtualenv_root: str = None, python_interpreter: str = None, ): """ This function will write to the folder specified the following protobuf types :: flyteidl.admin.launch_plan_pb2.LaunchPlan flyteidl.admin.workflow_pb2.WorkflowSpec flyteidl.admin.task_pb2.TaskSpec These can be inspected by calling (in the launch plan case) :: flyte-cli parse-proto -f filename.pb -p flyteidl.admin.launch_plan_pb2.LaunchPlan See :py:class:`flytekit.models.core.identifier.ResourceType` to match the trailing index in the file name with the entity type. :param pkgs: Dot-delimited Python packages/subpackages to look into for serialization. :param local_source_root: Where to start looking for the code. :param folder: Where to write the output protobuf files :param mode: Regular vs fast :param image: The fully qualified and versioned default image to use :param config_path: Path to the config file, if any, to be used during serialization :param flytekit_virtualenv_root: The full path of the virtual env in the container. """ # m = module (i.e. python file) # k = value of dir(m), type str # o = object (e.g. SdkWorkflow) env = { _internal_config.CONFIGURATION_PATH.env_var: config_path if config_path else _internal_config.CONFIGURATION_PATH.get(), _internal_config.IMAGE.env_var: image, } serialization_settings = flyte_context.SerializationSettings( project=_PROJECT_PLACEHOLDER, domain=_DOMAIN_PLACEHOLDER, version=_VERSION_PLACEHOLDER, image_config=flyte_context.get_image_config(img_name=image), env=env, flytekit_virtualenv_root=flytekit_virtualenv_root, python_interpreter=python_interpreter, entrypoint_settings=flyte_context.EntrypointSettings( path=_os.path.join(flytekit_virtualenv_root, _DEFAULT_FLYTEKIT_RELATIVE_ENTRYPOINT_LOC)), ) ctx = flyte_context.FlyteContextManager.current_context( ).with_serialization_settings(serialization_settings) with flyte_context.FlyteContextManager.with_context(ctx) as ctx: old_style_entities = [] # This first for loop is for legacy API entities - SdkTask, SdkWorkflow, etc. The _get_entity_to_module # function that this iterate calls only works on legacy objects for m, k, o in iterate_registerable_entities_in_order( pkgs, local_source_root=local_source_root): name = _utils.fqdn(m.__name__, k, entity_type=o.resource_type) _logging.debug( "Found module {}\n K: {} Instantiated in {}".format( m, k, o._instantiated_in)) o._id = _identifier.Identifier(o.resource_type, _PROJECT_PLACEHOLDER, _DOMAIN_PLACEHOLDER, name, _VERSION_PLACEHOLDER) old_style_entities.append(o) serialized_old_style_entities = [] for entity in old_style_entities: if entity.has_registered: _logging.info( f"Skipping entity {entity.id} because already registered") continue serialized_old_style_entities.append(entity.serialize()) click.echo( f"Found {len(flyte_context.FlyteEntities.entities)} tasks/workflows" ) mode = mode if mode else SerializationMode.DEFAULT new_api_serializable_entities = OrderedDict() # TODO: Clean up the copy() - it's here because we call get_default_launch_plan, which may create a LaunchPlan # object, which gets added to the FlyteEntities.entities list, which we're iterating over. for entity in flyte_context.FlyteEntities.entities.copy(): # TODO: Add a reachable check. Since these entities are always added by the constructor, weird things can # happen. If someone creates a workflow inside a workflow, we don't actually want the inner workflow to be # registered. Or do we? Certainly, we don't want inner tasks to be registered because we don't know how # to reach them, but perhaps workflows should be okay to take into account generated workflows. # Also a user may import dir_b.workflows from dir_a.workflows but workflow packages might only # specify dir_a if isinstance(entity, PythonTask) or isinstance( entity, WorkflowBase) or isinstance(entity, LaunchPlan): if isinstance(entity, PythonTask): if mode == SerializationMode.DEFAULT: get_serializable(new_api_serializable_entities, ctx.serialization_settings, entity) elif mode == SerializationMode.FAST: get_serializable(new_api_serializable_entities, ctx.serialization_settings, entity, fast=True) else: raise AssertionError( f"Unrecognized serialization mode: {mode}") else: get_serializable(new_api_serializable_entities, ctx.serialization_settings, entity) if isinstance(entity, WorkflowBase): lp = LaunchPlan.get_default_launch_plan(ctx, entity) get_serializable(new_api_serializable_entities, ctx.serialization_settings, lp) new_api_model_values = list(new_api_serializable_entities.values()) new_api_model_values = list( filter(_should_register_with_admin, new_api_model_values)) new_api_model_values = [v.to_flyte_idl() for v in new_api_model_values] loaded_entities = serialized_old_style_entities + new_api_model_values zero_padded_length = _determine_text_chars(len(loaded_entities)) for i, entity in enumerate(loaded_entities): fname_index = str(i).zfill(zero_padded_length) if isinstance(entity, _idl_admin_TaskSpec): fname = "{}_{}_1.pb".format(fname_index, entity.template.id.name) elif isinstance(entity, _idl_admin_WorkflowSpec): fname = "{}_{}_2.pb".format(fname_index, entity.template.id.name) elif isinstance(entity, _idl_admin_LaunchPlan): fname = "{}_{}_3.pb".format(fname_index, entity.id.name) else: raise Exception(f"Bad format {type(entity)}") click.echo(f" Writing to file: {fname}") if folder: fname = _os.path.join(folder, fname) _write_proto_to_file(entity, fname) click.secho( f"Successfully serialized {len(loaded_entities)} flyte objects", fg="green")