def test_fetch_execute_launch_plan_with_subworkflows(flyteclient, flyte_workflows_register): remote = FlyteRemote(Config.auto(), PROJECT, "development") flyte_launch_plan = remote.fetch_launch_plan( name="workflows.basic.subworkflows.parent_wf", version=f"v{VERSION}") execution = remote.execute(flyte_launch_plan, {"a": 101}, wait=True) # check node execution inputs and outputs assert execution.node_executions["n0"].inputs == {"a": 101} assert execution.node_executions["n0"].outputs == { "t1_int_output": 103, "c": "world" } assert execution.node_executions["n1"].inputs == {"a": 103} assert execution.node_executions["n1"].outputs == { "o0": "world", "o1": "world" } # check subworkflow task execution inputs and outputs subworkflow_node_executions = execution.node_executions[ "n1"].subworkflow_node_executions subworkflow_node_executions["n1-0-n0"].inputs == {"a": 103} subworkflow_node_executions["n1-0-n1"].outputs == { "t1_int_output": 107, "c": "world" }
def test_execute_sqlite3_task(flyteclient, flyte_workflows_register, flyte_remote_env): remote = FlyteRemote(Config.auto(), PROJECT, "development") example_db = "https://www.sqlitetutorial.net/wp-content/uploads/2018/03/chinook.zip" interactive_sql_task = SQLite3Task( "basic_querying", query_template= "select TrackId, Name from tracks limit {{.inputs.limit}}", inputs=kwtypes(limit=int), output_schema_type=FlyteSchema[kwtypes(TrackId=int, Name=str)], task_config=SQLite3Config( uri=example_db, compressed=True, ), ) registered_sql_task = remote.register(interactive_sql_task) execution = remote.execute(registered_sql_task, inputs={"limit": 10}, wait=True) output = execution.outputs["results"] result = output.open().all() assert result.__class__.__name__ == "DataFrame" assert "TrackId" in result assert "Name" in result
def test_underscore_execute_fall_back_remote_attributes(mock_wf_exec): mock_wf_exec.return_value = True mock_client = MagicMock() remote = FlyteRemote(config=Config.auto(), default_project="p1", default_domain="d1") remote._client = mock_client options = Options( raw_output_data_config=common_models.RawOutputDataConfig( output_location_prefix="raw_output"), security_context=security.SecurityContext(run_as=security.Identity( iam_role="iam:some:role")), ) def local_assertions(*args, **kwargs): execution_spec = args[3] assert execution_spec.security_context.run_as.iam_role == "iam:some:role" assert execution_spec.raw_output_data_config.output_location_prefix == "raw_output" mock_client.create_execution.side_effect = local_assertions mock_entity = MagicMock() remote._execute( mock_entity, inputs={}, project="proj", domain="dev", options=options, )
def test_execute_python_workflow_and_launch_plan(flyteclient, flyte_workflows_register, flyte_remote_env): """Test execution of a @workflow-decorated python function and launchplan that are already registered.""" from mock_flyte_repo.workflows.basic.basic_workflow import my_wf # make sure the task name is the same as the name used during registration my_wf._name = my_wf.name.replace("mock_flyte_repo.", "") remote = FlyteRemote(Config.auto(), PROJECT, "development") execution = remote.execute(my_wf, inputs={ "a": 10, "b": "xyz" }, version=f"v{VERSION}", wait=True) assert execution.outputs["o0"] == 12 assert execution.outputs["o1"] == "xyzworld" launch_plan = LaunchPlan.get_or_create(workflow=my_wf, name=my_wf.name) execution = remote.execute(launch_plan, inputs={ "a": 14, "b": "foobar" }, version=f"v{VERSION}", wait=True) assert execution.outputs["o0"] == 16 assert execution.outputs["o1"] == "foobarworld" flyte_workflow_execution = remote.fetch_execution(name=execution.id.name) assert execution.inputs == flyte_workflow_execution.inputs assert execution.outputs == flyte_workflow_execution.outputs
def test_execute_python_workflow_dict_of_string_to_string( flyteclient, flyte_workflows_register, flyte_remote_env): """Test execution of a @workflow-decorated python function and launchplan that are already registered.""" from mock_flyte_repo.workflows.basic.dict_str_wf import my_wf # make sure the task name is the same as the name used during registration my_wf._name = my_wf.name.replace("mock_flyte_repo.", "") remote = FlyteRemote(Config.auto(), PROJECT, "development") d: typing.Dict[str, str] = {"k1": "v1", "k2": "v2"} execution = remote.execute(my_wf, inputs={"d": d}, version=f"v{VERSION}", wait=True) assert json.loads(execution.outputs["o0"]) == {"k1": "v1", "k2": "v2"} launch_plan = LaunchPlan.get_or_create(workflow=my_wf, name=my_wf.name) execution = remote.execute(launch_plan, inputs={"d": { "k2": "vvvv", "abc": "def" }}, version=f"v{VERSION}", wait=True) assert json.loads(execution.outputs["o0"]) == {"k2": "vvvv", "abc": "def"}
def test_underscore_execute_uses_launch_plan_attributes(mock_wf_exec): mock_wf_exec.return_value = True mock_client = MagicMock() remote = FlyteRemote(config=Config.auto(), default_project="p1", default_domain="d1") remote._client = mock_client def local_assertions(*args, **kwargs): execution_spec = args[3] assert execution_spec.security_context.run_as.k8s_service_account == "svc" assert execution_spec.labels == common_models.Labels( {"a": "my_label_value"}) assert execution_spec.annotations == common_models.Annotations( {"b": "my_annotation_value"}) mock_client.create_execution.side_effect = local_assertions mock_entity = MagicMock() options = Options( labels=common_models.Labels({"a": "my_label_value"}), annotations=common_models.Annotations({"b": "my_annotation_value"}), security_context=security.SecurityContext(run_as=security.Identity( k8s_service_account="svc")), ) remote._execute( mock_entity, inputs={}, project="proj", domain="dev", options=options, )
def test_execute_with_default_launch_plan(flyteclient, flyte_workflows_register, flyte_remote_env): from mock_flyte_repo.workflows.basic.subworkflows import parent_wf # make sure the task name is the same as the name used during registration parent_wf._name = parent_wf.name.replace("mock_flyte_repo.", "") remote = FlyteRemote(Config.auto(), PROJECT, "development") execution = remote.execute(parent_wf, {"a": 101}, version=f"v{VERSION}", wait=True) # check node execution inputs and outputs assert execution.node_executions["n0"].inputs == {"a": 101} assert execution.node_executions["n0"].outputs == { "t1_int_output": 103, "c": "world" } assert execution.node_executions["n1"].inputs == {"a": 103} assert execution.node_executions["n1"].outputs == { "o0": "world", "o1": "world" } # check subworkflow task execution inputs and outputs subworkflow_node_executions = execution.node_executions[ "n1"].subworkflow_node_executions subworkflow_node_executions["n1-0-n0"].inputs == {"a": 103} subworkflow_node_executions["n1-0-n1"].outputs == { "t1_int_output": 107, "c": "world" }
def test_fetch_execute_launch_plan_list_of_floats(flyteclient, flyte_workflows_register): remote = FlyteRemote(Config.auto(), PROJECT, "development") flyte_launch_plan = remote.fetch_launch_plan( name="workflows.basic.list_float_wf.my_wf", version=f"v{VERSION}") xs: typing.List[float] = [42.24, 999.1, 0.0001] execution = remote.execute(flyte_launch_plan, inputs={"xs": xs}, wait=True) assert execution.outputs["o0"] == "[42.24, 999.1, 0.0001]"
def test_fetch_execute_task_list_of_floats(flyteclient, flyte_workflows_register): remote = FlyteRemote(Config.auto(), PROJECT, "development") flyte_task = remote.fetch_task( name="workflows.basic.list_float_wf.concat_list", version=f"v{VERSION}") xs: typing.List[float] = [0.1, 0.2, 0.3, 0.4, -99999.7] execution = remote.execute(flyte_task, {"xs": xs}, wait=True) assert execution.outputs["o0"] == "[0.1, 0.2, 0.3, 0.4, -99999.7]"
def test_fetch_execute_task(flyteclient, flyte_workflows_register): remote = FlyteRemote(Config.auto(), PROJECT, "development") flyte_task = remote.fetch_task(name="workflows.basic.basic_workflow.t1", version=f"v{VERSION}") execution = remote.execute(flyte_task, {"a": 10}, wait=True) assert execution.outputs["t1_int_output"] == 12 assert execution.outputs["c"] == "world" assert execution.raw_inputs.get("a", int) == 10 assert execution.raw_outputs.get("c", str) == "world"
def test_execute_joblib_workflow(flyteclient, flyte_workflows_register, flyte_remote_env): remote = FlyteRemote(Config.auto(), PROJECT, "development") flyte_workflow = remote.fetch_workflow( name="workflows.basic.joblib.joblib_workflow", version=f"v{VERSION}") input_obj = [1, 2, 3] execution = remote.execute(flyte_workflow, {"obj": input_obj}, wait=True) joblib_output = execution.outputs["o0"] joblib_output.download() output_obj = joblib.load(joblib_output.path) assert execution.outputs["o0"].extension() == "joblib" assert output_obj == input_obj
def test_fetch_execute_task_convert_dict(flyteclient, flyte_workflows_register): remote = FlyteRemote(Config.auto(), PROJECT, "development") flyte_task = remote.fetch_task( name="workflows.basic.dict_str_wf.convert_to_string", version=f"v{VERSION}") d: typing.Dict[str, str] = {"key1": "value1", "key2": "value2"} execution = remote.execute(flyte_task, {"d": d}, wait=True) assert json.loads(execution.outputs["o0"]) == { "key1": "value1", "key2": "value2" }
def test_generate_http_domain_sandbox_rewrite(mock_client): _, temp_filename = tempfile.mkstemp(suffix=".yaml") with open(temp_filename, "w") as f: # This string is similar to the relevant configuration emitted by flytectl in the cases of both demo and sandbox. flytectl_config_file = """admin: endpoint: localhost:30081 authType: Pkce insecure: true """ f.write(flytectl_config_file) remote = FlyteRemote(config=Config.auto(config_file=temp_filename), default_project="project", default_domain="domain") assert remote.generate_http_domain() == "http://localhost:30080"
def test_fetch_execute_launch_plan_with_child_workflows( flyteclient, flyte_workflows_register): remote = FlyteRemote(Config.auto(), PROJECT, "development") flyte_launch_plan = remote.fetch_launch_plan( name="workflows.basic.child_workflow.parent_wf", version=f"v{VERSION}") execution = remote.execute(flyte_launch_plan, {"a": 3}, wait=True) # check node execution inputs and outputs assert execution.node_executions["n0"].inputs == {"a": 3} assert execution.node_executions["n0"].outputs["o0"] == 6 assert execution.node_executions["n1"].inputs == {"a": 6} assert execution.node_executions["n1"].outputs["o0"] == 12 assert execution.node_executions["n2"].inputs == {"a": 6, "b": 12} assert execution.node_executions["n2"].outputs["o0"] == 18
def test_execute_python_task(flyteclient, flyte_workflows_register, flyte_remote_env): """Test execution of a @task-decorated python function that is already registered.""" from mock_flyte_repo.workflows.basic.basic_workflow import t1 # make sure the task name is the same as the name used during registration t1._name = t1.name.replace("mock_flyte_repo.", "") remote = FlyteRemote(Config.auto(), PROJECT, "development") execution = remote.execute(t1, inputs={"a": 10}, version=f"v{VERSION}", wait=True) assert execution.outputs["t1_int_output"] == 12 assert execution.outputs["c"] == "world"
def test_remote_fetch_execution(mock_client_manager): admin_workflow_execution = Execution( id=WorkflowExecutionIdentifier("p1", "d1", "n1"), spec=MagicMock(), closure=MagicMock(), ) mock_client = MagicMock() mock_client.get_execution.return_value = admin_workflow_execution remote = FlyteRemote(config=Config.auto(), default_project="p1", default_domain="d1") remote._client = mock_client flyte_workflow_execution = remote.fetch_execution(name="n1") assert flyte_workflow_execution.id == admin_workflow_execution.id
def test_execute_with_wrong_input_key(mock_wf_exec): # mock_url.get.return_value = "localhost" # mock_insecure.get.return_value = True mock_wf_exec.return_value = True mock_client = MagicMock() remote = FlyteRemote(config=Config.auto(), default_project="p1", default_domain="d1") remote._client = mock_client mock_entity = MagicMock() mock_entity.interface.inputs = {"foo": int} with pytest.raises(user_exceptions.FlyteValueException): remote._execute( mock_entity, inputs={"bar": 3}, project="proj", domain="dev", )
def test_fetch_execute_workflow(flyteclient, flyte_workflows_register): remote = FlyteRemote(Config.auto(), PROJECT, "development") flyte_workflow = remote.fetch_workflow( name="workflows.basic.hello_world.my_wf", version=f"v{VERSION}") execution = remote.execute(flyte_workflow, {}, wait=True) assert execution.outputs["o0"] == "hello world" assert isinstance(execution.closure.duration, datetime.timedelta) assert execution.closure.duration > datetime.timedelta(seconds=1) execution_to_terminate = remote.execute(flyte_workflow, {}) remote.terminate(execution_to_terminate, cause="just because")
def test_execute_python_workflow_list_of_floats(flyteclient, flyte_workflows_register, flyte_remote_env): """Test execution of a @workflow-decorated python function and launchplan that are already registered.""" from mock_flyte_repo.workflows.basic.list_float_wf import my_wf # make sure the task name is the same as the name used during registration my_wf._name = my_wf.name.replace("mock_flyte_repo.", "") remote = FlyteRemote(Config.auto(), PROJECT, "development") xs: typing.List[float] = [42.24, 999.1, 0.0001] execution = remote.execute(my_wf, inputs={"xs": xs}, version=f"v{VERSION}", wait=True) assert execution.outputs["o0"] == "[42.24, 999.1, 0.0001]" launch_plan = LaunchPlan.get_or_create(workflow=my_wf, name=my_wf.name) execution = remote.execute(launch_plan, inputs={"xs": [-1.1, 0.12345]}, version=f"v{VERSION}", wait=True) assert execution.outputs["o0"] == "[-1.1, 0.12345]"
def fetch_execute_launch_plan_with_args(flyteclient, flyte_workflows_register): remote = FlyteRemote(Config.auto(), PROJECT, "development") flyte_launch_plan = remote.fetch_launch_plan( name="workflows.basic.basic_workflow.my_wf", version=f"v{VERSION}") execution = remote.execute(flyte_launch_plan, { "a": 10, "b": "foobar" }, wait=True) assert execution.node_executions["n0"].inputs == {"a": 10} assert execution.node_executions["n0"].outputs == { "t1_int_output": 12, "c": "world" } assert execution.node_executions["n1"].inputs == { "a": "world", "b": "foobar" } assert execution.node_executions["n1"].outputs == {"o0": "foobarworld"} assert execution.node_executions["n0"].task_executions[0].inputs == { "a": 10 } assert execution.node_executions["n0"].task_executions[0].outputs == { "t1_int_output": 12, "c": "world" } assert execution.node_executions["n1"].task_executions[0].inputs == { "a": "world", "b": "foobar" } assert execution.node_executions["n1"].task_executions[0].outputs == { "o0": "foobarworld" } assert execution.inputs["a"] == 10 assert execution.inputs["b"] == "foobar" assert execution.outputs["o0"] == 12 assert execution.outputs["o1"] == "foobarworld"
def test_passing_of_kwargs(mock_client): additional_args = { "credentials": 1, "options": 2, "private_key": 3, "compression": 4, "root_certificates": 5, "certificate_chain": 6, } FlyteRemote(config=Config.auto(), default_project="project", default_domain="domain", **additional_args) assert mock_client.called assert mock_client.call_args[1] == additional_args
def test_more_stuff(mock_client): r = FlyteRemote(config=Config.auto(), default_project="project", default_domain="domain") # Can't upload a folder with pytest.raises(ValueError): with tempfile.TemporaryDirectory() as tmp_dir: r._upload_file(pathlib.Path(tmp_dir)) # Test that this copies the file. with tempfile.TemporaryDirectory() as tmp_dir: mm = MagicMock() mm.signed_url = os.path.join(tmp_dir, "tmp_file") mock_client.return_value.get_upload_signed_url.return_value = mm r._upload_file(pathlib.Path(__file__)) serialization_settings = flytekit.configuration.SerializationSettings( project="project", domain="domain", version="version", env=None, image_config=ImageConfig.auto(img_name=DefaultImages.default_image()), ) # gives a thing computed_v = r._version_from_hash(b"", serialization_settings) assert len(computed_v) > 0 # gives the same thing computed_v2 = r._version_from_hash(b"", serialization_settings) assert computed_v2 == computed_v2 # should give a different thing computed_v3 = r._version_from_hash(b"", serialization_settings, "hi") assert computed_v2 != computed_v3
def get_and_save_remote_with_click_context(ctx: click.Context, project: str, domain: str, save: bool = True) -> FlyteRemote: """ NB: This function will by default mutate the click Context.obj dictionary, adding a remote key with value of the created FlyteRemote object. :param ctx: the click context object :param project: default project for the remote instance :param domain: default domain :param save: If false, will not mutate the context.obj dict :return: FlyteRemote instance """ cfg_file_location = ctx.obj.get(CTX_CONFIG_FILE) cfg_obj = Config.auto(cfg_file_location) cli_logger.info(f"Creating remote with config {cfg_obj}" + ( f" with file {cfg_file_location}" if cfg_file_location else "")) r = FlyteRemote(cfg_obj, default_project=project, default_domain=domain) if save: ctx.obj[FLYTE_REMOTE_INSTANCE_KEY] = r return r
def test_spark_template_with_remote(): @task(task_config=Spark(spark_conf={"spark": "1"})) def my_spark(a: str) -> int: return 10 @task def my_python_task(a: str) -> int: return 10 remote = FlyteRemote(config=Config.for_endpoint(endpoint="localhost", insecure=True), default_project="p1", default_domain="d1") mock_client = MagicMock() remote._client = mock_client remote.register_task( my_spark, serialization_settings=SerializationSettings( image_config=MagicMock(), ), version="v1", ) serialized_spec = mock_client.create_task.call_args.kwargs["task_spec"] print(serialized_spec) # Check if the serialized spark task has mainApplicaitonFile field set. assert serialized_spec.template.custom["mainApplicationFile"] assert serialized_spec.template.custom["sparkConf"] remote.register_task( my_python_task, serialization_settings=SerializationSettings(image_config=MagicMock()), version="v1") serialized_spec = mock_client.create_task.call_args.kwargs["task_spec"] # Check if the serialized python task has no mainApplicaitonFile field set by default. assert serialized_spec.template.custom is None
def test_monitor_workflow_execution(flyteclient, flyte_workflows_register, flyte_remote_env): remote = FlyteRemote(Config.auto(), PROJECT, "development") flyte_launch_plan = remote.fetch_launch_plan( name="workflows.basic.hello_world.my_wf", version=f"v{VERSION}") execution = remote.execute(flyte_launch_plan, {}) poll_interval = datetime.timedelta(seconds=1) time_to_give_up = datetime.datetime.utcnow() + datetime.timedelta( seconds=60) execution = remote.sync_execution(execution, sync_nodes=True) while datetime.datetime.utcnow() < time_to_give_up: if execution.is_done: break with pytest.raises( FlyteAssertion, match= "Please wait until the node execution has completed before requesting the outputs" ): execution.outputs time.sleep(poll_interval.total_seconds()) execution = remote.sync_execution(execution, sync_nodes=True) if execution.node_executions: assert execution.node_executions[ "start-node"].closure.phase == 3 # SUCCEEEDED for key in execution.node_executions: assert execution.node_executions[key].closure.phase == 3 assert execution.node_executions["n0"].inputs == {} assert execution.node_executions["n0"].outputs["o0"] == "hello world" assert execution.node_executions["n0"].task_executions[0].inputs == {} assert execution.node_executions["n0"].task_executions[0].outputs[ "o0"] == "hello world" assert execution.inputs == {} assert execution.outputs["o0"] == "hello world"
def test_form_config(): remote = FlyteRemote(config=Config.auto(), default_project="p1", default_domain="d1") assert remote.default_project == "p1" assert remote.default_domain == "d1"
def test_fetch_not_exist_launch_plan(flyteclient): remote = FlyteRemote(Config.auto(), PROJECT, "development") with pytest.raises(FlyteEntityNotExistException): remote.fetch_launch_plan(name="workflows.basic.list_float_wf.fake_wf", version=f"v{VERSION}")
def test_fetch_execute_launch_plan(flyteclient, flyte_workflows_register): remote = FlyteRemote(Config.auto(), PROJECT, "development") flyte_launch_plan = remote.fetch_launch_plan( name="workflows.basic.hello_world.my_wf", version=f"v{VERSION}") execution = remote.execute(flyte_launch_plan, {}, wait=True) assert execution.outputs["o0"] == "hello world"
from .resources import hello_wf ##### # THESE TESTS ARE NOT RUN IN CI. THEY ARE HERE TO MAKE LOCAL TESTING EASIER. # Update these to use these tests IMAGE_STR = "flytecookbook:core-f7af27e23b3935a166645cf96a68583cdd263a87" FETCH_VERSION = "a351b7c7445a8a818cdf87bf1c1cf38b63beddf1" RELEASED_EXAMPLES_VERSION = "a351b7c7445a8a818cdf87bf1c1cf38b63beddf1" ##### image_config = ImageConfig.auto(img_name=IMAGE_STR) rr = FlyteRemote( Config.for_sandbox(), default_project="flytesnacks", default_domain="development", ) def get_get_version(): _VERSION_PREFIX = "sandbox_test_" + uuid.uuid4().hex[:3] logger.warning(f"Test version prefix is {_VERSION_PREFIX}") print(f"fdsafdsaTest version prefix is {_VERSION_PREFIX}") def fn(suffix: str = "") -> str: return _VERSION_PREFIX + (f"_{suffix}" if suffix else "") return fn