def test_two(two_sample_inputs): my_input = two_sample_inputs[0] my_input_2 = two_sample_inputs[1] @dynamic def dt1(a: List[MyInput]) -> List[FlyteFile]: x = [] for aa in a: x.append(aa.main_product) return x with FlyteContextManager.with_context( FlyteContextManager.current_context().with_serialization_settings( SerializationSettings( project="test_proj", domain="test_domain", version="abc", image_config=ImageConfig(Image(name="name", fqn="image", tag="name")), env={}, ) ) ) as ctx: with FlyteContextManager.with_context( ctx.with_execution_state( ctx.execution_state.with_params( mode=ExecutionState.Mode.TASK_EXECUTION, ) ) ) as ctx: input_literal_map = TypeEngine.dict_to_literal_map( ctx, d={"a": [my_input, my_input_2]}, type_hints={"a": List[MyInput]} ) dynamic_job_spec = dt1.dispatch_execute(ctx, input_literal_map) assert len(dynamic_job_spec.literals["o0"].collection.literals) == 2
def test_dispatch_execute_system_error(mock_write_to_file, mock_upload_dir, mock_get_data, mock_load_proto): # Just leave these here, mock them out so nothing happens mock_get_data.return_value = True mock_upload_dir.return_value = True ctx = context_manager.FlyteContext.current_context() with context_manager.FlyteContextManager.with_context( ctx.with_execution_state( ctx.execution_state.with_params(mode=context_manager.ExecutionState.Mode.TASK_EXECUTION) ) ) as ctx: input_literal_map = TypeEngine.dict_to_literal_map(ctx, {"a": 5}) mock_load_proto.return_value = input_literal_map.to_flyte_idl() python_task = mock.MagicMock() python_task.dispatch_execute.side_effect = Exception("some system exception") files = OrderedDict() mock_write_to_file.side_effect = get_output_collector(files) # See comment in test_dispatch_execute_ignore for why we need to decorate system_entry_point(_dispatch_execute)(ctx, python_task, "inputs path", "outputs prefix") assert len(files) == 1 # Exception should've caused an error file k = list(files.keys())[0] assert "error.pb" in k v = list(files.values())[0] ed = error_models.ErrorDocument.from_flyte_idl(v) # System errors default to recoverable assert ed.error.kind == error_models.ContainerError.Kind.RECOVERABLE assert "some system exception" in ed.error.message assert ed.error.origin == execution_models.ExecutionError.ErrorKind.SYSTEM
def test_dispatch_execute_user_error_non_recov(mock_write_to_file, mock_upload_dir, mock_get_data, mock_load_proto): # Just leave these here, mock them out so nothing happens mock_get_data.return_value = True mock_upload_dir.return_value = True @task def t1(a: int) -> str: # Should be interpreted as a non-recoverable user error raise ValueError(f"some exception {a}") ctx = context_manager.FlyteContext.current_context() with context_manager.FlyteContextManager.with_context( ctx.with_execution_state( ctx.execution_state.with_params(mode=context_manager.ExecutionState.Mode.TASK_EXECUTION) ) ) as ctx: input_literal_map = TypeEngine.dict_to_literal_map(ctx, {"a": 5}) mock_load_proto.return_value = input_literal_map.to_flyte_idl() files = OrderedDict() mock_write_to_file.side_effect = get_output_collector(files) # See comment in test_dispatch_execute_ignore for why we need to decorate system_entry_point(_dispatch_execute)(ctx, t1, "inputs path", "outputs prefix") assert len(files) == 1 # Exception should've caused an error file k = list(files.keys())[0] assert "error.pb" in k v = list(files.values())[0] ed = error_models.ErrorDocument.from_flyte_idl(v) assert ed.error.kind == error_models.ContainerError.Kind.NON_RECOVERABLE assert "some exception 5" in ed.error.message assert ed.error.origin == execution_models.ExecutionError.ErrorKind.USER
def test_dispatch_execute_normal(mock_write_to_file, mock_upload_dir, mock_get_data, mock_load_proto): # Just leave these here, mock them out so nothing happens mock_get_data.return_value = True mock_upload_dir.return_value = True @task def t1(a: int) -> str: return f"string is: {a}" ctx = context_manager.FlyteContext.current_context() with context_manager.FlyteContextManager.with_context( ctx.with_execution_state( ctx.execution_state.with_params(mode=context_manager.ExecutionState.Mode.TASK_EXECUTION) ) ) as ctx: input_literal_map = TypeEngine.dict_to_literal_map(ctx, {"a": 5}) mock_load_proto.return_value = input_literal_map.to_flyte_idl() files = OrderedDict() mock_write_to_file.side_effect = get_output_collector(files) # See comment in test_dispatch_execute_ignore for why we need to decorate system_entry_point(_dispatch_execute)(ctx, t1, "inputs path", "outputs prefix") assert len(files) == 1 # A successful run should've written an outputs file. k = list(files.keys())[0] assert "outputs.pb" in k v = list(files.values())[0] lm = _literal_models.LiteralMap.from_flyte_idl(v) assert lm.literals["o0"].scalar.primitive.string_value == "string is: 5"
def test_pb_guess_python_type(): artifact_tag = catalog_pb2.CatalogArtifactTag(artifact_id="artifact_1", name="artifact_name") x = {"a": artifact_tag} lt = TypeEngine.to_literal_type(catalog_pb2.CatalogArtifactTag) gt = TypeEngine.guess_python_type(lt) assert gt == catalog_pb2.CatalogArtifactTag ctx = FlyteContextManager.current_context() lm = TypeEngine.dict_to_literal_map(ctx, x, {"a": gt}) pv = TypeEngine.to_python_value(ctx, lm.literals["a"], gt) assert pv == artifact_tag
def test_dc_dyn_directory(folders_and_files_setup): proxy_c = MyProxyConfiguration(splat_data_dir="/tmp/proxy_splat", apriori_file="/opt/config/a_file") proxy_p = MyProxyParameters(id="pp_id", job_i_step=1) my_input_gcs = MyInput( main_product=FlyteFile(folders_and_files_setup[0]), apriori_config=MyAprioriConfiguration( static_data_dir=FlyteDirectory("gs://my-bucket/one"), external_data_dir=FlyteDirectory("gs://my-bucket/two"), ), proxy_config=proxy_c, proxy_params=proxy_p, ) my_input_gcs_2 = MyInput( main_product=FlyteFile(folders_and_files_setup[0]), apriori_config=MyAprioriConfiguration( static_data_dir=FlyteDirectory("gs://my-bucket/three"), external_data_dir=FlyteDirectory("gs://my-bucket/four"), ), proxy_config=proxy_c, proxy_params=proxy_p, ) @dynamic def dt1(a: List[MyInput]) -> List[FlyteDirectory]: x = [] for aa in a: x.append(aa.apriori_config.external_data_dir) return x ctx = FlyteContextManager.current_context() cb = ( ctx.new_builder() .with_serialization_settings( SerializationSettings( project="test_proj", domain="test_domain", version="abc", image_config=ImageConfig(Image(name="name", fqn="image", tag="name")), env={}, ) ) .with_execution_state(ctx.execution_state.with_params(mode=ExecutionState.Mode.TASK_EXECUTION)) ) with FlyteContextManager.with_context(cb) as ctx: input_literal_map = TypeEngine.dict_to_literal_map( ctx, d={"a": [my_input_gcs, my_input_gcs_2]}, type_hints={"a": List[MyInput]} ) dynamic_job_spec = dt1.dispatch_execute(ctx, input_literal_map) assert dynamic_job_spec.literals["o0"].collection.literals[0].scalar.blob.uri == "gs://my-bucket/two" assert dynamic_job_spec.literals["o0"].collection.literals[1].scalar.blob.uri == "gs://my-bucket/four"
def test_wf1_with_fast_dynamic(): @task def t1(a: int) -> str: a = a + 2 return "fast-" + str(a) @dynamic def my_subwf(a: int) -> typing.List[str]: s = [] for i in range(a): s.append(t1(a=i)) return s @workflow def my_wf(a: int) -> typing.List[str]: v = my_subwf(a=a) return v with context_manager.FlyteContextManager.with_context( context_manager.FlyteContextManager.current_context( ).with_serialization_settings( flytekit.configuration.SerializationSettings( project="test_proj", domain="test_domain", version="abc", image_config=ImageConfig( Image(name="name", fqn="image", tag="name")), env={}, fast_serialization_settings=FastSerializationSettings( enabled=True, destination_dir="/User/flyte/workflows", distribution_location="s3://my-s3-bucket/fast/123", ), ))) as ctx: with context_manager.FlyteContextManager.with_context( ctx.with_execution_state( ctx.execution_state.with_params( mode=ExecutionState.Mode.TASK_EXECUTION, ))) as ctx: input_literal_map = TypeEngine.dict_to_literal_map(ctx, {"a": 5}) dynamic_job_spec = my_subwf.dispatch_execute( ctx, input_literal_map) assert len(dynamic_job_spec._nodes) == 5 assert len(dynamic_job_spec.tasks) == 1 args = " ".join(dynamic_job_spec.tasks[0].container.args) assert args.startswith( "pyflyte-fast-execute --additional-distribution s3://my-s3-bucket/fast/123 " "--dest-dir /User/flyte/workflows") assert context_manager.FlyteContextManager.size() == 1
def test_dynamic(): @dynamic def my_subwf(a: int) -> typing.List[int]: s = [] for i in range(a): s.append(ft(a=i)) return s with context_manager.FlyteContextManager.with_context( context_manager.FlyteContextManager.current_context( ).with_serialization_settings( context_manager.SerializationSettings( project="test_proj", domain="test_domain", version="abc", image_config=ImageConfig( Image(name="name", fqn="image", tag="name")), env={}, fast_serialization_settings=FastSerializationSettings( enabled=True), ))) as ctx: with context_manager.FlyteContextManager.with_context( ctx.with_execution_state( ctx.execution_state.with_params( mode=ExecutionState.Mode.TASK_EXECUTION, ))) as ctx: input_literal_map = TypeEngine.dict_to_literal_map(ctx, {"a": 2}) # Test that it works dynamic_job_spec = my_subwf.dispatch_execute( ctx, input_literal_map) assert len(dynamic_job_spec._nodes) == 2 assert len(dynamic_job_spec.tasks) == 1 assert dynamic_job_spec.tasks[0].id == ft.id # Test that the fast execute stuff does not get applied because the commands of tasks fetched from # Admin should never change. args = " ".join(dynamic_job_spec.tasks[0].container.args) assert not args.startswith("pyflyte-fast-execute")
def test_fast(): REQUESTS_GPU = Resources(cpu="123m", mem="234Mi", ephemeral_storage="123M", gpu="1") LIMITS_GPU = Resources(cpu="124M", mem="235Mi", ephemeral_storage="124M", gpu="1") def get_minimal_pod_task_config() -> Pod: primary_container = V1Container(name="flytetask") pod_spec = V1PodSpec(containers=[primary_container]) return Pod(pod_spec=pod_spec, primary_container_name="flytetask") @task( task_config=get_minimal_pod_task_config(), requests=REQUESTS_GPU, limits=LIMITS_GPU, ) def pod_task_with_resources(dummy_input: str) -> str: return dummy_input @dynamic(requests=REQUESTS_GPU, limits=LIMITS_GPU) def dynamic_task_with_pod_subtask(dummy_input: str) -> str: pod_task_with_resources(dummy_input=dummy_input) return dummy_input default_img = Image(name="default", fqn="test", tag="tag") serialization_settings = SerializationSettings( project="project", domain="domain", version="version", env={"FOO": "baz"}, image_config=ImageConfig(default_image=default_img, images=[default_img]), fast_serialization_settings=FastSerializationSettings( enabled=True, destination_dir="/User/flyte/workflows", distribution_location="s3://my-s3-bucket/fast/123", ), ) with context_manager.FlyteContextManager.with_context( context_manager.FlyteContextManager.current_context( ).with_serialization_settings(serialization_settings)) as ctx: with context_manager.FlyteContextManager.with_context( ctx.with_execution_state( ctx.execution_state.with_params( mode=ExecutionState.Mode.TASK_EXECUTION, ))) as ctx: input_literal_map = TypeEngine.dict_to_literal_map( ctx, {"dummy_input": "hi"}) dynamic_job_spec = dynamic_task_with_pod_subtask.dispatch_execute( ctx, input_literal_map) # print(dynamic_job_spec) assert len(dynamic_job_spec._nodes) == 1 assert len(dynamic_job_spec.tasks) == 1 args = " ".join( dynamic_job_spec.tasks[0].k8s_pod.pod_spec["containers"][0] ["args"]) assert args.startswith( "pyflyte-fast-execute --additional-distribution s3://my-s3-bucket/fast/123 " "--dest-dir /User/flyte/workflows") assert dynamic_job_spec.tasks[0].k8s_pod.pod_spec["containers"][0][ "resources"]["limits"]["cpu"] == "124M" assert dynamic_job_spec.tasks[0].k8s_pod.pod_spec["containers"][0][ "resources"]["requests"]["gpu"] == "1" assert context_manager.FlyteContextManager.size() == 1