Ejemplo n.º 1
0
    def test_pickle_placeholder(self):
        placeholder = jax_util.LeafPlaceholder(Union[str, int])
        roundtrip = pickle.loads(pickle.dumps(placeholder))
        self.assertEqual(roundtrip,
                         jax_util.LeafPlaceholder("typing.Union[str, int]"))

        placeholder = jax_util.LeafPlaceholder(Any)
        roundtrip = pickle.loads(pickle.dumps(placeholder))
        self.assertEqual(roundtrip, jax_util.LeafPlaceholder("typing.Any"))
Ejemplo n.º 2
0
    def test_synthesize_dataclass(self):
        @dataclasses.dataclass
        class Inner:
            x: jax_util.NDArray
            y: int
            z: Any

        @dataclasses.dataclass
        class Outer:
            a: str
            b: Inner

        synthesized = jax_util.synthesize_dataclass(Outer)

        self.assertEqual(
            synthesized,
            Outer(a="",
                  b=Inner(x=jax_util.LeafPlaceholder(jax_util.NDArray),
                          y=0,
                          z=jax_util.LeafPlaceholder(Any))))  # type:ignore
Ejemplo n.º 3
0
    def test_synthesize_dataclass(self):
        @dataclasses.dataclass
        class Inner:
            x: jax_util.NDArray
            y: int
            z: Any

        @dataclasses.dataclass
        class Outer:
            a: str
            b: Inner  # pytype: disable=invalid-annotation  # enable-bare-annotations

        synthesized = jax_util.synthesize_dataclass(Outer)

        self.assertEqual(
            synthesized,
            Outer(a="",
                  b=Inner(x=jax_util.LeafPlaceholder(jax_util.NDArray),
                          y=0,
                          z=jax_util.LeafPlaceholder(Any))))  # type:ignore