def test_save_different_dtype_unallocated(self): devices = ['cpu'] if torch.cuda.is_available(): devices.append('cuda') def save_load_check(a, b): with io.BytesIO() as f: torch.save([a, b], f) f.seek(0) a_loaded, b_loaded = torch.load(f) self.assertEqual(a, a_loaded) self.assertEqual(b, b_loaded) for device, dtype in product( devices, all_types_and_complex_and(torch.half, torch.bfloat16, torch.bool)): a = torch.tensor([], dtype=dtype, device=device) for other_dtype in all_types_and_complex_and( torch.half, torch.bfloat16, torch.bool): s = torch._TypedStorage(wrap_storage=a.storage()._untyped(), dtype=other_dtype) save_load_check(a, s) save_load_check(a.storage(), s) b = torch.tensor([], dtype=other_dtype, device=device) save_load_check(a, b)
def test_save_different_dtype_error(self): error_msg = r"Cannot save multiple tensors or storages that view the same data as different types" devices = ['cpu'] if torch.cuda.is_available(): devices.append('cuda') for device in devices: a = torch.randn(10, dtype=torch.complex128, device=device) f = io.BytesIO() with self.assertRaisesRegex(RuntimeError, error_msg): torch.save([a, a.imag], f) with self.assertRaisesRegex(RuntimeError, error_msg): torch.save([a.storage(), a.imag], f) with self.assertRaisesRegex(RuntimeError, error_msg): torch.save([a, a.imag.storage()], f) with self.assertRaisesRegex(RuntimeError, error_msg): torch.save([a.storage(), a.imag.storage()], f) a = torch.randn(10, device=device) s_bytes = torch._TypedStorage( wrap_storage=a.storage()._untyped(), dtype=torch.uint8) with self.assertRaisesRegex(RuntimeError, error_msg): torch.save([a, s_bytes], f) with self.assertRaisesRegex(RuntimeError, error_msg): torch.save([a.storage(), s_bytes], f)
def rebuild_storage_filename(cls, manager, handle, size, dtype=None): storage: Union[torch._TypedStorage, torch._UntypedStorage] = storage_from_cache(cls, handle) if storage is not None: return storage._shared_decref() if dtype is None: storage = torch._UntypedStorage._new_shared_filename_cpu(manager, handle, size) else: byte_size = size * torch._utils._element_size(dtype) untyped_storage: torch._UntypedStorage = torch._UntypedStorage._new_shared_filename_cpu(manager, handle, byte_size) storage = torch._TypedStorage( wrap_storage=untyped_storage, dtype=dtype) shared_cache[handle] = StorageWeakRef(storage) return storage._shared_decref()