def test_sanitize_device(self): self.assertIs(ht.sanitize_device('cpu'), ht.cpu) self.assertIs(ht.sanitize_device('cPu'), ht.cpu) self.assertIs(ht.sanitize_device(' CPU '), ht.cpu) self.assertIs(ht.sanitize_device(ht.cpu), ht.cpu) self.assertIs(ht.sanitize_device(None), ht.cpu) with self.assertRaises(ValueError): self.assertIs(ht.sanitize_device('fpu'), ht.cpu) with self.assertRaises(ValueError): self.assertIs(ht.sanitize_device(1), ht.cpu)
def test_sanitize_device_cpu(self): self.assertIs(ht.sanitize_device("cpu"), ht.cpu) self.assertIs(ht.sanitize_device("cPu"), ht.cpu) self.assertIs(ht.sanitize_device(" CPU "), ht.cpu) self.assertIs(ht.sanitize_device(ht.cpu), ht.cpu) self.assertIs(ht.sanitize_device(None), ht.cpu) with self.assertRaises(ValueError): self.assertIs(ht.sanitize_device("fpu"), ht.cpu) with self.assertRaises(ValueError): self.assertIs(ht.sanitize_device(1), ht.cpu)
def test_sanitize_device_gpu(self): if ht.torch.cuda.is_available(): self.assertIs(ht.sanitize_device("gpu"), ht.gpu) self.assertIs(ht.sanitize_device("gPu"), ht.gpu) self.assertIs(ht.sanitize_device(" GPU "), ht.gpu) self.assertIs(ht.sanitize_device(ht.gpu), ht.gpu) self.assertIs(ht.sanitize_device(None), ht.gpu) with self.assertRaises(ValueError): self.assertIs(ht.sanitize_device("fpu"), ht.gpu) with self.assertRaises(ValueError): self.assertIs(ht.sanitize_device(1), ht.gpu)
def test_sanitize_device(self): if os.environ.get("DEVICE") == "gpu": ht.use_device(os.environ.get("DEVICE")) self.assertIs(ht.sanitize_device("gpu"), ht.gpu) self.assertIs(ht.sanitize_device("gPu"), ht.gpu) self.assertIs(ht.sanitize_device(" GPU "), ht.gpu) self.assertIs(ht.sanitize_device(ht.gpu), ht.gpu) self.assertIs(ht.sanitize_device(None), ht.gpu) else: self.assertIs(ht.sanitize_device("cpu"), ht.cpu) self.assertIs(ht.sanitize_device("cPu"), ht.cpu) self.assertIs(ht.sanitize_device(" CPU "), ht.cpu) self.assertIs(ht.sanitize_device(ht.cpu), ht.cpu) self.assertIs(ht.sanitize_device(None), ht.cpu) with self.assertRaises(ValueError): self.assertIs(ht.sanitize_device("fpu"), ht.cpu) with self.assertRaises(ValueError): self.assertIs(ht.sanitize_device(1), ht.cpu)