예제 #1
0
 def test_load_unicode_error_msg(self):
     # This Pickle contains a Python 2 module with Unicode data and the
     # loading should fail if the user explicitly specifies ascii encoding!
     path = download_file(
         'https://download.pytorch.org/test_data/legacy_conv2d.pt')
     self.assertRaises(UnicodeDecodeError,
                       lambda: torch.load(path, encoding='ascii'))
 def test_load_unicode_error_msg(self):
     # This Pickle contains a Python 2 module with Unicode data and the
     # loading should fail if the user explicitly specifies ascii encoding!
     path = download_file('https://download.pytorch.org/test_data/legacy_conv2d.pt')
     if sys.version_info >= (3, 0):
         self.assertRaises(UnicodeDecodeError, lambda: torch.load(path, encoding='ascii'))
     else:
         # Just checks the module loaded
         self.assertIsNotNone(torch.load(path))
예제 #3
0
    def test_serialization_backwards_compat(self):
        a = [torch.arange(1 + i, 26 + i).view(5, 5).float() for i in range(2)]
        b = [a[i % 2] for i in range(4)]
        b += [a[0].storage()]
        b += [a[0].reshape(-1)[1:4].clone().storage()]
        path = download_file(
            'https://download.pytorch.org/test_data/legacy_serialized.pt')
        c = torch.load(path)
        self.assertEqual(b, c, atol=0, rtol=0)
        self.assertTrue(isinstance(c[0], torch.FloatTensor))
        self.assertTrue(isinstance(c[1], torch.FloatTensor))
        self.assertTrue(isinstance(c[2], torch.FloatTensor))
        self.assertTrue(isinstance(c[3], torch.FloatTensor))
        self.assertTrue(isinstance(c[4], torch.storage.TypedStorage))
        self.assertEqual(c[4].dtype, torch.float32)
        c[0].fill_(10)
        self.assertEqual(c[0], c[2], atol=0, rtol=0)
        self.assertEqual(c[4],
                         torch.FloatStorage(25).fill_(10),
                         atol=0,
                         rtol=0)
        c[1].fill_(20)
        self.assertEqual(c[1], c[3], atol=0, rtol=0)

        # test some old tensor serialization mechanism
        class OldTensorBase(object):
            def __init__(self, new_tensor):
                self.new_tensor = new_tensor

            def __getstate__(self):
                return (self.new_tensor.storage(),
                        self.new_tensor.storage_offset(),
                        tuple(self.new_tensor.size()),
                        self.new_tensor.stride())

        class OldTensorV1(OldTensorBase):
            def __reduce__(self):
                return (torch.Tensor, (), self.__getstate__())

        class OldTensorV2(OldTensorBase):
            def __reduce__(self):
                return (_rebuild_tensor, self.__getstate__())

        x = torch.randn(30).as_strided([2, 3], [9, 3], 2)
        for old_cls in [OldTensorV1, OldTensorV2]:
            with tempfile.NamedTemporaryFile() as f:
                old_x = old_cls(x)
                torch.save(old_x, f)
                f.seek(0)
                load_x = torch.load(f)
                self.assertEqual(x.storage(), load_x.storage())
                self.assertEqual(x.storage_offset(), load_x.storage_offset())
                self.assertEqual(x.size(), load_x.size())
                self.assertEqual(x.stride(), load_x.stride())
예제 #4
0
    def test_serialization_map_location(self):
        test_file_path = download_file(
            'https://download.pytorch.org/test_data/gpu_tensors.pt')

        def map_location(storage, loc):
            return storage

        def load_bytes():
            with open(test_file_path, 'rb') as f:
                return io.BytesIO(f.read())

        fileobject_lambdas = [lambda: test_file_path, load_bytes]
        cpu_map_locations = [
            map_location,
            {
                'cuda:0': 'cpu'
            },
            'cpu',
            torch.device('cpu'),
        ]
        gpu_0_map_locations = [{
            'cuda:0': 'cuda:0'
        }, 'cuda', 'cuda:0',
                               torch.device('cuda'),
                               torch.device('cuda', 0)]
        gpu_last_map_locations = [
            'cuda:{}'.format(torch.cuda.device_count() - 1),
        ]

        def check_map_locations(map_locations, tensor_class, intended_device):
            for fileobject_lambda in fileobject_lambdas:
                for map_location in map_locations:
                    tensor = torch.load(fileobject_lambda(),
                                        map_location=map_location)

                    self.assertEqual(tensor.device, intended_device)
                    self.assertIsInstance(tensor, tensor_class)
                    self.assertEqual(tensor,
                                     tensor_class([[1.0, 2.0], [3.0, 4.0]]))

        check_map_locations(cpu_map_locations, torch.FloatTensor,
                            torch.device('cpu'))
        if torch.cuda.is_available():
            check_map_locations(gpu_0_map_locations, torch.cuda.FloatTensor,
                                torch.device('cuda', 0))
            check_map_locations(
                gpu_last_map_locations, torch.cuda.FloatTensor,
                torch.device('cuda',
                             torch.cuda.device_count() - 1))
예제 #5
0
 def test_load_python2_unicode_module(self):
     # This Pickle contains some Unicode data!
     path = download_file(
         'https://download.pytorch.org/test_data/legacy_conv2d.pt')
     with warnings.catch_warnings(record=True) as w:
         self.assertIsNotNone(torch.load(path))
예제 #6
0
 def test_load_python2_unicode_module(self):
     # This Pickle contains some Unicode data!
     path = download_file(
         'https://download.pytorch.org/test_data/legacy_conv2d.pt')
     self.assertIsNotNone(torch.load(path))