Beispiel #1
0
def _convert_resource_overrides(
    resources: typing.Optional[Resources], resource_name: str
) -> [_resources_model.ResourceEntry]:
    if resources is None:
        return []
    if not isinstance(resources, Resources):
        raise AssertionError(f"{resource_name} should be specified as flytekit.Resources")
    resource_entries = []
    if resources.cpu is not None:
        resource_entries.append(_resources_model.ResourceEntry(_resources_model.ResourceName.CPU, resources.cpu))

    if resources.mem is not None:
        resource_entries.append(_resources_model.ResourceEntry(_resources_model.ResourceName.MEMORY, resources.mem))

    if resources.gpu is not None:
        resource_entries.append(_resources_model.ResourceEntry(_resources_model.ResourceName.GPU, resources.gpu))

    if resources.storage is not None:
        resource_entries.append(
            _resources_model.ResourceEntry(_resources_model.ResourceName.STORAGE, resources.storage)
        )
    if resources.ephemeral_storage is not None:
        resource_entries.append(
            _resources_model.ResourceEntry(_resources_model.ResourceName.EPHEMERAL_STORAGE, resources.ephemeral_storage)
        )

    return resource_entries
Beispiel #2
0
def test_resource_limits_override():
    @task
    def t1(a: str) -> str:
        return f"*~*~*~{a}*~*~*~"

    @workflow
    def my_wf(a: typing.List[str]) -> typing.List[str]:
        mappy = map_task(t1)
        map_node = mappy(a=a).with_overrides(limits=Resources(cpu="2", mem="200", ephemeral_storage="1Gi"))
        return map_node

    serialization_settings = flytekit.configuration.SerializationSettings(
        project="test_proj",
        domain="test_domain",
        version="abc",
        image_config=ImageConfig(Image(name="name", fqn="image", tag="name")),
        env={},
    )
    wf_spec = get_serializable(OrderedDict(), serialization_settings, my_wf)
    assert len(wf_spec.template.nodes) == 1
    assert wf_spec.template.nodes[0].task_node.overrides.resources.requests == []
    assert wf_spec.template.nodes[0].task_node.overrides.resources.limits == [
        _resources_models.ResourceEntry(_resources_models.ResourceName.CPU, "2"),
        _resources_models.ResourceEntry(_resources_models.ResourceName.MEMORY, "200"),
        _resources_models.ResourceEntry(_resources_models.ResourceName.EPHEMERAL_STORAGE, "1Gi"),
    ]
Beispiel #3
0
def test_resources():
    @task(requests=Resources(cpu="1"), limits=Resources(cpu="2", mem="400M"))
    def t1(a: int) -> str:
        a = a + 2
        return "now it's " + str(a)

    @task(requests=Resources(cpu="3"))
    def t2(a: int) -> str:
        a = a + 200
        return "now it's " + str(a)

    @workflow
    def my_wf(a: int) -> str:
        x = t1(a=a)
        return x

    serialization_settings = context_manager.SerializationSettings(
        project="test_proj",
        domain="test_domain",
        version="abc",
        image_config=ImageConfig(Image(name="name", fqn="image", tag="name")),
        env={},
    )
    with context_manager.FlyteContextManager.with_context(
            context_manager.FlyteContextManager.current_context(
            ).with_new_compilation_state()):
        task_spec = get_serializable(OrderedDict(), serialization_settings, t1)
        assert task_spec.template.container.resources.requests == [
            _resource_models.ResourceEntry(_resource_models.ResourceName.CPU,
                                           "1")
        ]
        assert task_spec.template.container.resources.limits == [
            _resource_models.ResourceEntry(_resource_models.ResourceName.CPU,
                                           "2"),
            _resource_models.ResourceEntry(
                _resource_models.ResourceName.MEMORY, "400M"),
        ]

        task_spec2 = get_serializable(OrderedDict(), serialization_settings,
                                      t2)
        assert task_spec2.template.container.resources.requests == [
            _resource_models.ResourceEntry(_resource_models.ResourceName.CPU,
                                           "3")
        ]
        assert task_spec2.template.container.resources.limits == []
def test_resource_overrides():
    @task
    def t1(a: str) -> str:
        return f"*~*~*~{a}*~*~*~"

    @workflow
    def my_wf(a: typing.List[str]) -> typing.List[str]:
        mappy = map_task(t1)
        map_node = create_node(mappy, a=a).with_overrides(
            requests=Resources(cpu="1", mem="100"),
            limits=Resources(cpu="2", mem="200"))
        return map_node.o0

    serialization_settings = context_manager.SerializationSettings(
        project="test_proj",
        domain="test_domain",
        version="abc",
        image_config=ImageConfig(Image(name="name", fqn="image", tag="name")),
        env={},
    )
    wf_spec = get_serializable(OrderedDict(), serialization_settings, my_wf)
    assert len(wf_spec.template.nodes) == 1
    assert wf_spec.template.nodes[0].task_node.overrides is not None
    assert wf_spec.template.nodes[
        0].task_node.overrides.resources.requests == [
            _resources_models.ResourceEntry(_resources_models.ResourceName.CPU,
                                            "1"),
            _resources_models.ResourceEntry(
                _resources_models.ResourceName.MEMORY, "100"),
        ]

    assert wf_spec.template.nodes[0].task_node.overrides.resources.limits == [
        _resources_models.ResourceEntry(_resources_models.ResourceName.CPU,
                                        "2"),
        _resources_models.ResourceEntry(_resources_models.ResourceName.MEMORY,
                                        "200"),
    ]
Beispiel #5
0
 def from_flyte_idl(cls, pb2_object):
     resources = Resources.from_flyte_idl(pb2_object.resources)
     if bool(resources.requests) or bool(resources.limits):
         return cls(resources=resources)
     return cls(resources=None)