def test_binding_data_map(): b1 = literals.BindingData(scalar=literals.Scalar( primitive=literals.Primitive(integer=5))) b2 = literals.BindingData(scalar=literals.Scalar( primitive=literals.Primitive(integer=57))) b3 = literals.BindingData(scalar=literals.Scalar( primitive=literals.Primitive(integer=2))) binding_map_sub = literals.BindingDataMap(bindings={ "first": b1, "second": b2 }) binding_map = literals.BindingDataMap( bindings={ "three": b3, "sample_map": literals.BindingData(map=binding_map_sub) }) obj = literals.BindingData(map=binding_map) assert obj.scalar is None assert obj.promise is None assert obj.collection is None assert obj.value.bindings["three"].value.value.value == 2 assert obj.value.bindings["sample_map"].value.bindings[ "second"].value.value.value == 57 obj2 = literals.BindingData.from_flyte_idl(obj.to_flyte_idl()) assert obj == obj2 assert obj2.scalar is None assert obj2.promise is None assert obj2.collection is None assert obj2.value.bindings["three"].value.value.value == 2 assert obj2.value.bindings["sample_map"].value.bindings[ "first"].value.value.value == 5
def test_lp_default_handling(): @task def t1(a: int) -> typing.NamedTuple("OutputsBC", t1_int_output=int, c=str): a = a + 2 return a, "world-" + str(a) @workflow def my_wf(a: int, b: int) -> (str, str, int, int): x, y = t1(a=a) u, v = t1(a=b) return y, v, x, u lp = launch_plan.LaunchPlan.create("test1", my_wf) assert len(lp.parameters.parameters) == 2 assert lp.parameters.parameters["a"].required assert lp.parameters.parameters["a"].default is None assert lp.parameters.parameters["b"].required assert lp.parameters.parameters["b"].default is None assert len(lp.fixed_inputs.literals) == 0 lp_with_defaults = launch_plan.LaunchPlan.create("test2", my_wf, default_inputs={"a": 3}) assert len(lp_with_defaults.parameters.parameters) == 2 assert not lp_with_defaults.parameters.parameters["a"].required assert lp_with_defaults.parameters.parameters["a"].default == _literal_models.Literal( scalar=_literal_models.Scalar(primitive=_literal_models.Primitive(integer=3)) ) assert len(lp_with_defaults.fixed_inputs.literals) == 0 lp_with_fixed = launch_plan.LaunchPlan.create("test3", my_wf, fixed_inputs={"a": 3}) assert len(lp_with_fixed.parameters.parameters) == 1 assert len(lp_with_fixed.fixed_inputs.literals) == 1 assert lp_with_fixed.fixed_inputs.literals["a"] == _literal_models.Literal( scalar=_literal_models.Scalar(primitive=_literal_models.Primitive(integer=3)) ) @workflow def my_wf2(a: int, b: int = 42) -> (str, str, int, int): x, y = t1(a=a) u, v = t1(a=b) return y, v, x, u lp = launch_plan.LaunchPlan.create("test4", my_wf2) assert len(lp.parameters.parameters) == 2 assert len(lp.fixed_inputs.literals) == 0 lp_with_defaults = launch_plan.LaunchPlan.create("test5", my_wf2, default_inputs={"a": 3}) assert len(lp_with_defaults.parameters.parameters) == 2 assert len(lp_with_defaults.fixed_inputs.literals) == 0 # Launch plan defaults override wf defaults assert lp_with_defaults(b=3) == ("world-5", "world-5", 5, 5) lp_with_fixed = launch_plan.LaunchPlan.create("test6", my_wf2, fixed_inputs={"a": 3}) assert len(lp_with_fixed.parameters.parameters) == 1 assert len(lp_with_fixed.fixed_inputs.literals) == 1 # Launch plan defaults override wf defaults assert lp_with_fixed(b=3) == ("world-5", "world-5", 5, 5) lp_with_fixed = launch_plan.LaunchPlan.create("test7", my_wf2, fixed_inputs={"b": 3}) assert len(lp_with_fixed.parameters.parameters) == 1 assert len(lp_with_fixed.fixed_inputs.literals) == 1
def test_get_sdk_value_from_literal(): o = _type_helpers.get_sdk_value_from_literal( _literals.Literal(scalar=_literals.Scalar(none_type=_literals.Void()))) assert o.to_python_std() is None o = _type_helpers.get_sdk_value_from_literal( _literals.Literal(scalar=_literals.Scalar(none_type=_literals.Void())), sdk_type=_sdk_types.Types.Integer, ) assert o.to_python_std() is None o = _type_helpers.get_sdk_value_from_literal( _literals.Literal(scalar=_literals.Scalar( primitive=_literals.Primitive(integer=1))), sdk_type=_sdk_types.Types.Integer, ) assert o.to_python_std() == 1 o = _type_helpers.get_sdk_value_from_literal( _literals.Literal(collection=_literals.LiteralCollection([ _literals.Literal(scalar=_literals.Scalar( primitive=_literals.Primitive(integer=1))), _literals.Literal(scalar=_literals.Scalar( none_type=_literals.Void())), ]))) assert o.to_python_std() == [1, None]
def test_node_task_with_inputs(): nm = _get_sample_node_metadata() task = _workflow.TaskNode(reference_id=_generic_id) bd = _literals.BindingData(scalar=_literals.Scalar( primitive=_literals.Primitive(integer=5))) bd2 = _literals.BindingData(scalar=_literals.Scalar( primitive=_literals.Primitive(integer=99))) binding = _literals.Binding(var="myvar", binding=bd) binding2 = _literals.Binding(var="myothervar", binding=bd2) obj = _workflow.Node( id="some:node:id", metadata=nm, inputs=[binding, binding2], upstream_node_ids=[], output_aliases=[], task_node=task, ) assert obj.target == task assert obj.id == "some:node:id" assert obj.metadata == nm assert len(obj.inputs) == 2 assert obj.inputs[0] == binding obj2 = _workflow.Node.from_flyte_idl(obj.to_flyte_idl()) assert obj == obj2 assert obj2.target == task assert obj2.id == "some:node:id" assert obj2.metadata == nm assert len(obj2.inputs) == 2 assert obj2.inputs[1] == binding2
def test_infer_sdk_type_from_literal(): o = _type_helpers.infer_sdk_type_from_literal( _literals.Literal(scalar=_literals.Scalar( primitive=_literals.Primitive(string_value="abc")))) assert o == _sdk_types.Types.String o = _type_helpers.infer_sdk_type_from_literal( _literals.Literal(scalar=_literals.Scalar(none_type=_literals.Void()))) assert o is _base_sdk_types.Void
def test_branch_node(): nm = _get_sample_node_metadata() task = _workflow.TaskNode(reference_id=_generic_id) bd = _literals.BindingData(scalar=_literals.Scalar( primitive=_literals.Primitive(integer=5))) bd2 = _literals.BindingData(scalar=_literals.Scalar( primitive=_literals.Primitive(integer=99))) binding = _literals.Binding(var='myvar', binding=bd) binding2 = _literals.Binding(var='myothervar', binding=bd2) obj = _workflow.Node(id='some:node:id', metadata=nm, inputs=[binding, binding2], upstream_node_ids=[], output_aliases=[], task_node=task) bn = _workflow.BranchNode( _workflow.IfElseBlock( case=_workflow.IfBlock(condition=_condition.BooleanExpression( comparison=_condition.ComparisonExpression( _condition.ComparisonExpression.Operator.EQ, _condition.Operand(primitive=_literals.Primitive( integer=5)), _condition.Operand(primitive=_literals.Primitive( integer=2)))), then_node=obj), other=[ _workflow.IfBlock(condition=_condition.BooleanExpression( conjunction=_condition.ConjunctionExpression( _condition.ConjunctionExpression.LogicalOperator.AND, _condition.BooleanExpression( comparison=_condition.ComparisonExpression( _condition.ComparisonExpression.Operator.EQ, _condition.Operand( primitive=_literals.Primitive(integer=5)), _condition.Operand( primitive=_literals.Primitive( integer=2)))), _condition.BooleanExpression( comparison=_condition.ComparisonExpression( _condition.ComparisonExpression.Operator.EQ, _condition.Operand( primitive=_literals.Primitive(integer=5)), _condition.Operand( primitive=_literals.Primitive( integer=2)))))), then_node=obj) ], else_node=obj))
def __init__(self, pb_object: Union[GeneratedProtocolMessageType, FlyteIdlEntity]): """ :param Union[T, FlyteIdlEntity] pb_object: """ v = pb_object # This section converts an existing proto object (or a subclass of) to the right type expected by this instance # of GenericProto. GenericProto can be used with any protobuf type (not restricted to FlyteType). This makes it # a bit tricky to figure out the right version of the underlying raw proto class to use to populate the final # struct. # If the provided object has to_flyte_idl(), call it to produce a raw proto. if isinstance(pb_object, FlyteIdlEntity): v = pb_object.to_flyte_idl() # A check to ensure the raw proto (v) is of the correct expected type. This also performs one final attempt to # convert it to the correct type by leveraging from_flyte_idl (implemented by all FlyteTypes) in case this class # is initialized with one. expected_type = type(self).pb_type if expected_type != type(v) and expected_type != type(pb_object): if isinstance(type(self).pb_type, FlyteType): v = expected_type.from_flyte_idl(v).to_flyte_idl() else: raise _user_exceptions.FlyteTypeException( received_type=type(pb_object), expected_type=expected_type, received_value=pb_object ) data = v.SerializeToString() super(Protobuf, self).__init__( scalar=_literals.Scalar( binary=_literals.Binary(value=bytes(data) if _six.PY2 else data, tag=type(self).tag) ) )
def test_scalar_binary(): obj = literals.Scalar( binary=literals.Binary( b"value", "taggy" ) ) assert obj.primitive is None assert obj.error is None assert obj.blob is None assert obj.binary is not None assert obj.schema is None assert obj.none_type is None assert obj.binary.tag == "taggy" assert obj.binary.value == b"value" obj2 = literals.Scalar.from_flyte_idl(obj.to_flyte_idl()) assert obj2 == obj assert obj2.primitive is None assert obj2.error is None assert obj2.blob is None assert obj2.binary is not None assert obj2.schema is None assert obj2.none_type is None assert obj2.binary.tag == "taggy" assert obj2.binary.value == b"value"
def test_launch_workflow_with_subworkflows(flyteclient, flyte_workflows_register): execution = launch_plan.FlyteLaunchPlan.fetch( PROJECT, "development", "workflows.basic.subworkflows.parent_wf", f"v{VERSION}").launch_with_literals( PROJECT, "development", literals.LiteralMap({ "a": literals.Literal( literals.Scalar(literals.Primitive(integer=101))) }), ) execution.wait_for_completion() # check node execution inputs and outputs assert execution.node_executions["n0"].inputs == {"a": 101} assert execution.node_executions["n0"].outputs == { "t1_int_output": 103, "c": "world" } assert execution.node_executions["n1"].inputs == {"a": 103} assert execution.node_executions["n1"].outputs == { "o0": "world", "o1": "world" } # check subworkflow task execution inputs and outputs subworkflow_node_executions = execution.node_executions[ "n1"].subworkflow_node_executions subworkflow_node_executions["n1-0-n0"].inputs == {"a": 103} subworkflow_node_executions["n1-0-n1"].outputs == { "t1_int_output": 107, "c": "world" }
def test_scalar_schema(): schema_type = _types.SchemaType([ _types.SchemaType.SchemaColumn("a", _types.SchemaType.SchemaColumn.SchemaColumnType.INTEGER), _types.SchemaType.SchemaColumn("b", _types.SchemaType.SchemaColumn.SchemaColumnType.FLOAT), _types.SchemaType.SchemaColumn("c", _types.SchemaType.SchemaColumn.SchemaColumnType.STRING), _types.SchemaType.SchemaColumn("d", _types.SchemaType.SchemaColumn.SchemaColumnType.DATETIME), _types.SchemaType.SchemaColumn("e", _types.SchemaType.SchemaColumn.SchemaColumnType.DURATION), _types.SchemaType.SchemaColumn("f", _types.SchemaType.SchemaColumn.SchemaColumnType.BOOLEAN) ]) schema = literals.Schema(uri="asdf", type=schema_type) obj = literals.Scalar(schema=schema) assert obj.primitive is None assert obj.error is None assert obj.blob is None assert obj.binary is None assert obj.schema is not None assert obj.none_type is None assert obj.value.type.columns[0].name == 'a' assert len(obj.value.type.columns) == 6 obj2 = literals.Scalar.from_flyte_idl(obj.to_flyte_idl()) assert obj == obj2 assert obj2.primitive is None assert obj2.error is None assert obj2.blob is None assert obj2.binary is None assert obj2.schema is not None assert obj2.none_type is None assert obj2.value.type.columns[0].name == 'a' assert len(obj2.value.type.columns) == 6
def test_launch_plan_spec(): identifier_model = identifier.Identifier(identifier.ResourceType.TASK, "project", "domain", "name", "version") s = schedule.Schedule("asdf", "1 3 4 5 6 7") launch_plan_metadata_model = launch_plan.LaunchPlanMetadata( schedule=s, notifications=[]) v = interface.Variable(types.LiteralType(simple=types.SimpleType.BOOLEAN), "asdf asdf asdf") p = interface.Parameter(var=v) parameter_map = interface.ParameterMap({"ppp": p}) fixed_inputs = literals.LiteralMap({ "a": literals.Literal(scalar=literals.Scalar(primitive=literals.Primitive( integer=1))) }) labels_model = common.Labels({}) annotations_model = common.Annotations({"my": "annotation"}) auth_role_model = common.AuthRole(assumable_iam_role="my:iam:role") raw_data_output_config = common.RawOutputDataConfig("s3://bucket") empty_raw_data_output_config = common.RawOutputDataConfig("") max_parallelism = 100 lp_spec_raw_output_prefixed = launch_plan.LaunchPlanSpec( identifier_model, launch_plan_metadata_model, parameter_map, fixed_inputs, labels_model, annotations_model, auth_role_model, raw_data_output_config, max_parallelism, ) obj2 = launch_plan.LaunchPlanSpec.from_flyte_idl( lp_spec_raw_output_prefixed.to_flyte_idl()) assert obj2 == lp_spec_raw_output_prefixed lp_spec_no_prefix = launch_plan.LaunchPlanSpec( identifier_model, launch_plan_metadata_model, parameter_map, fixed_inputs, labels_model, annotations_model, auth_role_model, empty_raw_data_output_config, max_parallelism, ) obj2 = launch_plan.LaunchPlanSpec.from_flyte_idl( lp_spec_no_prefix.to_flyte_idl()) assert obj2 == lp_spec_no_prefix
def test_launch_workflow_with_args(flyteclient, flyte_workflows_register): execution = launch_plan.FlyteLaunchPlan.fetch( PROJECT, "development", "workflows.basic.basic_workflow.my_wf", f"v{VERSION}").launch_with_literals( PROJECT, "development", literals.LiteralMap({ "a": literals.Literal( literals.Scalar(literals.Primitive(integer=10))), "b": literals.Literal( literals.Scalar( literals.Primitive(string_value="foobar"))), }), ) execution.wait_for_completion() assert execution.node_executions["n0"].inputs == {"a": 10} assert execution.node_executions["n0"].outputs == { "t1_int_output": 12, "c": "world" } assert execution.node_executions["n1"].inputs == { "a": "world", "b": "foobar" } assert execution.node_executions["n1"].outputs == {"o0": "foobarworld"} assert execution.node_executions["n0"].task_executions[0].inputs == { "a": 10 } assert execution.node_executions["n0"].task_executions[0].outputs == { "t1_int_output": 12, "c": "world" } assert execution.node_executions["n1"].task_executions[0].inputs == { "a": "world", "b": "foobar" } assert execution.node_executions["n1"].task_executions[0].outputs == { "o0": "foobarworld" } assert execution.inputs["a"] == 10 assert execution.inputs["b"] == "foobar" assert execution.outputs["o0"] == 12 assert execution.outputs["o1"] == "foobarworld"
def __init__(self, pb_object): """ :param T pb_object: """ data = pb_object.SerializeToString() super( Protobuf, self).__init__(scalar=_literals.Scalar(binary=_literals.Binary( value=bytes(data) if _six.PY2 else data, tag=type(self).tag)))
def test_arrayjob_entrypoint_in_proc(): with _TemporaryConfiguration(os.path.join(os.path.dirname(__file__), 'fake.config'), internal_overrides={ 'project': 'test', 'domain': 'development' }): with _utils.AutoDeletingTempDir("dir") as dir: literal_map = _type_helpers.pack_python_std_map_to_literal_map( {'a': 9}, _type_map_from_variable_map( _task_defs.add_one.interface.inputs)) input_dir = os.path.join(dir.name, "1") os.mkdir( input_dir) # auto cleanup will take this subdir into account input_file = os.path.join(input_dir, "inputs.pb") _utils.write_proto_to_file(literal_map.to_flyte_idl(), input_file) # construct indexlookup.pb which has array: [1] mapped_index = _literals.Literal( _literals.Scalar(primitive=_literals.Primitive(integer=1))) index_lookup_collection = _literals.LiteralCollection( [mapped_index]) index_lookup_file = os.path.join(dir.name, "indexlookup.pb") _utils.write_proto_to_file(index_lookup_collection.to_flyte_idl(), index_lookup_file) # fake arrayjob task by setting environment variables orig_env_index_var_name = os.environ.get( 'BATCH_JOB_ARRAY_INDEX_VAR_NAME') orig_env_array_index = os.environ.get('AWS_BATCH_JOB_ARRAY_INDEX') os.environ[ 'BATCH_JOB_ARRAY_INDEX_VAR_NAME'] = 'AWS_BATCH_JOB_ARRAY_INDEX' os.environ['AWS_BATCH_JOB_ARRAY_INDEX'] = '0' execute_task(_task_defs.add_one.task_module, _task_defs.add_one.task_function_name, dir.name, dir.name, False) raw_map = _type_helpers.unpack_literal_map_to_sdk_python_std( _literal_models.LiteralMap.from_flyte_idl( _utils.load_proto_from_file( _literals_pb2.LiteralMap, os.path.join(input_dir, _constants.OUTPUT_FILE_NAME))), _type_map_from_variable_map( _task_defs.add_one.interface.outputs)) assert raw_map['b'] == 10 assert len(raw_map) == 1 # reset the env vars if orig_env_index_var_name: os.environ[ 'BATCH_JOB_ARRAY_INDEX_VAR_NAME'] = orig_env_index_var_name if orig_env_array_index: os.environ['AWS_BATCH_JOB_ARRAY_INDEX'] = orig_env_array_index
def test_model_promotion(): list_type = containers.List(primitives.Integer) list_model = literals.Literal(collection=literals.LiteralCollection( literals=[ literals.Literal(scalar=literals.Scalar( primitive=literals.Primitive(integer=0))), literals.Literal(scalar=literals.Scalar( primitive=literals.Primitive(integer=1))), literals.Literal(scalar=literals.Scalar( primitive=literals.Primitive(integer=2))), ])) list_obj = list_type.promote_from_model(list_model) assert len(list_obj.collection.literals) == 3 assert isinstance(list_obj.collection.literals[0], primitives.Integer) assert list_obj == list_type.from_python_std([0, 1, 2]) assert list_obj == list_type( [primitives.Integer(0), primitives.Integer(1), primitives.Integer(2)])
def test_infer_proto_from_literal(): sdk_type = _flyte_engine.FlyteDefaultTypeEngine().infer_sdk_type_from_literal( _literal_models.Literal( scalar=_literal_models.Scalar( binary=_literal_models.Binary( value="", tag="{}{}".format(_proto.Protobuf.TAG_PREFIX, "flyteidl.core.errors_pb2.ContainerError",), ) ) ) ) assert sdk_type.pb_type == _errors_pb2.ContainerError
def test_launch_workflow_with_args(flyteclient, flyte_workflows_register): execution = launch_plan.FlyteLaunchPlan.fetch( PROJECT, "development", "workflows.basic.basic_workflow.my_wf", f"v{VERSION}").launch_with_literals( PROJECT, "development", literals.LiteralMap({ "a": literals.Literal( literals.Scalar(literals.Primitive(integer=10))), "b": literals.Literal( literals.Scalar( literals.Primitive(string_value="foobar"))), }), ) execution.wait_for_completion() assert execution.outputs.literals["o0"].scalar.primitive.integer == 12 assert execution.outputs.literals[ "o1"].scalar.primitive.string_value == "foobarworld"
def test_blob_promote_from_model(): m = _literal_models.Literal(scalar=_literal_models.Scalar( blob=_literal_models.Blob( _literal_models.BlobMetadata( _core_types.BlobType(format="f", dimensionality=_core_types.BlobType. BlobDimensionality.SINGLE)), "some/path"))) b = blobs.Blob.promote_from_model(m) assert b.value.blob.uri == "some/path" assert b.value.blob.metadata.type.format == "f" assert b.value.blob.metadata.type.dimensionality == _core_types.BlobType.BlobDimensionality.SINGLE
def test_binding_data_collection(): b1 = literals.BindingData(scalar=literals.Scalar(primitive=literals.Primitive(integer=5))) b2 = literals.BindingData(scalar=literals.Scalar(primitive=literals.Primitive(integer=57))) coll = literals.BindingDataCollection(bindings=[b1, b2]) obj = literals.BindingData(collection=coll) assert obj.scalar is None assert obj.promise is None assert obj.collection is not None assert obj.map is None assert obj.value.bindings[0].value.value.value == 5 assert obj.value.bindings[1].value.value.value == 57 obj2 = literals.BindingData.from_flyte_idl(obj.to_flyte_idl()) assert obj == obj2 assert obj2.scalar is None assert obj2.promise is None assert obj2.collection is not None assert obj2.map is None assert obj2.value.bindings[0].value.value.value == 5 assert obj2.value.bindings[1].value.value.value == 57
def test_construct_literal_map_from_variable_map(): v = Variable(type=types.LiteralType(simple=types.SimpleType.INTEGER), description="some description") variable_map = { 'inputa': v, } input_txt_dictionary = {'inputa': '15'} literal_map = helpers.construct_literal_map_from_variable_map(variable_map, input_txt_dictionary) parsed_literal = literal_map.literals['inputa'].value ll = literals.Scalar(primitive=literals.Primitive(integer=15)) assert parsed_literal == ll
def test_binding_data_scalar(): obj = literals.BindingData(scalar=literals.Scalar(primitive=literals.Primitive(integer=5))) assert obj.value.value.value == 5 assert obj.promise is None assert obj.collection is None assert obj.map is None obj2 = literals.BindingData.from_flyte_idl(obj.to_flyte_idl()) assert obj == obj2 assert obj2.value.value.value == 5 assert obj2.promise is None assert obj2.collection is None assert obj2.map is None
def test_construct_literal_map_from_parameter_map(): v = Variable(type=types.LiteralType(simple=types.SimpleType.INTEGER), description="some description") p = Parameter(var=v, required=True) pm = ParameterMap(parameters={"inputa": p}) input_txt_dictionary = {"inputa": "15"} literal_map = helpers.construct_literal_map_from_parameter_map(pm, input_txt_dictionary) parsed_literal = literal_map.literals["inputa"].value ll = literals.Scalar(primitive=literals.Primitive(integer=15)) assert parsed_literal == ll with pytest.raises(Exception): helpers.construct_literal_map_from_parameter_map(pm, {})
def test_lp_serialize(): @task def t1(a: int) -> typing.NamedTuple("OutputsBC", t1_int_output=int, c=str): a = a + 2 return a, "world-" + str(a) @task def t2(a: str, b: str) -> str: return b + a @workflow def my_subwf(a: int) -> (str, str): x, y = t1(a=a) u, v = t1(a=x) return y, v lp = launch_plan.LaunchPlan.create("serialize_test1", my_subwf) lp_with_defaults = launch_plan.LaunchPlan.create("serialize_test2", my_subwf, default_inputs={"a": 3}) serialization_settings = context_manager.SerializationSettings( project="proj", domain="dom", version="123", image_config=ImageConfig(Image(name="name", fqn="asdf/fdsa", tag="123")), env={}, ) sdk_lp = get_serializable(OrderedDict(), serialization_settings, lp) assert len(sdk_lp.default_inputs.parameters) == 1 assert sdk_lp.default_inputs.parameters["a"].required assert len(sdk_lp.fixed_inputs.literals) == 0 sdk_lp = get_serializable(OrderedDict(), serialization_settings, lp_with_defaults) assert len(sdk_lp.default_inputs.parameters) == 1 assert not sdk_lp.default_inputs.parameters["a"].required assert sdk_lp.default_inputs.parameters[ "a"].default == _literal_models.Literal(scalar=_literal_models.Scalar( primitive=_literal_models.Primitive(integer=3))) assert len(sdk_lp.fixed_inputs.literals) == 0 # Adding a check to make sure oneof is respected. Tricky with booleans... if a default is specified, the # required field needs to be None, not False. parameter_a = sdk_lp.default_inputs.parameters["a"] parameter_a = Parameter.from_flyte_idl(parameter_a.to_flyte_idl()) assert parameter_a.default is not None
def test_scalar_void(): obj = literals.Scalar(none_type=literals.Void()) assert obj.primitive is None assert obj.error is None assert obj.blob is None assert obj.binary is None assert obj.schema is None assert obj.none_type is not None obj2 = literals.Scalar.from_flyte_idl(obj.to_flyte_idl()) assert obj2 == obj assert obj2.primitive is None assert obj2.error is None assert obj2.blob is None assert obj2.binary is None assert obj2.schema is None assert obj2.none_type is not None
def setUp(self): with _utils.AutoDeletingTempDir("input_dir") as input_dir: self._task_input = _literals.LiteralMap({ "input_1": _literals.Literal(scalar=_literals.Scalar( primitive=_literals.Primitive(integer=1))) }) self._context = _common_engine.EngineContext( execution_id=WorkflowExecutionIdentifier(project="unit_test", domain="unit_test", name="unit_test"), execution_date=_datetime.datetime.utcnow(), stats=MockStats(), logging=None, tmp_dir=input_dir.name, ) # Defining the distributed training task without specifying an output-persist # predicate (so it will use the default) @inputs(input_1=Types.Integer) @outputs(model=Types.Blob) @custom_training_job_task( training_job_resource_config=TrainingJobResourceConfig( instance_type="ml.m4.xlarge", instance_count=2, volume_size_in_gb=25, ), algorithm_specification=AlgorithmSpecification( input_mode=InputMode.FILE, input_content_type=InputContentType.TEXT_CSV, metric_definitions=[ MetricDefinition(name="Validation error", regex="validation:error") ], ), ) def my_distributed_task(wf_params, input_1, model): pass self._my_distributed_task = my_distributed_task assert type(self._my_distributed_task) == CustomTrainingJobTask
def test_scalar_primitive(): obj = literals.Scalar(primitive=literals.Primitive(float_value=5.6)) assert obj.value.value == 5.6 assert obj.error is None assert obj.blob is None assert obj.binary is None assert obj.schema is None assert obj.none_type is None x = obj.to_flyte_idl() assert x.primitive.float_value == 5.6 obj2 = literals.Scalar.from_flyte_idl(obj.to_flyte_idl()) assert obj2 == obj assert obj2.error is None assert obj2.blob is None assert obj2.binary is None assert obj2.schema is None assert obj2.none_type is None
def test_old_style_role(): identifier_model = identifier.Identifier(identifier.ResourceType.TASK, "project", "domain", "name", "version") s = schedule.Schedule("asdf", "1 3 4 5 6 7") launch_plan_metadata_model = launch_plan.LaunchPlanMetadata( schedule=s, notifications=[]) v = interface.Variable(types.LiteralType(simple=types.SimpleType.BOOLEAN), "asdf asdf asdf") p = interface.Parameter(var=v) parameter_map = interface.ParameterMap({"ppp": p}) fixed_inputs = literals.LiteralMap({ "a": literals.Literal(scalar=literals.Scalar(primitive=literals.Primitive( integer=1))) }) labels_model = common.Labels({}) annotations_model = common.Annotations({"my": "annotation"}) raw_data_output_config = common.RawOutputDataConfig("s3://bucket") old_role = _launch_plan_idl.Auth( kubernetes_service_account="my:service:account") old_style_spec = _launch_plan_idl.LaunchPlanSpec( workflow_id=identifier_model.to_flyte_idl(), entity_metadata=launch_plan_metadata_model.to_flyte_idl(), default_inputs=parameter_map.to_flyte_idl(), fixed_inputs=fixed_inputs.to_flyte_idl(), labels=labels_model.to_flyte_idl(), annotations=annotations_model.to_flyte_idl(), raw_output_data_config=raw_data_output_config.to_flyte_idl(), auth=old_role, ) lp_spec = launch_plan.LaunchPlanSpec.from_flyte_idl(old_style_spec) assert lp_spec.auth_role.assumable_iam_role == "my:service:account"
def test_structured_dataset(): my_cols = [ _types.StructuredDatasetType.DatasetColumn( "a", _types.LiteralType(simple=_types.SimpleType.INTEGER)), _types.StructuredDatasetType.DatasetColumn( "b", _types.LiteralType(simple=_types.SimpleType.STRING)), _types.StructuredDatasetType.DatasetColumn( "c", _types.LiteralType(collection_type=_types.LiteralType( simple=_types.SimpleType.INTEGER))), _types.StructuredDatasetType.DatasetColumn( "d", _types.LiteralType(map_value_type=_types.LiteralType( simple=_types.SimpleType.INTEGER))), ] ds = literals.StructuredDataset( uri="s3://bucket", metadata=literals.StructuredDatasetMetadata( structured_dataset_type=_types.StructuredDatasetType( columns=my_cols, format="parquet")), ) obj = literals.Scalar(structured_dataset=ds) assert obj.error is None assert obj.blob is None assert obj.binary is None assert obj.schema is None assert obj.none_type is None assert obj.structured_dataset is not None assert obj.value.uri == "s3://bucket" assert len(obj.value.metadata.structured_dataset_type.columns) == 4 obj2 = literals.Scalar.from_flyte_idl(obj.to_flyte_idl()) assert obj == obj2 assert obj2.blob is None assert obj2.binary is None assert obj2.schema is None assert obj2.none_type is None assert obj2.structured_dataset is not None assert obj2.value.uri == "s3://bucket" assert len(obj2.value.metadata.structured_dataset_type.columns) == 4
def __init__(self, value): """ :param flytekit.common.types.impl.schema.Schema value: Schema value to wrap """ super(Schema, self).__init__(scalar=_literals.Scalar(schema=value))
import pytest from flytekit.models import common as _common_models from flytekit.models import execution as _execution from flytekit.models import literals as _literals from flytekit.models.core import execution as _core_exec from flytekit.models.core import identifier as _identifier from tests.flytekit.common import parameterizers as _parameterizers _INPUT_MAP = _literals.LiteralMap({ "a": _literals.Literal(scalar=_literals.Scalar(primitive=_literals.Primitive( integer=1))) }) _OUTPUT_MAP = _literals.LiteralMap({ "b": _literals.Literal(scalar=_literals.Scalar(primitive=_literals.Primitive( integer=2))) }) def test_execution_metadata(): obj = _execution.ExecutionMetadata( _execution.ExecutionMetadata.ExecutionMode.MANUAL, "tester", 1) assert obj.mode == _execution.ExecutionMetadata.ExecutionMode.MANUAL assert obj.principal == "tester" assert obj.nesting == 1 obj2 = _execution.ExecutionMetadata.from_flyte_idl(obj.to_flyte_idl()) assert obj == obj2 assert obj2.mode == _execution.ExecutionMetadata.ExecutionMode.MANUAL assert obj2.principal == "tester"