def _validate_call_args(self, args: Tuple[Any, ...], kwargs: Dict[str, Any]) -> None: """Validate the call args before they get passed to the run method of the Work. Currently, this performs a check against strings that look like filesystem paths and may need to be wrapped with a Lightning Path by the user. """ def warn_if_pathlike(obj: Union[os.PathLike, str]): if isinstance(obj, Path): return if os.sep in str(obj) and os.path.exists(obj): # NOTE: The existence check is wrong in general, as the file will never exist on the disk # where the flow is running unless we are running locally warnings.warn( f"You passed a the value {obj!r} as an argument to the `run()` method of {self.work_name} and" f" it looks like this is a path to a file or a folder. Consider wrapping this path in a" f" `lightning_app.storage.Path` object to be able to access these files in your Work.", UserWarning, ) apply_to_collection(args, dtype=(os.PathLike, str), function=warn_if_pathlike) apply_to_collection(kwargs, dtype=(os.PathLike, str), function=warn_if_pathlike)
def _sanitize_state(state: Dict[str, Any]) -> Dict[str, Any]: """Utility function to sanitize the state of a component. Sanitization enables the state to be deep-copied and hashed. """ from lightning_app.storage import Drive, Path from lightning_app.storage.payload import BasePayload def sanitize_path(path: Path) -> Path: path_copy = Path(path) path_copy._sanitize() return path_copy def sanitize_payload(payload: BasePayload): return type(payload).from_dict(content=payload.to_dict()) def sanitize_drive(drive: Drive) -> Dict: return drive.to_dict() state = apply_to_collection(state, dtype=Path, function=sanitize_path) state = apply_to_collection(state, dtype=BasePayload, function=sanitize_payload) state = apply_to_collection(state, dtype=Drive, function=sanitize_drive) return state
def test_apply_to_collection_frozen_dataclass(): @dataclasses.dataclass(frozen=True) class Foo: input: torch.Tensor foo = Foo(torch.tensor(0)) with pytest.raises(MisconfigurationException, match="frozen dataclass was passed"): apply_to_collection(foo, torch.Tensor, lambda t: t.to(torch.int))
def _state_to_json(state: Dict[str, Any]) -> Dict[str, Any]: """Utility function to make sure that state dict is json serializable.""" from lightning_app.storage import Path from lightning_app.storage.payload import BasePayload state_paths_cleaned = apply_to_collection(state, dtype=(Path, BasePayload), function=lambda x: x.to_dict()) state_diff_cleaned = apply_to_collection(state_paths_cleaned, dtype=type(NotPresent), function=lambda x: None) return state_diff_cleaned
def test_apply_to_collection_include_none(): to_reduce = [1, 2, 3.4, 5.6, 7, (8, 9.1, {10: 10})] def fn(x): if isinstance(x, float): return x reduced = apply_to_collection(to_reduce, (int, float), fn) assert reduced == [None, None, 3.4, 5.6, None, (None, 9.1, {10: None})] reduced = apply_to_collection(to_reduce, (int, float), fn, include_none=False) assert reduced == [3.4, 5.6, (9.1, {})]
def _process_call_args( args: Tuple[Any, ...], kwargs: Dict[str, Any]) -> Tuple[Tuple[Any, ...], Dict[str, Any]]: """Processes all positional and keyword arguments before they get passed to the caller queue and sent to the LightningWork. Currently, this method only applies sanitization to Lightning Path objects. Args: args: The tuple of positional arguments passed to the run method. kwargs: The dictionary of named arguments passed to the run method. Returns: The positional and keyword arguments in the same order they were passed in. """ def sanitize(obj: Union[Path, Drive]) -> Union[Path, Dict]: if isinstance(obj, Path): # create a copy of the Path and erase the consumer # the LightningWork on the receiving end of the caller queue will become the new consumer # this is necessary to make the Path deepdiff-hashable path_copy = Path(obj) path_copy._sanitize() path_copy._consumer = None return path_copy return obj.to_dict() return apply_to_collection((args, kwargs), dtype=(Path, Drive), function=sanitize)
def _process_call_args( self, args: Tuple[Any, ...], kwargs: Dict[str, Any]) -> Tuple[Tuple[Any, ...], Dict[str, Any]]: """Process the arguments that were passed in to the ``run()`` method of the :class:`lightning_app.core.work.LightningWork`. This method currently only implements special treatments for the :class:`lightning_app.storage.path.Path` objects. Any Path objects that get passed into the run method get attached to the Work automatically, i.e., the Work becomes the `origin` or the `consumer` if they were not already before. Additionally, if the file or folder under the Path exists, we transfer it. Args: args: The tuple of positional arguments passed to the run method. kwargs: The dictionary of named arguments passed to the run method. Returns: The positional and keyword arguments in the same order they were passed in. """ def _attach_work_and_get( transporter: Union[Path, Payload, dict] ) -> Union[Path, Drive, dict, Any]: if not transporter.origin_name: # If path/payload is not attached to an origin, there is no need to attach or transfer anything return transporter transporter._attach_work(self.work) transporter._attach_queues(self.work._request_queue, self.work._response_queue) if transporter.exists_remote(): # All paths/payloads passed to the `run` method under a Lightning obj need to be copied (if they exist) if isinstance(transporter, Payload): transporter.get() else: transporter.get(overwrite=True) return transporter def _handle_drive(dict): return _maybe_create_drive(self.work_name, dict) args, kwargs = apply_to_collection((args, kwargs), dtype=(Path, Payload), function=_attach_work_and_get) return apply_to_collection((args, kwargs), dtype=dict, function=_handle_drive)
def test_recursive_application_to_collection(): ntc = namedtuple("Foo", ["bar"]) @dataclasses.dataclass class Feature: input_ids: torch.Tensor segment_ids: np.ndarray def __eq__(self, o: object) -> bool: if not isinstance(o, Feature): return NotImplemented else: return torch.equal(self.input_ids, o.input_ids) and np.equal( self.segment_ids, o.segment_ids).all() @dataclasses.dataclass class ModelExample: example_ids: List[str] feature: Feature label: torch.Tensor some_constant: int = dataclasses.field(init=False) def __post_init__(self): self.some_constant = 7 def __eq__(self, o: object) -> bool: if not isinstance(o, ModelExample): return NotImplemented else: return (self.example_ids == o.example_ids and self.feature == o.feature and torch.equal(self.label, o.label) and self.some_constant == o.some_constant) @dataclasses.dataclass class WithClassVar: class_var: ClassVar[int] = 0 dummy: Any def __eq__(self, o: object) -> bool: if not isinstance(o, WithClassVar): return NotImplemented elif isinstance(self.dummy, torch.Tensor): return torch.equal(self.dummy, o.dummy) else: return self.dummy == o.dummy @dataclasses.dataclass class WithInitVar: dummy: Any override: InitVar[Optional[Any]] = None def __post_init__(self, override: Optional[Any]): if override is not None: self.dummy = override def __eq__(self, o: object) -> bool: if not isinstance(o, WithInitVar): return NotImplemented elif isinstance(self.dummy, torch.Tensor): return torch.equal(self.dummy, o.dummy) else: return self.dummy == o.dummy @dataclasses.dataclass class WithClassAndInitVar: class_var: ClassVar[torch.Tensor] = torch.tensor(0) dummy: Any override: InitVar[Optional[Any]] = torch.tensor(1) def __post_init__(self, override: Optional[Any]): if override is not None: self.dummy = override def __eq__(self, o: object) -> bool: if not isinstance(o, WithClassAndInitVar): return NotImplemented elif isinstance(self.dummy, torch.Tensor): return torch.equal(self.dummy, o.dummy) else: return self.dummy == o.dummy model_example = ModelExample( example_ids=["i-1", "i-2", "i-3"], feature=Feature(input_ids=torch.tensor([1.0, 2.0, 3.0]), segment_ids=np.array([4.0, 5.0, 6.0])), label=torch.tensor([7.0, 8.0, 9.0]), ) to_reduce = { "a": torch.tensor([1.0]), # Tensor "b": [torch.tensor([2.0])], # list "c": (torch.tensor([100.0]), ), # tuple "d": ntc(bar=5.0), # named tuple "e": np.array([10.0]), # numpy array "f": "this_is_a_dummy_str", # string "g": 12.0, # number "h": Feature(input_ids=torch.tensor([1.0, 2.0, 3.0]), segment_ids=np.array([4.0, 5.0, 6.0])), # dataclass "i": model_example, # nested dataclass "j": WithClassVar(torch.arange(3)), # dataclass with class variable "k": WithInitVar("this_gets_overridden", torch.tensor([2.0])), # dataclass with init-only variable "l": WithClassAndInitVar( model_example, None), # nested dataclass with class and init-only variables } model_example_result = ModelExample( example_ids=["i-1", "i-2", "i-3"], feature=Feature(input_ids=torch.tensor([2.0, 4.0, 6.0]), segment_ids=np.array([8.0, 10.0, 12.0])), label=torch.tensor([14.0, 16.0, 18.0]), ) expected_result = { "a": torch.tensor([2.0]), "b": [torch.tensor([4.0])], "c": (torch.tensor([200.0]), ), "d": ntc(bar=torch.tensor([10.0])), "e": np.array([20.0]), "f": "this_is_a_dummy_str", "g": 24.0, "h": Feature(input_ids=torch.tensor([2.0, 4.0, 6.0]), segment_ids=np.array([8.0, 10.0, 12.0])), "i": model_example_result, "j": WithClassVar(torch.arange(0, 6, 2)), "k": WithInitVar(torch.tensor([4.0])), "l": WithClassAndInitVar(model_example_result, None), } reduced = apply_to_collection(to_reduce, (torch.Tensor, numbers.Number, np.ndarray), lambda x: x * 2) assert isinstance(reduced, dict), "Type Consistency of dict not preserved" assert all( x in reduced for x in to_reduce), "Not all entries of the dict were preserved" assert all( isinstance(reduced[k], type(expected_result[k])) for k in to_reduce), "At least one type was not correctly preserved" assert isinstance( reduced["a"], torch.Tensor), "Reduction Result of a Tensor should be a Tensor" assert torch.equal( expected_result["a"], reduced["a"] ), "Reduction of a tensor does not yield the expected value" assert isinstance(reduced["b"], list), "Reduction Result of a list should be a list" assert all( torch.equal(x, y) for x, y in zip(reduced["b"], expected_result["b"]) ), "At least one value of list reduction did not come out as expected" assert isinstance(reduced["c"], tuple), "Reduction Result of a tuple should be a tuple" assert all( torch.equal(x, y) for x, y in zip(reduced["c"], expected_result["c"]) ), "At least one value of tuple reduction did not come out as expected" assert isinstance(reduced["d"], ntc), "Type Consistency for named tuple not given" assert isinstance( reduced["d"].bar, numbers.Number ), "Failure in type promotion while reducing fields of named tuples" assert reduced["d"].bar == expected_result["d"].bar assert isinstance( reduced["e"], np.ndarray), "Type Promotion in reduction of numpy arrays failed" assert reduced["e"] == expected_result[ "e"], "Reduction of numpy array did not yield the expected result" assert isinstance(reduced["f"], str), "A string should not be reduced" assert reduced["f"] == expected_result[ "f"], "String not preserved during reduction" assert isinstance( reduced["g"], numbers.Number), "Reduction of a number should result in a number" assert reduced["g"] == expected_result[ "g"], "Reduction of a number did not yield the desired result" def _assert_dataclass_reduction(actual, expected, dataclass_type: str = ""): assert dataclasses.is_dataclass(actual) and not isinstance( actual, type ), f"Reduction of a {dataclass_type} dataclass should result in a dataclass" for field in dataclasses.fields(actual): if dataclasses.is_dataclass(field.type): _assert_dataclass_reduction(getattr(actual, field.name), getattr(expected, field.name), "nested") assert actual == expected, f"Reduction of a {dataclass_type} dataclass did not yield the desired result" _assert_dataclass_reduction(reduced["h"], expected_result["h"]) _assert_dataclass_reduction(reduced["i"], expected_result["i"]) dataclass_type = "ClassVar-containing" _assert_dataclass_reduction(reduced["j"], expected_result["j"], dataclass_type) assert WithClassVar.class_var == 0, f"Reduction of a {dataclass_type} dataclass should not change the class var" _assert_dataclass_reduction(reduced["k"], expected_result["k"], "InitVar-containing") dataclass_type = "Class-and-InitVar-containing" _assert_dataclass_reduction(reduced["l"], expected_result["l"], dataclass_type) assert torch.equal( WithClassAndInitVar.class_var, torch.tensor(0) ), f"Reduction of a {dataclass_type} dataclass should not change the class var" # mapping support reduced = apply_to_collection({"a": 1, "b": 2}, int, lambda x: str(x)) assert reduced == {"a": "1", "b": "2"} reduced = apply_to_collection(OrderedDict([("b", 2), ("a", 1)]), int, lambda x: str(x)) assert reduced == OrderedDict([("b", "2"), ("a", "1")]) # custom mappings class _CustomCollection(dict): def __init__(self, initial_dict): super().__init__(initial_dict) to_reduce = _CustomCollection({"a": 1, "b": 2, "c": 3}) reduced = apply_to_collection(to_reduce, int, lambda x: str(x)) assert reduced == _CustomCollection({"a": "1", "b": "2", "c": "3"}) # defaultdict to_reduce = defaultdict(int, {"a": 1, "b": 2, "c": 3}) reduced = apply_to_collection(to_reduce, int, lambda x: str(x)) assert reduced == defaultdict(int, {"a": "1", "b": "2", "c": "3"})