def _convert_resource_overrides( resources: typing.Optional[Resources], resource_name: str ) -> [_resources_model.ResourceEntry]: if resources is None: return [] if not isinstance(resources, Resources): raise AssertionError(f"{resource_name} should be specified as flytekit.Resources") resource_entries = [] if resources.cpu is not None: resource_entries.append(_resources_model.ResourceEntry(_resources_model.ResourceName.CPU, resources.cpu)) if resources.mem is not None: resource_entries.append(_resources_model.ResourceEntry(_resources_model.ResourceName.MEMORY, resources.mem)) if resources.gpu is not None: resource_entries.append(_resources_model.ResourceEntry(_resources_model.ResourceName.GPU, resources.gpu)) if resources.storage is not None: resource_entries.append( _resources_model.ResourceEntry(_resources_model.ResourceName.STORAGE, resources.storage) ) if resources.ephemeral_storage is not None: resource_entries.append( _resources_model.ResourceEntry(_resources_model.ResourceName.EPHEMERAL_STORAGE, resources.ephemeral_storage) ) return resource_entries
def test_resource_limits_override(): @task def t1(a: str) -> str: return f"*~*~*~{a}*~*~*~" @workflow def my_wf(a: typing.List[str]) -> typing.List[str]: mappy = map_task(t1) map_node = mappy(a=a).with_overrides(limits=Resources(cpu="2", mem="200", ephemeral_storage="1Gi")) return map_node serialization_settings = flytekit.configuration.SerializationSettings( project="test_proj", domain="test_domain", version="abc", image_config=ImageConfig(Image(name="name", fqn="image", tag="name")), env={}, ) wf_spec = get_serializable(OrderedDict(), serialization_settings, my_wf) assert len(wf_spec.template.nodes) == 1 assert wf_spec.template.nodes[0].task_node.overrides.resources.requests == [] assert wf_spec.template.nodes[0].task_node.overrides.resources.limits == [ _resources_models.ResourceEntry(_resources_models.ResourceName.CPU, "2"), _resources_models.ResourceEntry(_resources_models.ResourceName.MEMORY, "200"), _resources_models.ResourceEntry(_resources_models.ResourceName.EPHEMERAL_STORAGE, "1Gi"), ]
def test_resources(): @task(requests=Resources(cpu="1"), limits=Resources(cpu="2", mem="400M")) def t1(a: int) -> str: a = a + 2 return "now it's " + str(a) @task(requests=Resources(cpu="3")) def t2(a: int) -> str: a = a + 200 return "now it's " + str(a) @workflow def my_wf(a: int) -> str: x = t1(a=a) return x serialization_settings = context_manager.SerializationSettings( project="test_proj", domain="test_domain", version="abc", image_config=ImageConfig(Image(name="name", fqn="image", tag="name")), env={}, ) with context_manager.FlyteContextManager.with_context( context_manager.FlyteContextManager.current_context( ).with_new_compilation_state()): task_spec = get_serializable(OrderedDict(), serialization_settings, t1) assert task_spec.template.container.resources.requests == [ _resource_models.ResourceEntry(_resource_models.ResourceName.CPU, "1") ] assert task_spec.template.container.resources.limits == [ _resource_models.ResourceEntry(_resource_models.ResourceName.CPU, "2"), _resource_models.ResourceEntry( _resource_models.ResourceName.MEMORY, "400M"), ] task_spec2 = get_serializable(OrderedDict(), serialization_settings, t2) assert task_spec2.template.container.resources.requests == [ _resource_models.ResourceEntry(_resource_models.ResourceName.CPU, "3") ] assert task_spec2.template.container.resources.limits == []
def test_resource_overrides(): @task def t1(a: str) -> str: return f"*~*~*~{a}*~*~*~" @workflow def my_wf(a: typing.List[str]) -> typing.List[str]: mappy = map_task(t1) map_node = create_node(mappy, a=a).with_overrides( requests=Resources(cpu="1", mem="100"), limits=Resources(cpu="2", mem="200")) return map_node.o0 serialization_settings = context_manager.SerializationSettings( project="test_proj", domain="test_domain", version="abc", image_config=ImageConfig(Image(name="name", fqn="image", tag="name")), env={}, ) wf_spec = get_serializable(OrderedDict(), serialization_settings, my_wf) assert len(wf_spec.template.nodes) == 1 assert wf_spec.template.nodes[0].task_node.overrides is not None assert wf_spec.template.nodes[ 0].task_node.overrides.resources.requests == [ _resources_models.ResourceEntry(_resources_models.ResourceName.CPU, "1"), _resources_models.ResourceEntry( _resources_models.ResourceName.MEMORY, "100"), ] assert wf_spec.template.nodes[0].task_node.overrides.resources.limits == [ _resources_models.ResourceEntry(_resources_models.ResourceName.CPU, "2"), _resources_models.ResourceEntry(_resources_models.ResourceName.MEMORY, "200"), ]
def from_flyte_idl(cls, pb2_object): resources = Resources.from_flyte_idl(pb2_object.resources) if bool(resources.requests) or bool(resources.limits): return cls(resources=resources) return cls(resources=None)