def _test_real_serialization(self): wannabe = TinyConvNet3D() tik_torch = TikTorch(model=wannabe) with tempdir() as d: the_path = '{}/testfile.nn'.format(d) tik_torch.serialize(to_path=the_path) new_torch = TikTorch.unserialize(from_path=the_path) self.assertIsInstance(new_torch, TikTorch)
def _test_gpu_serialization(self): wannabe = TinyConvNet3D() tik_torch = TikTorch(model=wannabe) tik_torch.cuda() with tempdir() as d: the_path = "{}/testfile.nn".format(d) tik_torch.serialize(to_path=the_path) new_torch = TikTorch.unserialize(from_path=the_path) self.assertIsInstance(new_torch, TikTorch) self.assertTrue(new_torch.is_cuda)