Ejemplo n.º 1
0
    def test_to(self) -> None:
        class X:
            def to(self, *args, **kwargs):
                self.args = (args, kwargs)
                return self

        e = Environment()
        e["key"] = X()
        e["x"] = torch.tensor(0)
        e["y"] = PaddedSequenceWithMask(torch.tensor(0.0), torch.tensor(True))
        e["z"] = 10
        e.to(device=torch.device("cpu"))
        assert e["key"].args == ((), {"device": torch.device("cpu")})
Ejemplo n.º 2
0
 def _to(self, x: Environment) -> Environment:
     params = list(self.encoder.parameters())
     if len(params) != 0:
         x.to(params[0].device)
     return x