def test_load_unicode_error_msg(self): # This Pickle contains a Python 2 module with Unicode data and the # loading should fail if the user explicitly specifies ascii encoding! path = download_file( 'https://download.pytorch.org/test_data/legacy_conv2d.pt') self.assertRaises(UnicodeDecodeError, lambda: torch.load(path, encoding='ascii'))
def test_load_unicode_error_msg(self): # This Pickle contains a Python 2 module with Unicode data and the # loading should fail if the user explicitly specifies ascii encoding! path = download_file('https://download.pytorch.org/test_data/legacy_conv2d.pt') if sys.version_info >= (3, 0): self.assertRaises(UnicodeDecodeError, lambda: torch.load(path, encoding='ascii')) else: # Just checks the module loaded self.assertIsNotNone(torch.load(path))
def test_serialization_backwards_compat(self): a = [torch.arange(1 + i, 26 + i).view(5, 5).float() for i in range(2)] b = [a[i % 2] for i in range(4)] b += [a[0].storage()] b += [a[0].reshape(-1)[1:4].clone().storage()] path = download_file( 'https://download.pytorch.org/test_data/legacy_serialized.pt') c = torch.load(path) self.assertEqual(b, c, atol=0, rtol=0) self.assertTrue(isinstance(c[0], torch.FloatTensor)) self.assertTrue(isinstance(c[1], torch.FloatTensor)) self.assertTrue(isinstance(c[2], torch.FloatTensor)) self.assertTrue(isinstance(c[3], torch.FloatTensor)) self.assertTrue(isinstance(c[4], torch.storage.TypedStorage)) self.assertEqual(c[4].dtype, torch.float32) c[0].fill_(10) self.assertEqual(c[0], c[2], atol=0, rtol=0) self.assertEqual(c[4], torch.FloatStorage(25).fill_(10), atol=0, rtol=0) c[1].fill_(20) self.assertEqual(c[1], c[3], atol=0, rtol=0) # test some old tensor serialization mechanism class OldTensorBase(object): def __init__(self, new_tensor): self.new_tensor = new_tensor def __getstate__(self): return (self.new_tensor.storage(), self.new_tensor.storage_offset(), tuple(self.new_tensor.size()), self.new_tensor.stride()) class OldTensorV1(OldTensorBase): def __reduce__(self): return (torch.Tensor, (), self.__getstate__()) class OldTensorV2(OldTensorBase): def __reduce__(self): return (_rebuild_tensor, self.__getstate__()) x = torch.randn(30).as_strided([2, 3], [9, 3], 2) for old_cls in [OldTensorV1, OldTensorV2]: with tempfile.NamedTemporaryFile() as f: old_x = old_cls(x) torch.save(old_x, f) f.seek(0) load_x = torch.load(f) self.assertEqual(x.storage(), load_x.storage()) self.assertEqual(x.storage_offset(), load_x.storage_offset()) self.assertEqual(x.size(), load_x.size()) self.assertEqual(x.stride(), load_x.stride())
def test_serialization_map_location(self): test_file_path = download_file( 'https://download.pytorch.org/test_data/gpu_tensors.pt') def map_location(storage, loc): return storage def load_bytes(): with open(test_file_path, 'rb') as f: return io.BytesIO(f.read()) fileobject_lambdas = [lambda: test_file_path, load_bytes] cpu_map_locations = [ map_location, { 'cuda:0': 'cpu' }, 'cpu', torch.device('cpu'), ] gpu_0_map_locations = [{ 'cuda:0': 'cuda:0' }, 'cuda', 'cuda:0', torch.device('cuda'), torch.device('cuda', 0)] gpu_last_map_locations = [ 'cuda:{}'.format(torch.cuda.device_count() - 1), ] def check_map_locations(map_locations, tensor_class, intended_device): for fileobject_lambda in fileobject_lambdas: for map_location in map_locations: tensor = torch.load(fileobject_lambda(), map_location=map_location) self.assertEqual(tensor.device, intended_device) self.assertIsInstance(tensor, tensor_class) self.assertEqual(tensor, tensor_class([[1.0, 2.0], [3.0, 4.0]])) check_map_locations(cpu_map_locations, torch.FloatTensor, torch.device('cpu')) if torch.cuda.is_available(): check_map_locations(gpu_0_map_locations, torch.cuda.FloatTensor, torch.device('cuda', 0)) check_map_locations( gpu_last_map_locations, torch.cuda.FloatTensor, torch.device('cuda', torch.cuda.device_count() - 1))
def test_load_python2_unicode_module(self): # This Pickle contains some Unicode data! path = download_file( 'https://download.pytorch.org/test_data/legacy_conv2d.pt') with warnings.catch_warnings(record=True) as w: self.assertIsNotNone(torch.load(path))
def test_load_python2_unicode_module(self): # This Pickle contains some Unicode data! path = download_file( 'https://download.pytorch.org/test_data/legacy_conv2d.pt') self.assertIsNotNone(torch.load(path))