예제 #1
0
def test_literal_types():
    obj = _types.LiteralType(simple=_types.SimpleType.INTEGER)
    assert obj.simple == _types.SimpleType.INTEGER
    assert obj.schema is None
    assert obj.collection_type is None
    assert obj.map_value_type is None
    assert obj == _types.LiteralType.from_flyte_idl(obj.to_flyte_idl())

    schema_type = _types.SchemaType([
        _types.SchemaType.SchemaColumn(
            "a", _types.SchemaType.SchemaColumn.SchemaColumnType.INTEGER),
        _types.SchemaType.SchemaColumn(
            "b", _types.SchemaType.SchemaColumn.SchemaColumnType.FLOAT),
        _types.SchemaType.SchemaColumn(
            "c", _types.SchemaType.SchemaColumn.SchemaColumnType.STRING),
        _types.SchemaType.SchemaColumn(
            "d", _types.SchemaType.SchemaColumn.SchemaColumnType.DATETIME),
        _types.SchemaType.SchemaColumn(
            "e", _types.SchemaType.SchemaColumn.SchemaColumnType.DURATION),
        _types.SchemaType.SchemaColumn(
            "f", _types.SchemaType.SchemaColumn.SchemaColumnType.BOOLEAN)
    ])
    obj = _types.LiteralType(schema=schema_type)
    assert obj.simple is None
    assert obj.schema == schema_type
    assert obj.collection_type is None
    assert obj.map_value_type is None
    assert obj == _types.LiteralType.from_flyte_idl(obj.to_flyte_idl())
예제 #2
0
def test_default_python_task():
    assert isinstance(default_task, _spark_task.SdkSparkTask)
    assert isinstance(default_task, _sdk_runnable.SdkRunnableTask)
    assert default_task.interface.inputs["in1"].description == ""
    assert default_task.interface.inputs[
        "in1"].type == _type_models.LiteralType(
            simple=_type_models.SimpleType.INTEGER)
    assert default_task.interface.outputs["out1"].description == ""
    assert default_task.interface.outputs[
        "out1"].type == _type_models.LiteralType(
            simple=_type_models.SimpleType.STRING)
    assert default_task.type == _common_constants.SdkTaskType.SPARK_TASK
    assert default_task.task_function_name == "default_task"
    assert default_task.task_module == __name__
    assert default_task.metadata.timeout == _datetime.timedelta(seconds=0)
    assert default_task.metadata.deprecated_error_message == ""
    assert default_task.metadata.discoverable is False
    assert default_task.metadata.discovery_version == ""
    assert default_task.metadata.retries.retries == 0
    assert len(default_task.container.resources.limits) == 0
    assert len(default_task.container.resources.requests) == 0
    assert default_task.custom["sparkConf"]["A"] == "B"
    assert default_task.custom["hadoopConf"]["C"] == "D"
    assert default_task.hadoop_conf["C"] == "D"
    assert default_task.spark_conf["A"] == "B"
    assert _os.path.abspath(
        _entrypoint.__file__
    )[:-1] in default_task.custom["mainApplicationFile"]
    assert default_task.custom["executorPath"] == _sys.executable

    pb2 = default_task.to_flyte_idl()
    assert pb2.custom["sparkConf"]["A"] == "B"
    assert pb2.custom["hadoopConf"]["C"] == "D"
예제 #3
0
def test_simple_pytorch_task():
    assert isinstance(simple_pytorch_task, _pytorch_task.SdkPyTorchTask)
    assert isinstance(simple_pytorch_task, _sdk_runnable.SdkRunnableTask)
    assert simple_pytorch_task.interface.inputs["in1"].description == ""
    assert simple_pytorch_task.interface.inputs[
        "in1"].type == _type_models.LiteralType(
            simple=_type_models.SimpleType.INTEGER)
    assert simple_pytorch_task.interface.outputs["out1"].description == ""
    assert simple_pytorch_task.interface.outputs[
        "out1"].type == _type_models.LiteralType(
            simple=_type_models.SimpleType.STRING)
    assert simple_pytorch_task.type == _common_constants.SdkTaskType.PYTORCH_TASK
    assert simple_pytorch_task.task_function_name == "simple_pytorch_task"
    assert simple_pytorch_task.task_module == __name__
    assert simple_pytorch_task.metadata.timeout == _datetime.timedelta(
        seconds=0)
    assert simple_pytorch_task.metadata.deprecated_error_message == ""
    assert simple_pytorch_task.metadata.discoverable is False
    assert simple_pytorch_task.metadata.discovery_version == ""
    assert simple_pytorch_task.metadata.retries.retries == 0
    assert len(simple_pytorch_task.container.resources.limits) == 0
    assert len(simple_pytorch_task.container.resources.requests) == 0
    assert simple_pytorch_task.custom["workers"] == 1
    # Should strip out the venv component of the args.
    assert simple_pytorch_task._get_container_definition(
    ).args[0] == "pyflyte-execute"

    pb2 = simple_pytorch_task.to_flyte_idl()
    assert pb2.custom["workers"] == 1
def test_simple_hpo_job_task():
    assert isinstance(simple_xgboost_hpo_job_task,
                      SdkSimpleHyperparameterTuningJobTask)
    assert isinstance(simple_xgboost_hpo_job_task, _sdk_task.SdkTask)
    # Checking if the input of the underlying SdkTrainingJobTask has been embedded
    assert simple_xgboost_hpo_job_task.interface.inputs[
        "train"].description == ""
    assert (simple_xgboost_hpo_job_task.interface.inputs["train"].type ==
            _sdk_types.Types.MultiPartCSV.to_flyte_literal_type())
    assert simple_xgboost_hpo_job_task.interface.inputs[
        "train"].type == _idl_types.LiteralType(blob=_core_types.BlobType(
            format="csv",
            dimensionality=_core_types.BlobType.BlobDimensionality.MULTIPART,
        ))
    assert simple_xgboost_hpo_job_task.interface.inputs[
        "validation"].description == ""
    assert (simple_xgboost_hpo_job_task.interface.inputs["validation"].type ==
            _sdk_types.Types.MultiPartCSV.to_flyte_literal_type())
    assert simple_xgboost_hpo_job_task.interface.inputs[
        "validation"].type == _idl_types.LiteralType(blob=_core_types.BlobType(
            format="csv",
            dimensionality=_core_types.BlobType.BlobDimensionality.MULTIPART,
        ))
    assert simple_xgboost_hpo_job_task.interface.inputs[
        "static_hyperparameters"].description == ""
    assert (
        simple_xgboost_hpo_job_task.interface.inputs["static_hyperparameters"].
        type == _sdk_types.Types.Generic.to_flyte_literal_type())

    # Checking if the hpo-specific input is defined
    assert simple_xgboost_hpo_job_task.interface.inputs[
        "hyperparameter_tuning_job_config"].description == ""
    assert (simple_xgboost_hpo_job_task.interface.
            inputs["hyperparameter_tuning_job_config"].type ==
            _sdk_types.Types.Proto(_pb2_HPOJobConfig).to_flyte_literal_type())
    assert simple_xgboost_hpo_job_task.interface.outputs[
        "model"].description == ""
    assert simple_xgboost_hpo_job_task.interface.outputs[
        "model"].type == _sdk_types.Types.Blob.to_flyte_literal_type()
    assert simple_xgboost_hpo_job_task.type == _common_constants.SdkTaskType.SAGEMAKER_HYPERPARAMETER_TUNING_JOB_TASK

    # Checking if the spec of the TrainingJob is embedded into the custom field
    # of this SdkSimpleHyperparameterTuningJobTask
    assert simple_xgboost_hpo_job_task.to_flyte_idl(
    ).custom["trainingJob"] == (
        builtin_algorithm_training_job_task2.to_flyte_idl().custom)

    assert simple_xgboost_hpo_job_task.metadata.timeout == _datetime.timedelta(
        seconds=0)
    assert simple_xgboost_hpo_job_task.metadata.discoverable is True
    assert simple_xgboost_hpo_job_task.metadata.discovery_version == "1"
    assert simple_xgboost_hpo_job_task.metadata.retries.retries == 2

    assert simple_xgboost_hpo_job_task.metadata.deprecated_error_message == ""
    assert "metricDefinitions" in simple_xgboost_hpo_job_task.custom[
        "trainingJob"]["algorithmSpecification"].keys()
    assert len(simple_xgboost_hpo_job_task.custom["trainingJob"]
               ["algorithmSpecification"]["metricDefinitions"]) == 1
    """
예제 #5
0
def test_guessing_containers():
    b = model_types.LiteralType(simple=model_types.SimpleType.BOOLEAN)
    lt = model_types.LiteralType(collection_type=b)
    pt = TypeEngine.guess_python_type(lt)
    assert pt == typing.List[bool]

    dur = model_types.LiteralType(simple=model_types.SimpleType.DURATION)
    lt = model_types.LiteralType(map_value_type=dur)
    pt = TypeEngine.guess_python_type(lt)
    assert pt == typing.Dict[str, timedelta]
예제 #6
0
 def _get_dataset_column_literal_type(self,
                                      t: Type) -> type_models.LiteralType:
     if t in self._SUPPORTED_TYPES:
         return self._SUPPORTED_TYPES[t]
     if hasattr(t, "__origin__") and t.__origin__ == list:
         return type_models.LiteralType(
             collection_type=self._get_dataset_column_literal_type(
                 t.__args__[0]))
     if hasattr(t, "__origin__") and t.__origin__ == dict:
         return type_models.LiteralType(
             map_value_type=self._get_dataset_column_literal_type(
                 t.__args__[1]))
     raise AssertionError(
         f"type {t} is currently not supported by StructuredDataset")
예제 #7
0
def test_default_python_task():
    assert isinstance(default_task, _sdk_runnable.SdkRunnableTask)
    assert default_task.interface.inputs['in1'].description == ''
    assert default_task.interface.inputs['in1'].type == \
        _type_models.LiteralType(simple=_type_models.SimpleType.INTEGER)
    assert default_task.interface.outputs['out1'].description == ''
    assert default_task.interface.outputs['out1'].type == \
        _type_models.LiteralType(simple=_type_models.SimpleType.STRING)
    assert default_task.type == _common_constants.SdkTaskType.PYTHON_TASK
    assert default_task.task_function_name == 'default_task'
    assert default_task.task_module == __name__
    assert default_task.metadata.timeout == _datetime.timedelta(seconds=0)
    assert default_task.metadata.deprecated_error_message == ''
    assert default_task.metadata.discoverable is False
    assert default_task.metadata.discovery_version == ''
    assert default_task.metadata.retries.retries == 0
예제 #8
0
def test_workflow_template():
    task = _workflow.TaskNode(reference_id=_generic_id)
    nm = _get_sample_node_metadata()
    int_type = _types.LiteralType(_types.SimpleType.INTEGER)
    wf_metadata = _workflow.WorkflowMetadata()
    wf_metadata_defaults = _workflow.WorkflowMetadataDefaults()
    typed_interface = _interface.TypedInterface(
        {"a": _interface.Variable(int_type, "description1")},
        {
            "b": _interface.Variable(int_type, "description2"),
            "c": _interface.Variable(int_type, "description3")
        },
    )
    wf_node = _workflow.Node(
        id="some:node:id",
        metadata=nm,
        inputs=[],
        upstream_node_ids=[],
        output_aliases=[],
        task_node=task,
    )
    obj = _workflow.WorkflowTemplate(
        id=_generic_id,
        metadata=wf_metadata,
        metadata_defaults=wf_metadata_defaults,
        interface=typed_interface,
        nodes=[wf_node],
        outputs=[],
    )
    obj2 = _workflow.WorkflowTemplate.from_flyte_idl(obj.to_flyte_idl())
    assert obj2 == obj
예제 #9
0
def test_workflow_template_with_queuing_budget():
    task = _workflow.TaskNode(reference_id=_generic_id)
    nm = _get_sample_node_metadata()
    int_type = _types.LiteralType(_types.SimpleType.INTEGER)
    wf_metadata = _workflow.WorkflowMetadata(queuing_budget=timedelta(
        seconds=10))
    wf_metadata_defaults = _workflow.WorkflowMetadataDefaults()
    typed_interface = _interface.TypedInterface(
        {'a': _interface.Variable(int_type, "description1")}, {
            'b': _interface.Variable(int_type, "description2"),
            'c': _interface.Variable(int_type, "description3")
        })
    wf_node = _workflow.Node(id='some:node:id',
                             metadata=nm,
                             inputs=[],
                             upstream_node_ids=[],
                             output_aliases=[],
                             task_node=task)
    obj = _workflow.WorkflowTemplate(id=_generic_id,
                                     metadata=wf_metadata,
                                     metadata_defaults=wf_metadata_defaults,
                                     interface=typed_interface,
                                     nodes=[wf_node],
                                     outputs=[])
    obj2 = _workflow.WorkflowTemplate.from_flyte_idl(obj.to_flyte_idl())
    assert obj2 == obj
예제 #10
0
 def to_flyte_literal_type(cls):
     """
     :rtype: flytekit.models.types.LiteralType
     """
     return _idl_types.LiteralType(blob=_core_types.BlobType(
         format="",
         dimensionality=_core_types.BlobType.BlobDimensionality.SINGLE))
예제 #11
0
def test_literal_collections(literal_type):
    obj = _types.LiteralType(collection_type=literal_type)
    assert obj.collection_type == literal_type
    assert obj.simple is None
    assert obj.schema is None
    assert obj.map_value_type is None
    assert obj == _types.LiteralType.from_flyte_idl(obj.to_flyte_idl())
예제 #12
0
def test_variable_map():
    v = interface.Variable(types.LiteralType(simple=types.SimpleType.BOOLEAN),
                           'asdf asdf asdf')
    obj = interface.VariableMap({'vvv': v})

    obj2 = interface.VariableMap.from_flyte_idl(obj.to_flyte_idl())
    assert obj == obj2
예제 #13
0
 def to_flyte_literal_type(cls):
     """
     :rtype: flytekit.models.types.LiteralType
     """
     return _idl_types.LiteralType(
         simple=_idl_types.SimpleType.BINARY,
         metadata={cls.PB_FIELD_KEY: cls.descriptor})
예제 #14
0
def test_launch_plan_spec():
    identifier_model = identifier.Identifier(identifier.ResourceType.TASK,
                                             "project", "domain", "name",
                                             "version")

    s = schedule.Schedule("asdf", "1 3 4 5 6 7")
    launch_plan_metadata_model = launch_plan.LaunchPlanMetadata(
        schedule=s, notifications=[])

    v = interface.Variable(types.LiteralType(simple=types.SimpleType.BOOLEAN),
                           "asdf asdf asdf")
    p = interface.Parameter(var=v)
    parameter_map = interface.ParameterMap({"ppp": p})

    fixed_inputs = literals.LiteralMap({
        "a":
        literals.Literal(scalar=literals.Scalar(primitive=literals.Primitive(
            integer=1)))
    })

    labels_model = common.Labels({})
    annotations_model = common.Annotations({"my": "annotation"})

    auth_role_model = common.AuthRole(assumable_iam_role="my:iam:role")
    raw_data_output_config = common.RawOutputDataConfig("s3://bucket")
    empty_raw_data_output_config = common.RawOutputDataConfig("")
    max_parallelism = 100

    lp_spec_raw_output_prefixed = launch_plan.LaunchPlanSpec(
        identifier_model,
        launch_plan_metadata_model,
        parameter_map,
        fixed_inputs,
        labels_model,
        annotations_model,
        auth_role_model,
        raw_data_output_config,
        max_parallelism,
    )

    obj2 = launch_plan.LaunchPlanSpec.from_flyte_idl(
        lp_spec_raw_output_prefixed.to_flyte_idl())
    assert obj2 == lp_spec_raw_output_prefixed

    lp_spec_no_prefix = launch_plan.LaunchPlanSpec(
        identifier_model,
        launch_plan_metadata_model,
        parameter_map,
        fixed_inputs,
        labels_model,
        annotations_model,
        auth_role_model,
        empty_raw_data_output_config,
        max_parallelism,
    )

    obj2 = launch_plan.LaunchPlanSpec.from_flyte_idl(
        lp_spec_no_prefix.to_flyte_idl())
    assert obj2 == lp_spec_no_prefix
예제 #15
0
def test_annotated_literal_types():
    obj = _types.LiteralType(simple=_types.SimpleType.INTEGER, annotation=TypeAnnotation(annotations={"foo": "bar"}))
    assert obj.simple == _types.SimpleType.INTEGER
    assert obj.schema is None
    assert obj.collection_type is None
    assert obj.map_value_type is None
    assert obj.annotation.annotations == {"foo": "bar"}
    assert obj == _types.LiteralType.from_flyte_idl(obj.to_flyte_idl())
예제 #16
0
def test_unloadable_proto_from_literal_type():
    with pytest.raises(_user_exceptions.FlyteAssertion):
        _flyte_engine.FlyteDefaultTypeEngine().get_sdk_type_from_literal_type(
            _type_models.LiteralType(
                simple=_type_models.SimpleType.BINARY,
                metadata={_proto.Protobuf.PB_FIELD_KEY: "flyteidl.core.errors_pb2.ContainerErrorNoExist"},
            )
        )
예제 #17
0
def test_parameter_map():
    v = interface.Variable(types.LiteralType(simple=types.SimpleType.BOOLEAN),
                           'asdf asdf asdf')
    p = interface.Parameter(var=v)

    obj = interface.ParameterMap({'ppp': p})
    obj2 = interface.ParameterMap.from_flyte_idl(obj.to_flyte_idl())
    assert obj == obj2
예제 #18
0
def test_guessing_basic():
    b = model_types.LiteralType(simple=model_types.SimpleType.BOOLEAN)
    pt = TypeEngine.guess_python_type(b)
    assert pt is bool

    lt = model_types.LiteralType(simple=model_types.SimpleType.INTEGER)
    pt = TypeEngine.guess_python_type(lt)
    assert pt is int

    lt = model_types.LiteralType(simple=model_types.SimpleType.STRING)
    pt = TypeEngine.guess_python_type(lt)
    assert pt is str

    lt = model_types.LiteralType(simple=model_types.SimpleType.DURATION)
    pt = TypeEngine.guess_python_type(lt)
    assert pt is timedelta

    lt = model_types.LiteralType(simple=model_types.SimpleType.DATETIME)
    pt = TypeEngine.guess_python_type(lt)
    assert pt is datetime.datetime

    lt = model_types.LiteralType(simple=model_types.SimpleType.FLOAT)
    pt = TypeEngine.guess_python_type(lt)
    assert pt is float

    lt = model_types.LiteralType(simple=model_types.SimpleType.NONE)
    pt = TypeEngine.guess_python_type(lt)
    assert pt is None
예제 #19
0
 def get_literal_type(self, t: Type[T]) -> LiteralType:
     """
     Only univariate Lists are supported in Flyte
     """
     try:
         sub_type = TypeEngine.to_literal_type(self.get_sub_type(t))
         return _type_models.LiteralType(collection_type=sub_type)
     except Exception as e:
         raise ValueError(f"Type of Generic List type is not supported, {e}")
예제 #20
0
def test_parameter():
    v = interface.Variable(types.LiteralType(simple=types.SimpleType.BOOLEAN),
                           'asdf asdf asdf')
    obj = interface.Parameter(var=v)
    assert obj.var == v

    obj2 = interface.Parameter.from_flyte_idl(obj.to_flyte_idl())
    assert obj == obj2
    assert obj2.var == v
예제 #21
0
def test_task_template__k8s_pod_target():
    int_type = types.LiteralType(types.SimpleType.INTEGER)
    obj = task.TaskTemplate(
        identifier.Identifier(identifier.ResourceType.TASK, "project",
                              "domain", "name", "version"),
        "python",
        task.TaskMetadata(
            False,
            task.RuntimeMetadata(1, "v", "f"),
            timedelta(days=1),
            literal_models.RetryStrategy(5),
            False,
            "1.0",
            "deprecated",
            False,
        ),
        interface_models.TypedInterface(
            # inputs
            {"a": interface_models.Variable(int_type, "description1")},
            # outputs
            {
                "b": interface_models.Variable(int_type, "description2"),
                "c": interface_models.Variable(int_type, "description3"),
            },
        ),
        {
            "a": 1,
            "b": {
                "c": 2,
                "d": 3
            }
        },
        config={"a": "b"},
        k8s_pod=task.K8sPod(
            metadata=task.K8sObjectMetadata(labels={"label": "foo"},
                                            annotations={"anno": "bar"}),
            pod_spec={
                "str": "val",
                "int": 1
            },
        ),
    )
    assert obj.id.resource_type == identifier.ResourceType.TASK
    assert obj.id.project == "project"
    assert obj.id.domain == "domain"
    assert obj.id.name == "name"
    assert obj.id.version == "version"
    assert obj.type == "python"
    assert obj.custom == {"a": 1, "b": {"c": 2, "d": 3}}
    assert obj.k8s_pod.metadata == task.K8sObjectMetadata(
        labels={"label": "foo"}, annotations={"anno": "bar"})
    assert obj.k8s_pod.pod_spec == {"str": "val", "int": 1}
    assert text_format.MessageToString(
        obj.to_flyte_idl()) == text_format.MessageToString(
            task.TaskTemplate.from_flyte_idl(
                obj.to_flyte_idl()).to_flyte_idl())
    assert obj.config == {"a": "b"}
예제 #22
0
def test_generic_proto_from_literal_type():
    sdk_type = _flyte_engine.FlyteDefaultTypeEngine().get_sdk_type_from_literal_type(
        _type_models.LiteralType(
            simple=_type_models.SimpleType.STRUCT,
            metadata={_proto.Protobuf.PB_FIELD_KEY: "flyteidl.core.errors_pb2.ContainerError"},
        )
    )

    assert sdk_type.pb_type == _errors_pb2.ContainerError
예제 #23
0
 def get_literal_type(self, t: Type[dict]) -> LiteralType:
     tp = self.get_dict_types(t)
     if tp:
         if tp[0] == str:
             try:
                 sub_type = TypeEngine.to_literal_type(tp[1])
                 return _type_models.LiteralType(map_value_type=sub_type)
             except Exception as e:
                 raise ValueError(f"Type of Generic List type is not supported, {e}")
     return _primitives.Generic.to_flyte_literal_type()
예제 #24
0
def test_basic_workflow_promote(mock_task_fetch):
    # This section defines a sample workflow from a user
    @_sdk_tasks.inputs(a=_Types.Integer)
    @_sdk_tasks.outputs(b=_Types.Integer, c=_Types.Integer)
    @_sdk_tasks.python_task()
    def demo_task_for_promote(wf_params, a, b, c):
        b.set(a + 1)
        c.set(a + 2)

    @_sdk_workflow.workflow_class()
    class TestPromoteExampleWf(object):
        wf_input = _sdk_workflow.Input(_Types.Integer, required=True)
        my_task_node = demo_task_for_promote(a=wf_input)
        wf_output_b = _sdk_workflow.Output(my_task_node.outputs.b,
                                           sdk_type=_Types.Integer)
        wf_output_c = _sdk_workflow.Output(my_task_node.outputs.c,
                                           sdk_type=_Types.Integer)

    # This section uses the TaskTemplate stored in Admin to promote back to an Sdk Workflow
    int_type = _types.LiteralType(_types.SimpleType.INTEGER)
    task_interface = _interface.TypedInterface(
        # inputs
        {'a': _interface.Variable(int_type, "description1")},
        # outputs
        {
            'b': _interface.Variable(int_type, "description2"),
            'c': _interface.Variable(int_type, "description3")
        })
    # Since the promotion of a workflow requires retrieving the task from Admin, we mock the SdkTask to return
    task_template = _task_model.TaskTemplate(_identifier.Identifier(
        _identifier.ResourceType.TASK, "project", "domain",
        "tests.flytekit.unit.common_tests.test_workflow_promote.demo_task_for_promote",
        "version"),
                                             "python_container",
                                             get_sample_task_metadata(),
                                             task_interface,
                                             custom={},
                                             container=get_sample_container())
    sdk_promoted_task = _task.SdkTask.promote_from_model(task_template)
    mock_task_fetch.return_value = sdk_promoted_task
    workflow_template = get_workflow_template()
    promoted_wf = _workflow_common.SdkWorkflow.promote_from_model(
        workflow_template)

    assert promoted_wf.interface.inputs[
        "wf_input"] == TestPromoteExampleWf.interface.inputs["wf_input"]
    assert promoted_wf.interface.outputs[
        "wf_output_b"] == TestPromoteExampleWf.interface.outputs["wf_output_b"]
    assert promoted_wf.interface.outputs[
        "wf_output_c"] == TestPromoteExampleWf.interface.outputs["wf_output_c"]

    assert len(promoted_wf.nodes) == 1
    assert len(TestPromoteExampleWf.nodes) == 1
    assert promoted_wf.nodes[0].inputs[0] == TestPromoteExampleWf.nodes[
        0].inputs[0]
예제 #25
0
def test_construct_literal_map_from_variable_map():
    v = Variable(type=types.LiteralType(simple=types.SimpleType.INTEGER), description="some description")
    variable_map = {
        'inputa': v,
    }

    input_txt_dictionary = {'inputa': '15'}

    literal_map = helpers.construct_literal_map_from_variable_map(variable_map, input_txt_dictionary)
    parsed_literal = literal_map.literals['inputa'].value
    ll = literals.Scalar(primitive=literals.Primitive(integer=15))
    assert parsed_literal == ll
예제 #26
0
def test_construct_literal_map_from_parameter_map():
    v = Variable(type=types.LiteralType(simple=types.SimpleType.INTEGER), description="some description")
    p = Parameter(var=v, required=True)
    pm = ParameterMap(parameters={"inputa": p})

    input_txt_dictionary = {"inputa": "15"}

    literal_map = helpers.construct_literal_map_from_parameter_map(pm, input_txt_dictionary)
    parsed_literal = literal_map.literals["inputa"].value
    ll = literals.Scalar(primitive=literals.Primitive(integer=15))
    assert parsed_literal == ll

    with pytest.raises(Exception):
        helpers.construct_literal_map_from_parameter_map(pm, {})
예제 #27
0
def test_lp_closure():
    v = interface.Variable(types.LiteralType(simple=types.SimpleType.BOOLEAN), 'asdf asdf asdf')
    p = interface.Parameter(var=v)
    parameter_map = interface.ParameterMap({'ppp': p})
    parameter_map.to_flyte_idl()
    variable_map = interface.VariableMap({'vvv': v})
    obj = launch_plan.LaunchPlanClosure(state=launch_plan.LaunchPlanState.ACTIVE, expected_inputs=parameter_map,
                                        expected_outputs=variable_map)
    assert obj.expected_inputs == parameter_map
    assert obj.expected_outputs == variable_map

    obj2 = launch_plan.LaunchPlanClosure.from_flyte_idl(obj.to_flyte_idl())
    assert obj == obj2
    assert obj2.expected_inputs == parameter_map
    assert obj2.expected_outputs == variable_map
예제 #28
0
def test_old_style_role():
    identifier_model = identifier.Identifier(identifier.ResourceType.TASK,
                                             "project", "domain", "name",
                                             "version")

    s = schedule.Schedule("asdf", "1 3 4 5 6 7")
    launch_plan_metadata_model = launch_plan.LaunchPlanMetadata(
        schedule=s, notifications=[])

    v = interface.Variable(types.LiteralType(simple=types.SimpleType.BOOLEAN),
                           "asdf asdf asdf")
    p = interface.Parameter(var=v)
    parameter_map = interface.ParameterMap({"ppp": p})

    fixed_inputs = literals.LiteralMap({
        "a":
        literals.Literal(scalar=literals.Scalar(primitive=literals.Primitive(
            integer=1)))
    })

    labels_model = common.Labels({})
    annotations_model = common.Annotations({"my": "annotation"})

    raw_data_output_config = common.RawOutputDataConfig("s3://bucket")

    old_role = _launch_plan_idl.Auth(
        kubernetes_service_account="my:service:account")

    old_style_spec = _launch_plan_idl.LaunchPlanSpec(
        workflow_id=identifier_model.to_flyte_idl(),
        entity_metadata=launch_plan_metadata_model.to_flyte_idl(),
        default_inputs=parameter_map.to_flyte_idl(),
        fixed_inputs=fixed_inputs.to_flyte_idl(),
        labels=labels_model.to_flyte_idl(),
        annotations=annotations_model.to_flyte_idl(),
        raw_output_data_config=raw_data_output_config.to_flyte_idl(),
        auth=old_role,
    )

    lp_spec = launch_plan.LaunchPlanSpec.from_flyte_idl(old_style_spec)

    assert lp_spec.auth_role.assumable_iam_role == "my:service:account"
예제 #29
0
def test_structured_dataset():
    my_cols = [
        _types.StructuredDatasetType.DatasetColumn(
            "a", _types.LiteralType(simple=_types.SimpleType.INTEGER)),
        _types.StructuredDatasetType.DatasetColumn(
            "b", _types.LiteralType(simple=_types.SimpleType.STRING)),
        _types.StructuredDatasetType.DatasetColumn(
            "c",
            _types.LiteralType(collection_type=_types.LiteralType(
                simple=_types.SimpleType.INTEGER))),
        _types.StructuredDatasetType.DatasetColumn(
            "d",
            _types.LiteralType(map_value_type=_types.LiteralType(
                simple=_types.SimpleType.INTEGER))),
    ]
    ds = literals.StructuredDataset(
        uri="s3://bucket",
        metadata=literals.StructuredDatasetMetadata(
            structured_dataset_type=_types.StructuredDatasetType(
                columns=my_cols, format="parquet")),
    )
    obj = literals.Scalar(structured_dataset=ds)
    assert obj.error is None
    assert obj.blob is None
    assert obj.binary is None
    assert obj.schema is None
    assert obj.none_type is None
    assert obj.structured_dataset is not None
    assert obj.value.uri == "s3://bucket"
    assert len(obj.value.metadata.structured_dataset_type.columns) == 4
    obj2 = literals.Scalar.from_flyte_idl(obj.to_flyte_idl())
    assert obj == obj2
    assert obj2.blob is None
    assert obj2.binary is None
    assert obj2.schema is None
    assert obj2.none_type is None
    assert obj2.structured_dataset is not None
    assert obj2.value.uri == "s3://bucket"
    assert len(obj2.value.metadata.structured_dataset_type.columns) == 4
예제 #30
0
 def to_flyte_literal_type(cls):
     """
     :rtype: flytekit.models.types.LiteralType
     """
     return _idl_types.LiteralType(schema=cls.schema_type)