예제 #1
0
    def _validate_call_args(self, args: Tuple[Any, ...],
                            kwargs: Dict[str, Any]) -> None:
        """Validate the call args before they get passed to the run method of the Work.

        Currently, this performs a check against strings that look like filesystem paths and may need to be wrapped with
        a Lightning Path by the user.
        """
        def warn_if_pathlike(obj: Union[os.PathLike, str]):
            if isinstance(obj, Path):
                return
            if os.sep in str(obj) and os.path.exists(obj):
                # NOTE: The existence check is wrong in general, as the file will never exist on the disk
                # where the flow is running unless we are running locally
                warnings.warn(
                    f"You passed a the value {obj!r} as an argument to the `run()` method of {self.work_name} and"
                    f" it looks like this is a path to a file or a folder. Consider wrapping this path in a"
                    f" `lightning_app.storage.Path` object to be able to access these files in your Work.",
                    UserWarning,
                )

        apply_to_collection(args,
                            dtype=(os.PathLike, str),
                            function=warn_if_pathlike)
        apply_to_collection(kwargs,
                            dtype=(os.PathLike, str),
                            function=warn_if_pathlike)
예제 #2
0
def _sanitize_state(state: Dict[str, Any]) -> Dict[str, Any]:
    """Utility function to sanitize the state of a component.

    Sanitization enables the state to be deep-copied and hashed.
    """
    from lightning_app.storage import Drive, Path
    from lightning_app.storage.payload import BasePayload

    def sanitize_path(path: Path) -> Path:
        path_copy = Path(path)
        path_copy._sanitize()
        return path_copy

    def sanitize_payload(payload: BasePayload):
        return type(payload).from_dict(content=payload.to_dict())

    def sanitize_drive(drive: Drive) -> Dict:
        return drive.to_dict()

    state = apply_to_collection(state, dtype=Path, function=sanitize_path)
    state = apply_to_collection(state,
                                dtype=BasePayload,
                                function=sanitize_payload)
    state = apply_to_collection(state, dtype=Drive, function=sanitize_drive)
    return state
예제 #3
0
def test_apply_to_collection_frozen_dataclass():
    @dataclasses.dataclass(frozen=True)
    class Foo:
        input: torch.Tensor

    foo = Foo(torch.tensor(0))

    with pytest.raises(MisconfigurationException,
                       match="frozen dataclass was passed"):
        apply_to_collection(foo, torch.Tensor, lambda t: t.to(torch.int))
예제 #4
0
def _state_to_json(state: Dict[str, Any]) -> Dict[str, Any]:
    """Utility function to make sure that state dict is json serializable."""
    from lightning_app.storage import Path
    from lightning_app.storage.payload import BasePayload

    state_paths_cleaned = apply_to_collection(state,
                                              dtype=(Path, BasePayload),
                                              function=lambda x: x.to_dict())
    state_diff_cleaned = apply_to_collection(state_paths_cleaned,
                                             dtype=type(NotPresent),
                                             function=lambda x: None)
    return state_diff_cleaned
예제 #5
0
def test_apply_to_collection_include_none():
    to_reduce = [1, 2, 3.4, 5.6, 7, (8, 9.1, {10: 10})]

    def fn(x):
        if isinstance(x, float):
            return x

    reduced = apply_to_collection(to_reduce, (int, float), fn)
    assert reduced == [None, None, 3.4, 5.6, None, (None, 9.1, {10: None})]

    reduced = apply_to_collection(to_reduce, (int, float),
                                  fn,
                                  include_none=False)
    assert reduced == [3.4, 5.6, (9.1, {})]
예제 #6
0
    def _process_call_args(
            args: Tuple[Any, ...],
            kwargs: Dict[str, Any]) -> Tuple[Tuple[Any, ...], Dict[str, Any]]:
        """Processes all positional and keyword arguments before they get passed to the caller queue and sent to
        the LightningWork.

        Currently, this method only applies sanitization to Lightning Path objects.

        Args:
            args: The tuple of positional arguments passed to the run method.
            kwargs: The dictionary of named arguments passed to the run method.

        Returns:
            The positional and keyword arguments in the same order they were passed in.
        """
        def sanitize(obj: Union[Path, Drive]) -> Union[Path, Dict]:
            if isinstance(obj, Path):
                # create a copy of the Path and erase the consumer
                # the LightningWork on the receiving end of the caller queue will become the new consumer
                # this is necessary to make the Path deepdiff-hashable
                path_copy = Path(obj)
                path_copy._sanitize()
                path_copy._consumer = None
                return path_copy
            return obj.to_dict()

        return apply_to_collection((args, kwargs),
                                   dtype=(Path, Drive),
                                   function=sanitize)
예제 #7
0
    def _process_call_args(
            self, args: Tuple[Any, ...],
            kwargs: Dict[str, Any]) -> Tuple[Tuple[Any, ...], Dict[str, Any]]:
        """Process the arguments that were passed in to the ``run()`` method of the
        :class:`lightning_app.core.work.LightningWork`.

        This method currently only implements special treatments for the :class:`lightning_app.storage.path.Path`
        objects. Any Path objects that get passed into the run method get attached to the Work automatically, i.e.,
        the Work becomes the `origin` or the `consumer` if they were not already before. Additionally,
        if the file or folder under the Path exists, we transfer it.

        Args:
            args: The tuple of positional arguments passed to the run method.
            kwargs: The dictionary of named arguments passed to the run method.

        Returns:
            The positional and keyword arguments in the same order they were passed in.
        """
        def _attach_work_and_get(
            transporter: Union[Path, Payload, dict]
        ) -> Union[Path, Drive, dict, Any]:
            if not transporter.origin_name:
                # If path/payload is not attached to an origin, there is no need to attach or transfer anything
                return transporter

            transporter._attach_work(self.work)
            transporter._attach_queues(self.work._request_queue,
                                       self.work._response_queue)
            if transporter.exists_remote():
                # All paths/payloads passed to the `run` method under a Lightning obj need to be copied (if they exist)
                if isinstance(transporter, Payload):
                    transporter.get()
                else:
                    transporter.get(overwrite=True)
            return transporter

        def _handle_drive(dict):
            return _maybe_create_drive(self.work_name, dict)

        args, kwargs = apply_to_collection((args, kwargs),
                                           dtype=(Path, Payload),
                                           function=_attach_work_and_get)
        return apply_to_collection((args, kwargs),
                                   dtype=dict,
                                   function=_handle_drive)
예제 #8
0
def test_recursive_application_to_collection():
    ntc = namedtuple("Foo", ["bar"])

    @dataclasses.dataclass
    class Feature:
        input_ids: torch.Tensor
        segment_ids: np.ndarray

        def __eq__(self, o: object) -> bool:
            if not isinstance(o, Feature):
                return NotImplemented
            else:
                return torch.equal(self.input_ids, o.input_ids) and np.equal(
                    self.segment_ids, o.segment_ids).all()

    @dataclasses.dataclass
    class ModelExample:
        example_ids: List[str]
        feature: Feature
        label: torch.Tensor
        some_constant: int = dataclasses.field(init=False)

        def __post_init__(self):
            self.some_constant = 7

        def __eq__(self, o: object) -> bool:
            if not isinstance(o, ModelExample):
                return NotImplemented
            else:
                return (self.example_ids == o.example_ids
                        and self.feature == o.feature
                        and torch.equal(self.label, o.label)
                        and self.some_constant == o.some_constant)

    @dataclasses.dataclass
    class WithClassVar:
        class_var: ClassVar[int] = 0
        dummy: Any

        def __eq__(self, o: object) -> bool:
            if not isinstance(o, WithClassVar):
                return NotImplemented
            elif isinstance(self.dummy, torch.Tensor):
                return torch.equal(self.dummy, o.dummy)
            else:
                return self.dummy == o.dummy

    @dataclasses.dataclass
    class WithInitVar:
        dummy: Any
        override: InitVar[Optional[Any]] = None

        def __post_init__(self, override: Optional[Any]):
            if override is not None:
                self.dummy = override

        def __eq__(self, o: object) -> bool:
            if not isinstance(o, WithInitVar):
                return NotImplemented
            elif isinstance(self.dummy, torch.Tensor):
                return torch.equal(self.dummy, o.dummy)
            else:
                return self.dummy == o.dummy

    @dataclasses.dataclass
    class WithClassAndInitVar:
        class_var: ClassVar[torch.Tensor] = torch.tensor(0)
        dummy: Any
        override: InitVar[Optional[Any]] = torch.tensor(1)

        def __post_init__(self, override: Optional[Any]):
            if override is not None:
                self.dummy = override

        def __eq__(self, o: object) -> bool:
            if not isinstance(o, WithClassAndInitVar):
                return NotImplemented
            elif isinstance(self.dummy, torch.Tensor):
                return torch.equal(self.dummy, o.dummy)
            else:
                return self.dummy == o.dummy

    model_example = ModelExample(
        example_ids=["i-1", "i-2", "i-3"],
        feature=Feature(input_ids=torch.tensor([1.0, 2.0, 3.0]),
                        segment_ids=np.array([4.0, 5.0, 6.0])),
        label=torch.tensor([7.0, 8.0, 9.0]),
    )

    to_reduce = {
        "a":
        torch.tensor([1.0]),  # Tensor
        "b": [torch.tensor([2.0])],  # list
        "c": (torch.tensor([100.0]), ),  # tuple
        "d":
        ntc(bar=5.0),  # named tuple
        "e":
        np.array([10.0]),  # numpy array
        "f":
        "this_is_a_dummy_str",  # string
        "g":
        12.0,  # number
        "h":
        Feature(input_ids=torch.tensor([1.0, 2.0, 3.0]),
                segment_ids=np.array([4.0, 5.0, 6.0])),  # dataclass
        "i":
        model_example,  # nested dataclass
        "j":
        WithClassVar(torch.arange(3)),  # dataclass with class variable
        "k":
        WithInitVar("this_gets_overridden",
                    torch.tensor([2.0])),  # dataclass with init-only variable
        "l":
        WithClassAndInitVar(
            model_example,
            None),  # nested dataclass with class and init-only variables
    }

    model_example_result = ModelExample(
        example_ids=["i-1", "i-2", "i-3"],
        feature=Feature(input_ids=torch.tensor([2.0, 4.0, 6.0]),
                        segment_ids=np.array([8.0, 10.0, 12.0])),
        label=torch.tensor([14.0, 16.0, 18.0]),
    )

    expected_result = {
        "a":
        torch.tensor([2.0]),
        "b": [torch.tensor([4.0])],
        "c": (torch.tensor([200.0]), ),
        "d":
        ntc(bar=torch.tensor([10.0])),
        "e":
        np.array([20.0]),
        "f":
        "this_is_a_dummy_str",
        "g":
        24.0,
        "h":
        Feature(input_ids=torch.tensor([2.0, 4.0, 6.0]),
                segment_ids=np.array([8.0, 10.0, 12.0])),
        "i":
        model_example_result,
        "j":
        WithClassVar(torch.arange(0, 6, 2)),
        "k":
        WithInitVar(torch.tensor([4.0])),
        "l":
        WithClassAndInitVar(model_example_result, None),
    }

    reduced = apply_to_collection(to_reduce,
                                  (torch.Tensor, numbers.Number, np.ndarray),
                                  lambda x: x * 2)

    assert isinstance(reduced, dict), "Type Consistency of dict not preserved"
    assert all(
        x in reduced
        for x in to_reduce), "Not all entries of the dict were preserved"
    assert all(
        isinstance(reduced[k], type(expected_result[k]))
        for k in to_reduce), "At least one type was not correctly preserved"

    assert isinstance(
        reduced["a"],
        torch.Tensor), "Reduction Result of a Tensor should be a Tensor"
    assert torch.equal(
        expected_result["a"], reduced["a"]
    ), "Reduction of a tensor does not yield the expected value"

    assert isinstance(reduced["b"],
                      list), "Reduction Result of a list should be a list"
    assert all(
        torch.equal(x, y) for x, y in zip(reduced["b"], expected_result["b"])
    ), "At least one value of list reduction did not come out as expected"

    assert isinstance(reduced["c"],
                      tuple), "Reduction Result of a tuple should be a tuple"
    assert all(
        torch.equal(x, y) for x, y in zip(reduced["c"], expected_result["c"])
    ), "At least one value of tuple reduction did not come out as expected"

    assert isinstance(reduced["d"],
                      ntc), "Type Consistency for named tuple not given"
    assert isinstance(
        reduced["d"].bar, numbers.Number
    ), "Failure in type promotion while reducing fields of named tuples"
    assert reduced["d"].bar == expected_result["d"].bar

    assert isinstance(
        reduced["e"],
        np.ndarray), "Type Promotion in reduction of numpy arrays failed"
    assert reduced["e"] == expected_result[
        "e"], "Reduction of numpy array did not yield the expected result"

    assert isinstance(reduced["f"], str), "A string should not be reduced"
    assert reduced["f"] == expected_result[
        "f"], "String not preserved during reduction"

    assert isinstance(
        reduced["g"],
        numbers.Number), "Reduction of a number should result in a number"
    assert reduced["g"] == expected_result[
        "g"], "Reduction of a number did not yield the desired result"

    def _assert_dataclass_reduction(actual,
                                    expected,
                                    dataclass_type: str = ""):
        assert dataclasses.is_dataclass(actual) and not isinstance(
            actual, type
        ), f"Reduction of a {dataclass_type} dataclass should result in a dataclass"
        for field in dataclasses.fields(actual):
            if dataclasses.is_dataclass(field.type):
                _assert_dataclass_reduction(getattr(actual, field.name),
                                            getattr(expected, field.name),
                                            "nested")
        assert actual == expected, f"Reduction of a {dataclass_type} dataclass did not yield the desired result"

    _assert_dataclass_reduction(reduced["h"], expected_result["h"])

    _assert_dataclass_reduction(reduced["i"], expected_result["i"])

    dataclass_type = "ClassVar-containing"
    _assert_dataclass_reduction(reduced["j"], expected_result["j"],
                                dataclass_type)
    assert WithClassVar.class_var == 0, f"Reduction of a {dataclass_type} dataclass should not change the class var"

    _assert_dataclass_reduction(reduced["k"], expected_result["k"],
                                "InitVar-containing")

    dataclass_type = "Class-and-InitVar-containing"
    _assert_dataclass_reduction(reduced["l"], expected_result["l"],
                                dataclass_type)
    assert torch.equal(
        WithClassAndInitVar.class_var, torch.tensor(0)
    ), f"Reduction of a {dataclass_type} dataclass should not change the class var"

    # mapping support
    reduced = apply_to_collection({"a": 1, "b": 2}, int, lambda x: str(x))
    assert reduced == {"a": "1", "b": "2"}
    reduced = apply_to_collection(OrderedDict([("b", 2), ("a", 1)]), int,
                                  lambda x: str(x))
    assert reduced == OrderedDict([("b", "2"), ("a", "1")])

    # custom mappings
    class _CustomCollection(dict):
        def __init__(self, initial_dict):
            super().__init__(initial_dict)

    to_reduce = _CustomCollection({"a": 1, "b": 2, "c": 3})
    reduced = apply_to_collection(to_reduce, int, lambda x: str(x))
    assert reduced == _CustomCollection({"a": "1", "b": "2", "c": "3"})

    # defaultdict
    to_reduce = defaultdict(int, {"a": 1, "b": 2, "c": 3})
    reduced = apply_to_collection(to_reduce, int, lambda x: str(x))
    assert reduced == defaultdict(int, {"a": "1", "b": "2", "c": "3"})