Beispiel #1
0
    def test_serialization_2gb_file(self):
        big_model = torch.nn.Conv2d(20000, 3200, kernel_size=3)

        with BytesIOContext() as f:
            torch.save(big_model, f)
            f.seek(0)
            state = torch.load(f)
Beispiel #2
0
 def test_serialization_filelike(self):
     # Test serialization (load and save) with a filelike object
     b = self._test_serialization_data()
     with BytesIOContext() as f:
         torch.save(b, f)
         f.seek(0)
         c = torch.load(f)
     self._test_serialization_assert(b, c)
Beispiel #3
0
    def test_meta_serialization(self):
        big_model = torch.nn.Conv2d(20000, 320000, kernel_size=3, device='meta')

        with BytesIOContext() as f:
            torch.save(big_model, f)
            f.seek(0)
            state = torch.load(f)

        self.assertEqual(state.weight.size(), big_model.weight.size())
    def test_empty_class_serialization(self):
        tensor = TestEmptySubclass([1.])
        # Ensures it runs fine
        tensor2 = copy.copy(tensor)

        with BytesIOContext() as f:
            torch.save(tensor, f)
            f.seek(0)
            tensor2 = torch.load(f)

        tensor = TestEmptySubclass()
        # Ensures it runs fine
        # Note that tensor.data_ptr() == 0 here
        tensor2 = copy.copy(tensor)

        with BytesIOContext() as f:
            torch.save(tensor, f)
            f.seek(0)
            tensor2 = torch.load(f)
Beispiel #5
0
    def test_tensor_subclass_wrapper_serialization(self):
        wrapped_tensor = torch.rand(2)
        my_tensor = TestWrapperSubclass(wrapped_tensor)

        foo_val = "bar"
        my_tensor.foo = foo_val
        self.assertEqual(my_tensor.foo, foo_val)

        with BytesIOContext() as f:
            torch.save(my_tensor, f)
            f.seek(0)
            new_tensor = torch.load(f)

        self.assertIsInstance(new_tensor, TestWrapperSubclass)
        self.assertEqual(new_tensor.elem, my_tensor.elem)
        self.assertEqual(new_tensor.foo, foo_val)
Beispiel #6
0
    def test_tensor_subclass_getstate_overwrite(self):
        wrapped_tensor = torch.rand(2)
        my_tensor = TestGetStateSubclass(wrapped_tensor)

        foo_val = "bar"
        my_tensor.foo = foo_val
        self.assertEqual(my_tensor.foo, foo_val)

        with BytesIOContext() as f:
            torch.save(my_tensor, f)
            f.seek(0)
            new_tensor = torch.load(f)

        self.assertIsInstance(new_tensor, TestGetStateSubclass)
        self.assertEqual(new_tensor.elem, my_tensor.elem)
        self.assertEqual(new_tensor.foo, foo_val)
        self.assertTrue(new_tensor.reloaded)
Beispiel #7
0
 def test_serialization_offset_filelike(self):
     a = torch.randn(5, 5)
     b = torch.randn(1024, 1024, 512, dtype=torch.float32)
     i, j = 41, 43
     with BytesIOContext() as f:
         pickle.dump(i, f)
         torch.save(a, f)
         pickle.dump(j, f)
         torch.save(b, f)
         self.assertTrue(f.tell() > 2 * 1024 * 1024 * 1024)
         f.seek(0)
         i_loaded = pickle.load(f)
         a_loaded = torch.load(f)
         j_loaded = pickle.load(f)
         b_loaded = torch.load(f)
     self.assertTrue(torch.equal(a, a_loaded))
     self.assertTrue(torch.equal(b, b_loaded))
     self.assertEqual(i, i_loaded)
     self.assertEqual(j, j_loaded)