예제 #1
0
    def test_save_different_dtype_unallocated(self):
        devices = ['cpu']
        if torch.cuda.is_available():
            devices.append('cuda')

        def save_load_check(a, b):
            with io.BytesIO() as f:
                torch.save([a, b], f)
                f.seek(0)
                a_loaded, b_loaded = torch.load(f)
            self.assertEqual(a, a_loaded)
            self.assertEqual(b, b_loaded)

        for device, dtype in product(
                devices,
                all_types_and_complex_and(torch.half, torch.bfloat16,
                                          torch.bool)):
            a = torch.tensor([], dtype=dtype, device=device)

            for other_dtype in all_types_and_complex_and(
                    torch.half, torch.bfloat16, torch.bool):
                s = torch._TypedStorage(wrap_storage=a.storage()._untyped(),
                                        dtype=other_dtype)
                save_load_check(a, s)
                save_load_check(a.storage(), s)
                b = torch.tensor([], dtype=other_dtype, device=device)
                save_load_check(a, b)
예제 #2
0
    def test_save_different_dtype_error(self):
        error_msg = r"Cannot save multiple tensors or storages that view the same data as different types"

        devices = ['cpu']
        if torch.cuda.is_available():
            devices.append('cuda')

        for device in devices:
            a = torch.randn(10, dtype=torch.complex128, device=device)
            f = io.BytesIO()

            with self.assertRaisesRegex(RuntimeError, error_msg):
                torch.save([a, a.imag], f)

            with self.assertRaisesRegex(RuntimeError, error_msg):
                torch.save([a.storage(), a.imag], f)

            with self.assertRaisesRegex(RuntimeError, error_msg):
                torch.save([a, a.imag.storage()], f)

            with self.assertRaisesRegex(RuntimeError, error_msg):
                torch.save([a.storage(), a.imag.storage()], f)

            a = torch.randn(10, device=device)
            s_bytes = torch._TypedStorage(
                wrap_storage=a.storage()._untyped(),
                dtype=torch.uint8)

            with self.assertRaisesRegex(RuntimeError, error_msg):
                torch.save([a, s_bytes], f)

            with self.assertRaisesRegex(RuntimeError, error_msg):
                torch.save([a.storage(), s_bytes], f)
예제 #3
0
def rebuild_storage_filename(cls, manager, handle, size, dtype=None):
    storage: Union[torch._TypedStorage, torch._UntypedStorage] = storage_from_cache(cls, handle)
    if storage is not None:
        return storage._shared_decref()
    if dtype is None:
        storage = torch._UntypedStorage._new_shared_filename_cpu(manager, handle, size)
    else:
        byte_size = size * torch._utils._element_size(dtype)
        untyped_storage: torch._UntypedStorage = torch._UntypedStorage._new_shared_filename_cpu(manager, handle, byte_size)
        storage = torch._TypedStorage(
            wrap_storage=untyped_storage,
            dtype=dtype)
    shared_cache[handle] = StorageWeakRef(storage)
    return storage._shared_decref()