예제 #1
0
    def persistent_id(obj):
        # FIXME: the docs say that persistent_id should only return a string
        # but torch store returns tuples. This works only in the binary protocol
        # see
        # https://docs.python.org/2/library/pickle.html#pickling-and-unpickling-external-objects
        # https://github.com/python/cpython/blob/master/Lib/pickle.py#L527-L537
        if isinstance(obj,
                      torch.storage.TypedStorage) or torch.is_storage(obj):

            if isinstance(obj, torch.storage.TypedStorage):
                # TODO: Once we decide to break serialization FC, this case
                # can be deleted
                storage = obj._storage
                storage_type_str = obj.pickle_storage_type()
                storage_type = getattr(torch, storage_type_str)
                storage_numel = obj.size()

            else:
                storage = obj
                storage_type = normalize_storage_type(type(obj))
                storage_numel = storage.nbytes()

            storage = cast(Storage, storage)
            storage_key = id_map.setdefault(storage._cdata, str(len(id_map)))
            location = location_tag(storage)
            serialized_storages[storage_key] = storage

            return ('storage', storage_type, storage_key, location,
                    storage_numel)

        return None
예제 #2
0
    def __setitem__(self, idx, value):
        if not isinstance(idx, (int, slice)):
            raise RuntimeError(f"can't index a {type(self)} with {type(idx)}")
        if torch.is_storage(value):
            raise RuntimeError(
                f'cannot set item with value type {type(value)}')
        if self.dtype in [
                torch.quint8, torch.quint4x2, torch.quint2x4, torch.qint32,
                torch.qint8
        ]:
            interpret_dtypes = {
                torch.quint8: torch.uint8,
                torch.quint4x2: torch.uint8,
                torch.quint2x4: torch.uint8,
                torch.qint32: torch.int32,
                torch.qint8: torch.int8
            }
            tmp_dtype = interpret_dtypes[self.dtype]
            tmp_tensor = torch.tensor(
                [], dtype=tmp_dtype, device=self.device).set_(
                    _TypedStorage(wrap_storage=self._storage, dtype=tmp_dtype))
        else:
            tmp_tensor = torch.tensor([], dtype=self.dtype,
                                      device=self.device).set_(self)

        tmp_tensor[idx] = value
예제 #3
0
    def _persistent_id(self, obj):
        if torch.is_storage(obj):
            storage_type = normalize_storage_type(type(obj))
            obj_key = str(obj._cdata)
            location = location_tag(obj)
            name = f".data/{obj_key}.storage"

            if name not in self.serialized_storages:
                # check to see if storage was previously serialized
                serialized_files = self.zip_file.get_all_written_records()
                if name not in serialized_files:
                    if obj.device.type != "cpu":
                        obj = obj.cpu()
                    num_bytes = obj.size() * obj.element_size()
                    self.zip_file.write_record(name, obj.data_ptr(), num_bytes)
                self.serialized_storages.add(name)
            return ("storage", storage_type, obj_key, location, obj.size())

        if hasattr(obj, "__reduce_package__"):
            if self.serialized_reduces.get(id(obj)) is None:
                self.serialized_reduces[id(obj)] = ("reduce_package", id(obj), *obj.__reduce_package__(self))

            return self.serialized_reduces[id(obj)]

        return None
예제 #4
0
    def persistent_id(obj):
        # FIXME: the docs say that persistent_id should only return a string
        # but torch store returns tuples. This works only in the binary protocol
        # see
        # https://docs.python.org/2/library/pickle.html#pickling-and-unpickling-external-objects
        # https://github.com/python/cpython/blob/master/Lib/pickle.py#L527-L537
        if isinstance(obj, type) and issubclass(obj, nn.Module):
            if obj in serialized_container_types:
                return None
            serialized_container_types[obj] = True
            source_file = source = None
            try:
                source_file = inspect.getsourcefile(obj)
                source = inspect.getsource(obj)
            except:  # saving the source is optional, so we can ignore any errors
                warnings.warn("Couldn't retrieve source code for container of "
                              "type " + obj.__name__ + ". It won't be checked "
                              "for correctness upon loading.")
            return ('module', obj, source_file, source)
        elif torch.is_storage(obj):
            storage_type = normalize_storage_type(type(obj))
            root, offset = obj._root_storage()
            root_key = str(root._cdata)
            location = location_tag(obj)
            serialized_storages[root_key] = root
            is_view = obj._cdata != root._cdata
            if is_view:
                view_metadata = (str(obj._cdata), offset, obj.size())
            else:
                view_metadata = None

            return ('storage', storage_type, root_key, location, root.size(),
                    view_metadata)

        return None
예제 #5
0
    def _persistent_id(self, obj):
        if torch.is_storage(obj):
            storage_type = normalize_storage_type(type(obj))
            location = location_tag(obj)

            # serialize storage if not already written
            storage_present = self.storage_context.has_storage(obj)
            storage_id = self.storage_context.get_or_add_storage(obj)
            if not storage_present:
                if obj.device.type != "cpu":
                    obj = obj.cpu()
                num_bytes = obj.size() * obj.element_size()
                self.zip_file.write_record(f".data/{storage_id}.storage",
                                           obj.data_ptr(), num_bytes)
            return ("storage", storage_type, storage_id, location, obj.size())

        if hasattr(obj, "__reduce_package__"):
            if _gate_torchscript_serialization and isinstance(
                    obj, torch.jit.RecursiveScriptModule):
                raise Exception(
                    "Serializing ScriptModules directly into a package is a beta feature. "
                    "To use, set global "
                    "`torch.package.package_exporter._gate_torchscript_serialization` to `False`."
                )
            if self.serialized_reduces.get(id(obj)) is None:
                self.serialized_reduces[id(obj)] = (
                    "reduce_package",
                    id(obj),
                    *obj.__reduce_package__(self),
                )

            return self.serialized_reduces[id(obj)]

        return None
예제 #6
0
    def _persistent_id(self, obj):
        if torch.is_storage(obj):
            storage_type = normalize_storage_type(type(obj))
            obj_key = str(obj._cdata)
            location = location_tag(obj)
            name = f".data/{obj_key}.storage"

            if name not in self.serialized_storages:
                # check to see if storage was previously serialized
                serialized_files = self.zip_file.get_all_written_records()
                if name not in serialized_files:
                    if obj.device.type != "cpu":
                        obj = obj.cpu()
                    num_bytes = obj.size() * obj.element_size()
                    self.zip_file.write_record(name, obj.data_ptr(), num_bytes)
                self.serialized_storages.add(name)
            return ("storage", storage_type, obj_key, location, obj.size())

        if hasattr(obj, "__reduce_package__"):
            if _gate_torchscript_serialization and isinstance(
                    obj, torch.jit.RecursiveScriptModule):
                raise Exception(
                    "Serializing ScriptModules directly into a package is a beta feature. "
                    "To use, set global "
                    "`torch.package.package_exporter._gate_torchscript_serialization` to `False`."
                )
            if self.serialized_reduces.get(id(obj)) is None:
                self.serialized_reduces[id(obj)] = (
                    "reduce_package", id(obj), *obj.__reduce_package__(self))

            return self.serialized_reduces[id(obj)]

        return None
    def persistent_id(obj):
        if isinstance(obj, type) and issubclass(obj, nn.Module):
            if obj in serialized_container_types:
                return None
            serialized_container_types[obj] = True
            source_file = source = None
            try:
                source_lines, _, source_file = get_source_lines_and_file(obj)
                source = ''.join(source_lines)
            except Exception:  # saving the source is optional, so we can ignore any errors
                warnings.warn("Couldn't retrieve source code for container of "
                              "type " + obj.__name__ + ". It won't be checked "
                              "for correctness upon loading.")
            return ('module', obj, source_file, source)

        elif torch.is_storage(obj):
            storage_type = normalize_storage_type(type(obj))
            # Offset is always 0, but we keep it for backwards compatibility
            # with the old serialization format (which supported storage views)
            offset = 0
            obj_key = str(obj._cdata)
            location = location_tag(obj)
            serialized_storages[obj_key] = obj
            is_view = obj._cdata != obj._cdata
            if is_view:
                view_metadata = (str(obj._cdata), offset, obj.size())
            else:
                view_metadata = None

            return ('storage', storage_type, obj_key, location, obj.size(),
                    view_metadata)
        return None
예제 #8
0
    def persistent_id(obj):
        if torch.is_storage(obj) or isinstance(obj,
                                               torch.storage.TypedStorage):
            if isinstance(obj, torch.storage.TypedStorage):
                # TODO: Once we decide to break serialization FC, we can
                # remove this case
                storage = obj._storage
                dtype = obj.dtype
            else:
                storage = obj
                dtype = torch.uint8

            serialized_storages.append(obj)
            serialized_dtypes.append(dtype)
            return ("storage", len(serialized_storages) - 1)

        if hasattr(obj, "__reduce_deploy__"):
            if _serialized_reduces.get(id(obj)) is None:
                _serialized_reduces[id(obj)] = (
                    "reduce_deploy",
                    id(obj),
                    *obj.__reduce_deploy__(importers),
                )
            return _serialized_reduces[id(obj)]

        return None
예제 #9
0
 def persistent_id(obj):
     if torch.is_tensor(obj):
         serialized_tensors[obj._cdata] = obj
         return str(obj._cdata)
     elif torch.is_storage(obj):
         serialized_storages[obj._cdata] = obj
         return str(obj._cdata)
     return None
def storage_ptr(obj):
    if torch.is_tensor(obj):
        return obj.storage().data_ptr()
    elif torch.is_storage(obj):
        return obj.data_ptr()
    elif is_new_api() and isinstance(obj, typed_storage_class):
        return obj._storage.data_ptr()
    else:
        assert False, f'type {type(obj)} is not supported'
예제 #11
0
def test1():
    """
    # check whether an object in Python is a tensor
    :return: bool
    """
    x = [0, 1, 2]
    print("whether x is tensor :",
          torch.is_tensor(x))  # check whether is tensor -> False
    print("whether x is storage :",
          torch.is_storage(x))  # check whether is stored -> False

    y = torch.randn(3, 2)  # shape=(3, 2) / torch.zeros(3, 2)
    print("whether y is tensor :",
          torch.is_tensor(y))  # check whether is tensor -> True
    print("whether y is storage :",
          torch.is_storage(y))  # check whether is stored -> False

    print("the total number of elements in the input Tensor is : {}".format(
        torch.numel(y)))
예제 #12
0
파일: _deploy.py 프로젝트: zyx5256/pytorch
 def persistent_id(obj):
     # FIXME: the docs say that persistent_id should only return a string
     # but torch store returns tuples. This works only in the binary protocol
     # see
     # https://docs.python.org/2/library/pickle.html#pickling-and-unpickling-external-objects
     # https://github.com/python/cpython/blob/master/Lib/pickle.py#L527-L537
     if torch.is_storage(obj):
         serialized_storages.append(obj)
         serialized_dtypes.append(obj.dtype)
         return ('storage', len(serialized_storages) - 1)
     return None
예제 #13
0
    def _persistent_id(self, obj):
        if torch.is_storage(obj) or isinstance(obj,
                                               torch.storage._TypedStorage):
            if isinstance(obj, torch.storage._TypedStorage):
                # TODO: Once we decide to break serialization FC, we can
                # remove this case
                untyped_storage = obj._storage
                storage_type_str = obj.pickle_storage_type()
                storage_type = getattr(torch, storage_type_str)
                dtype = obj.dtype
                storage_numel = obj.size()

            elif isinstance(obj, torch._UntypedStorage):
                untyped_storage = obj
                storage_type = normalize_storage_type(type(storage))
                dtype = torch.uint8
                storage_numel = storage.nbytes()
            else:
                raise RuntimeError(f"storage type not recognized: {type(obj)}")

            storage: Storage = cast(Storage, untyped_storage)
            location = location_tag(storage)

            # serialize storage if not already written
            storage_present = self.storage_context.has_storage(storage)
            storage_id = self.storage_context.get_or_add_storage(storage)
            if not storage_present:
                if storage.device.type != "cpu":
                    storage = storage.cpu()
                num_bytes = storage.nbytes()
                self.zip_file.write_record(f".data/{storage_id}.storage",
                                           storage.data_ptr(), num_bytes)
            return ("storage", storage_type, storage_id, location,
                    storage_numel)

        if hasattr(obj, "__reduce_package__"):
            if _gate_torchscript_serialization and isinstance(
                    obj, torch.jit.RecursiveScriptModule):
                raise Exception(
                    "Serializing ScriptModules directly into a package is a beta feature. "
                    "To use, set global "
                    "`torch.package.package_exporter._gate_torchscript_serialization` to `False`."
                )
            if self.serialized_reduces.get(id(obj)) is None:
                self.serialized_reduces[id(obj)] = (
                    "reduce_package",
                    id(obj),
                    *obj.__reduce_package__(self),
                )

            return self.serialized_reduces[id(obj)]

        return None
예제 #14
0
    def _persistent_id(self, obj):
        if torch.is_storage(obj):
            storage_type = normalize_storage_type(type(obj))
            obj_key = str(obj._cdata)
            location = location_tag(obj)
            self.serialized_storages[obj_key] = obj

            return ("storage", storage_type, obj_key, location, obj.size())
        if hasattr(obj, "__reduce_package__"):
            return ("reduce_package", *obj.__reduce_package__(self))

        return None
예제 #15
0
    def persistent_id(obj):
        # FIXME: the docs say that persistent_id should only return a string
        # but torch store returns tuples. This works only in the binary protocol
        # see
        # https://docs.python.org/2/library/pickle.html#pickling-and-unpickling-external-objects
        # https://github.com/python/cpython/blob/master/Lib/pickle.py#L527-L537
        if torch.is_storage(obj):
            storage_type = normalize_storage_type(type(obj))
            obj_key = str(obj._cdata)
            location = location_tag(obj)
            serialized_storages[obj_key] = obj

            return ('storage', storage_type, obj_key, location, obj.size())
        return None
예제 #16
0
 def safe_call(*args, **kwargs):
     args = tuple(
         ffi.cast(_torch_to_cffi.get(type(arg), 'void') + '*', arg._cdata)
         if torch.is_tensor(arg) or torch.is_storage(arg) else arg
         for arg in args)
     args = (function, ) + args
     result = torch._C._safe_call(*args, **kwargs)
     if isinstance(result, ffi.CData):
         typeof = ffi.typeof(result)
         if typeof.kind == 'pointer':
             cdata = int(ffi.cast('uintptr_t', result))
             cname = typeof.item.cname
             if cname in _cffi_to_torch:
                 return _cffi_to_torch[cname](cdata=cdata)
     return result
예제 #17
0
 def persistent_id(obj):
     if isinstance(obj, type) and issubclass(obj, nn.Container):
         if obj in serialized_container_types:
             return None
         serialized_container_types[obj] = True
         source_file = inspect.getsourcefile(obj)
         source = inspect.getsource(obj)
         return (obj, source_file, source)
     if torch.is_tensor(obj):
         serialized_tensors[obj._cdata] = obj
         return str(obj._cdata)
     elif torch.is_storage(obj):
         serialized_storages[obj._cdata] = obj
         return str(obj._cdata)
     return None
예제 #18
0
 def safe_call(*args, **kwargs):
     args = tuple(ffi.cast(_torch_to_cffi.get(type(arg), 'void') + '*', arg._cdata)
                  if torch.is_tensor(arg) or torch.is_storage(arg)
                  else arg
                  for arg in args)
     args = (function,) + args
     result = torch._C._safe_call(*args, **kwargs)
     if isinstance(result, ffi.CData):
         typeof = ffi.typeof(result)
         if typeof.kind == 'pointer':
             cdata = int(ffi.cast('uintptr_t', result))
             cname = typeof.item.cname
             if cname in _cffi_to_torch:
                 return _cffi_to_torch[cname](cdata=cdata)
     return result
예제 #19
0
파일: common.py 프로젝트: Jsmilemsj/pytorch
def to_gpu(obj, type_map={}):
    if isinstance(obj, torch.Tensor):
        assert obj.is_leaf
        t = type_map.get(obj.type(), get_gpu_type(obj.type()))
        with torch.no_grad():
            res = obj.clone().type(t)
            res.requires_grad = obj.requires_grad
        return res
    elif torch.is_storage(obj):
        return obj.new().resize_(obj.size()).copy_(obj)
    elif isinstance(obj, list):
        return [to_gpu(o, type_map) for o in obj]
    elif isinstance(obj, tuple):
        return tuple(to_gpu(o, type_map) for o in obj)
    else:
        return deepcopy(obj)
예제 #20
0
def to_gpu(obj, type_map={}):
    if torch.is_tensor(obj):
        t = type_map.get(type(obj), get_gpu_type(type(obj)))
        return obj.clone().type(t)
    elif torch.is_storage(obj):
        return obj.new().resize_(obj.size()).copy_(obj)
    elif isinstance(obj, Variable):
        assert obj.is_leaf
        t = type_map.get(type(obj.data), get_gpu_type(type(obj.data)))
        return Variable(obj.data.clone().type(t), requires_grad=obj.requires_grad)
    elif isinstance(obj, list):
        return [to_gpu(o, type_map) for o in obj]
    elif isinstance(obj, tuple):
        return tuple(to_gpu(o, type_map) for o in obj)
    else:
        return deepcopy(obj)
예제 #21
0
def to_gpu(obj, type_map={}):
    if isinstance(obj, torch.Tensor):
        assert obj.is_leaf
        t = type_map.get(obj.type(), get_gpu_type(obj.type()))
        with torch.no_grad():
            res = obj.clone().type(t)
            res.requires_grad = obj.requires_grad
        return res
    elif torch.is_storage(obj):
        return obj.new().resize_(obj.size()).copy_(obj)
    elif isinstance(obj, list):
        return [to_gpu(o, type_map) for o in obj]
    elif isinstance(obj, tuple):
        return tuple(to_gpu(o, type_map) for o in obj)
    else:
        return deepcopy(obj)
예제 #22
0
def to_gpu(obj, type_map={}):
    if torch.is_tensor(obj):
        t = type_map.get(type(obj), get_gpu_type(type(obj)))
        return obj.clone().type(t)
    elif torch.is_storage(obj):
        return obj.new().resize_(obj.size()).copy_(obj)
    elif isinstance(obj, Variable):
        assert obj.is_leaf
        t = type_map.get(type(obj.data), get_gpu_type(type(obj.data)))
        return Variable(obj.data.clone().type(t), requires_grad=obj.requires_grad)
    elif isinstance(obj, list):
        return [to_gpu(o, type_map) for o in obj]
    elif isinstance(obj, tuple):
        return tuple(to_gpu(o, type_map) for o in obj)
    else:
        return deepcopy(obj)
예제 #23
0
    def persistent_id(obj):
        # FIXME: the docs say that persistent_id should only return a string
        # but torch store returns tuples. This works only in the binary protocol
        # see
        # https://docs.python.org/2/library/pickle.html#pickling-and-unpickling-external-objects
        # https://github.com/python/cpython/blob/master/Lib/pickle.py#L527-L537
        if isinstance(obj,
                      torch.storage._TypedStorage) or torch.is_storage(obj):

            if isinstance(obj, torch.storage._TypedStorage):
                # TODO: Once we decide to break serialization FC, this case
                # can be deleted
                storage = obj._storage
                storage_dtype = obj.dtype
                storage_type_str = obj.pickle_storage_type()
                storage_type = getattr(torch, storage_type_str)
                storage_numel = obj.size()

            else:
                storage = obj
                storage_dtype = storage.dtype
                storage_type = normalize_storage_type(type(obj))
                storage_numel = storage.nbytes()

            storage = cast(Storage, storage)

            # If storage is allocated, ensure that any other saved storages
            # pointing to the same data all have the same dtype. If storage is
            # not allocated, don't perform this check
            if storage.data_ptr() != 0:
                if storage.data_ptr() in storage_dtypes:
                    if storage_dtype != storage_dtypes[storage.data_ptr()]:
                        raise RuntimeError(
                            'Cannot save multiple tensors or storages that '
                            'view the same data as different types')
                else:
                    storage_dtypes[storage.data_ptr()] = storage_dtype

            storage_key = id_map.setdefault(storage._cdata, str(len(id_map)))
            location = location_tag(storage)
            serialized_storages[storage_key] = storage

            return ('storage', storage_type, storage_key, location,
                    storage_numel)

        return None
예제 #24
0
 def safe_call(*args, **kwargs):
     args = tuple(
         ffi.cast(_torch_to_cffi.get(arg.type(), 'void') + '*', arg._cdata)
         if isinstance(arg, torch.Tensor) or torch.is_storage(arg) else arg
         for arg in args)
     args = (function, ) + args
     result = torch._C._safe_call(*args, **kwargs)
     if isinstance(result, ffi.CData):
         typeof = ffi.typeof(result)
         if typeof.kind == 'pointer':
             cdata = int(ffi.cast('uintptr_t', result))
             cname = typeof.item.cname
             if cname in _cffi_to_torch:
                 # TODO: Maybe there is a less janky way to eval
                 # off of this
                 return eval(_cffi_to_torch[cname])(cdata=cdata)
     return result
예제 #25
0
 def safe_call(*args, **kwargs):
     args = tuple(ffi.cast(_torch_to_cffi.get(arg.type(), 'void') + '*', arg._cdata)
                  if isinstance(arg, torch.Tensor) or torch.is_storage(arg)
                  else arg
                  for arg in args)
     args = (function,) + args
     result = torch._C._safe_call(*args, **kwargs)
     if isinstance(result, ffi.CData):
         typeof = ffi.typeof(result)
         if typeof.kind == 'pointer':
             cdata = int(ffi.cast('uintptr_t', result))
             cname = typeof.item.cname
             if cname in _cffi_to_torch:
                 # TODO: Maybe there is a less janky way to eval
                 # off of this
                 return eval(_cffi_to_torch[cname])(cdata=cdata)
     return result
예제 #26
0
    def persistent_id(obj: Any) -> Optional[Tuple]:
        # FIXME: the docs say that persistent_id should only return a string
        # but torch store returns tuples. This works only in the binary protocol
        # see
        # https://docs.python.org/2/library/pickle.html#pickling-and-unpickling-external-objects
        # https://github.com/python/cpython/blob/master/Lib/pickle.py#L527-L537
        if isinstance(obj, type) and issubclass(obj, nn.Module):
            if obj in serialized_container_types:
                return None
            serialized_container_types[obj] = True
            source_file = source = None
            try:
                source_lines, _, source_file = get_source_lines_and_file(obj)
                source = ''.join(source_lines)
            except Exception:  # saving the source is optional, so we can ignore any errors
                warnings.warn("Couldn't retrieve source code for container of "
                              "type " + obj.__name__ + ". It won't be checked "
                              "for correctness upon loading.")
            return ('module', obj, source_file, source)

        elif torch.is_storage(obj):
            view_metadata: Optional[Tuple[str, int, int]]
            obj = cast(Storage, obj)
            storage_type = normalize_storage_type(type(obj))
            # Offset is always 0, but we keep it for backwards compatibility
            # with the old serialization format (which supported storage views)
            offset = 0
            obj_key = str(obj._cdata)
            location = location_tag(obj)
            serialized_storages[obj_key] = obj
            is_view = obj._cdata != obj._cdata
            if is_view:
                view_metadata = (str(obj._cdata), offset, obj.size())
            else:
                view_metadata = None

            return ('storage',
                    storage_type,
                    obj_key,
                    location,
                    obj.size(),
                    view_metadata)
        return None
예제 #27
0
def tensor():
    print('Hello, PyTorch.tensor')
    y = torch.eye(3, 4)
    print(y)
    print(torch.is_tensor(y))
    print(torch.is_storage(y))
    print(torch.numel(y))

    x = [12, 23, 34, 45, 56, 67, 78]
    x1 = np.array(x)
    print(torch.from_numpy(x1))
    print(torch.linspace(2, 10, steps=25))
    print(torch.logspace(-10, 10, steps=15))
    print(torch.rand(10))
    print(torch.rand(4, 5))
    print(torch.randn(10))
    print(torch.randperm(10))
    print(torch.arange(10, 40, 2))
    d = torch.randn(4, 5)
    print(d)
    print(torch.argmin(d))
    print(torch.argmin(d, dim=1))
    print(torch.zeros(4, 5))

    x = torch.randn(4, 5)
    print(x)
    print(torch.cat((x, x)))
    print(torch.cat((x, x, x), 1))

    a = torch.randn(4, 4)
    print(a)
    print(torch.chunk(a, 2))
    print(torch.chunk(a, 2, 1))

    print(
        torch.gather(torch.tensor([[11, 12], [23, 24]]), 1,
                     torch.LongTensor([[0, 0], [1, 0]])))

    x = torch.randn(4, 5)
    print(x)
    print(x.t())
    print(x.transpose(1, 0))
def check_regular_serialization(loaded_list, check_list):
    for idx0 in range(len(check_list)):
        check_val0 = check_list[idx0]
        loaded_val0 = loaded_list[idx0]

        # Check that loaded values are what they should be
        assert type(check_val0) == type(loaded_val0), (
            f'type should be {type(check_val0)} but got {type(loaded_val0)}')

        if torch.is_tensor(check_val0):
            assert check_val0.device == loaded_val0.device
            assert check_val0.eq(loaded_val0).all()

        elif torch.is_storage(check_val0):
            assert check_val0.device == loaded_val0.device
            assert check_val0.tolist() == loaded_val0.tolist()

        elif issubclass(type(check_val0), torch.nn.Module):
            param_pairs = zip(check_val0.parameters(),
                              loaded_val0.parameters())
            assert all([p0.device == p1.device for p0, p1 in param_pairs])
            assert all([p0.eq(p1).all() for p0, p1 in param_pairs])

        elif is_new_api() and isinstance(check_val0, typed_storage_class):
            assert check_val0._storage.device == loaded_val0._storage.device
            assert check_val0.dtype == loaded_val0.dtype
            assert check_val0._storage.tolist() == loaded_val0._storage.tolist(
            )

        else:
            assert False, f'type {type(check_val0)} not supported'

        if not has_data_ptr(check_val0):
            continue

        # Check that storage sharing is preserved
        for idx1 in range(idx0 + 1, len(check_list)):
            check_val1 = check_list[idx1]
            loaded_val1 = loaded_list[idx0]

            if storage_ptr(check_val0) == storage_ptr(check_val1):
                assert storage_ptr(loaded_val0) == storage_ptr(loaded_val1)
예제 #29
0
def to_gpu(obj, type_map={}):
    if torch.is_tensor(obj):
        t = type_map.get(type(obj), get_gpu_type(type(obj)))
        # Workaround since torch.HalfTensor doesn't support clone()
        if type(obj) == torch.HalfTensor:
            return obj.new().resize_(obj.size()).copy_(obj).type(t)
        return obj.clone().type(t)
    elif torch.is_storage(obj):
        return obj.new().resize_(obj.size()).copy_(obj)
    elif isinstance(obj, Variable):
        assert obj.is_leaf
        t = type_map.get(type(obj.data), get_gpu_type(type(obj.data)))
        o = obj.type(t).detach()
        o.requires_grad = obj.requires_grad
        return o
    elif isinstance(obj, list):
        return [to_gpu(o, type_map) for o in obj]
    elif isinstance(obj, tuple):
        return tuple(to_gpu(o, type_map) for o in obj)
    else:
        return deepcopy(obj)
예제 #30
0
 def persistent_id(obj):
     if isinstance(obj, type) and issubclass(obj, nn.Module):
         if obj in serialized_container_types:
             return None
         serialized_container_types[obj] = True
         source_file = source = None
         try:
             source_file = inspect.getsourcefile(obj)
             source = inspect.getsource(obj)
         except (TypeError, IOError):
             warnings.warn("Couldn't retrieve source code for container of "
                           "type " + obj.__name__ + ". It won't be checked "
                           "for correctness upon loading.")
         return (obj, source_file, source)
     if torch.is_tensor(obj):
         serialized_tensors[obj._cdata] = obj
         return str(obj._cdata)
     elif torch.is_storage(obj):
         serialized_storages[obj._cdata] = obj
         return str(obj._cdata)
     return None
예제 #31
0
파일: _deploy.py 프로젝트: xsacha/pytorch
    def persistent_id(obj):
        # FIXME: the docs say that persistent_id should only return a string
        # but torch store returns tuples. This works only in the binary protocol
        # see
        # https://docs.python.org/2/library/pickle.html#pickling-and-unpickling-external-objects
        # https://github.com/python/cpython/blob/master/Lib/pickle.py#L527-L537
        if torch.is_storage(obj):
            serialized_storages.append(obj)
            serialized_dtypes.append(obj.dtype)
            return ('storage', len(serialized_storages) - 1)

        if hasattr(obj, "__reduce_deploy__"):
            if _serialized_reduces.get(id(obj)) is None:
                _serialized_reduces[id(obj)] = (
                    "reduce_deploy",
                    id(obj),
                    *obj.__reduce_deploy__(importers),
                )
            return _serialized_reduces[id(obj)]

        return None
예제 #32
0
    def persistent_id(obj):
        # FIXME: the docs say that persistent_id should only return a string
        # but torch store returns tuples. This works only in the binary protocol
        # see
        # https://docs.python.org/2/library/pickle.html#pickling-and-unpickling-external-objects
        # https://github.com/python/cpython/blob/master/Lib/pickle.py#L527-L537
        if isinstance(obj, type) and issubclass(obj, nn.Module):
            if obj in serialized_container_types:
                return None
            serialized_container_types[obj] = True
            source_file = source = None
            try:
                source_file = inspect.getsourcefile(obj)
                source = inspect.getsource(obj)
            except Exception:  # saving the source is optional, so we can ignore any errors
                warnings.warn("Couldn't retrieve source code for container of "
                              "type " + obj.__name__ + ". It won't be checked "
                              "for correctness upon loading.")
            return ('module', obj, source_file, source)
        elif torch.is_storage(obj):
            storage_type = normalize_storage_type(type(obj))
            root, offset = obj._root_storage()
            root_key = str(root._cdata)
            location = location_tag(obj)
            serialized_storages[root_key] = root
            is_view = obj._cdata != root._cdata
            if is_view:
                view_metadata = (str(obj._cdata), offset, obj.size())
            else:
                view_metadata = None

            return ('storage',
                    storage_type,
                    root_key,
                    location,
                    root.size(),
                    view_metadata)

        return None
예제 #33
0
    def persistent_id(obj: Any) -> Optional[Tuple]:
        # FIXME: the docs say that persistent_id should only return a string
        # but torch store returns tuples. This works only in the binary protocol
        # see
        # https://docs.python.org/2/library/pickle.html#pickling-and-unpickling-external-objects
        # https://github.com/python/cpython/blob/master/Lib/pickle.py#L527-L537
        if isinstance(obj, type) and issubclass(obj, nn.Module):
            if obj in serialized_container_types:
                return None
            serialized_container_types[obj] = True
            source_file = source = None
            try:
                source_lines, _, source_file = get_source_lines_and_file(obj)
                source = ''.join(source_lines)
            except Exception:  # saving the source is optional, so we can ignore any errors
                warnings.warn("Couldn't retrieve source code for container of "
                              "type " + obj.__name__ + ". It won't be checked "
                              "for correctness upon loading.")
            return ('module', obj, source_file, source)

        if isinstance(obj,
                      torch.storage._TypedStorage) or torch.is_storage(obj):
            if isinstance(obj, torch.storage._TypedStorage):
                # TODO: Once we decide to break serialization FC, this case
                # can be deleted
                storage = obj._storage
                storage_dtype = obj.dtype
                storage_type_str = obj.pickle_storage_type()
                storage_type = getattr(torch, storage_type_str)
                dtype = obj.dtype
                storage_numel = obj.size()

            else:
                storage = obj
                storage_dtype = storage.dtype
                storage_type = normalize_storage_type(type(obj))
                dtype = torch.uint8
                storage_numel = cast(Storage, storage).nbytes()

            # If storage is allocated, ensure that any other saved storages
            # pointing to the same data all have the same dtype. If storage is
            # not allocated, don't perform this check
            if storage.data_ptr() != 0:
                if storage.data_ptr() in storage_dtypes:
                    if storage_dtype != storage_dtypes[storage.data_ptr()]:
                        raise RuntimeError(
                            'Cannot save multiple tensors or storages that '
                            'view the same data as different types')
                else:
                    storage_dtypes[storage.data_ptr()] = storage_dtype

            view_metadata: Optional[Tuple[str, int, int]]
            storage = cast(Storage, storage)

            # Offset is always 0, but we keep it for backwards compatibility
            # with the old serialization format (which supported storage views)
            offset = 0
            storage_key = str(storage._cdata)
            location = location_tag(storage)

            # TODO: There's an issue here with FC. It might be impossible to
            # solve, but it's worth noting. Imagine we save a list `[storage,
            # tensor]`, where `tensor.storage()` is the same as `storage`, and
            # `tensor.element_size() > 1`. Let's say that `tensor.dtype ==
            # torch.float`.  The storage will be serialized with element size
            # of 1, since we're choosing to serialize the first occurance of
            # a duplicate storage. Since this legacy serialization format saves
            # the numel of the storage, rather than nbytes directly, we'll be
            # effectively saving nbytes in this case.  We'll be able to load it
            # and the tensor back up with no problems in _this_ and future
            # versions of pytorch, but in older versions, here's the problem:
            # the storage will be loaded up as a _UntypedStorage, and then the
            # FloatTensor will loaded and the _UntypedStorage will be assigned to
            # it. Since the storage dtype does not match the tensor dtype, this
            # will cause an error.  If we reverse the list, like `[tensor,
            # storage]`, then we will save the `tensor.storage()` as a faked
            # `FloatStorage`, and the saved size will be the correct
            # dtype-specific numel count that old versions expect. `tensor`
            # will be able to load up properly in old versions, pointing to
            # a FloatStorage. However, `storage` is still being translated to
            # a _UntypedStorage, and it will try to resolve to the same
            # FloatStorage that `tensor` contains. This will also cause an
            # error. It doesn't seem like there's any way around this.
            # Probably, we just cannot maintain FC for the legacy format if the
            # saved list contains both a tensor and a storage that point to the
            # same data.  We should still be able to maintain FC for lists of
            # just tensors, as long as all views share the same dtype as the
            # tensor they are viewing.

            if storage_key not in serialized_storages:
                serialized_storages[storage_key] = (storage, dtype)
            is_view = storage._cdata != storage._cdata
            if is_view:
                view_metadata = (str(storage._cdata), offset, storage.nbytes())
            else:
                view_metadata = None

            res = ('storage', storage_type, storage_key, location,
                   storage_numel, view_metadata)
            return res
        return None
예제 #34
0
def basic_operation():
    # REF [site] >>
    #	https://pytorch.org/docs/stable/tensors.html
    #	https://pytorch.org/docs/stable/tensor_attributes.html

    x = torch.empty(5, 3)
    print('x =', x)
    print('x.shape = {}, x.dtype = {}.'.format(x.shape, x.dtype))
    #print('x =', x.data)

    x = torch.rand(2, 3)
    print('x =', x)

    x = torch.randn(2, 3)
    print('x =', x)

    x = torch.randn(2, 3)
    print('x =', x)

    x = torch.randperm(5)
    print('x =', x)

    x = torch.FloatTensor(10, 12, 3, 3)
    print('x =', x.size())
    print('x =', x.size()[:])

    #--------------------
    y = torch.zeros(2, 3)
    print('y =', y)

    y = torch.ones(2, 3)
    print('y =', y)

    y = torch.arange(0, 3, step=0.5)
    print('y =', y)

    x = torch.tensor(1, dtype=torch.int32)
    #x = torch.tensor(1, dtype=torch.int32, device='cuda:1')
    print('x =', x)

    x = torch.tensor([5.5, 3])
    print('x =', x)

    x = x.new_ones(5, 3, dtype=torch.double)  # new_* methods take in sizes.
    print('x =', x)
    x = torch.randn_like(x, dtype=torch.float)  # Override dtype.
    print('x =', x)

    #--------------------
    y = torch.rand(5, 3)
    print('x + y =', x + y)

    print('x + y =', torch.add(x, y))

    result = torch.empty(5, 3)
    torch.add(x, y, out=result)
    print('x + y =', result)

    #--------------------
    # Any operation that mutates a tensor in-place is post-fixed with an _.
    # For example: x.copy_(y), x.t_(), will change x.

    y.add_(x)  # In-place.
    print('y =', y)

    #--------------------
    # You can use standard NumPy-like indexing with all bells and whistles!
    print(x[:, 1])

    #--------------------
    # If you have a one element tensor, use .item() to get the value as a Python number.
    x = torch.randn(1)
    print('x =', x)
    print('x.item() =', x.item())

    #--------------------
    x = torch.randn(2, 2)
    print('x.is_cuda =', x.is_cuda)
    print('x.is_complex() =', x.is_complex())
    print('x.is_contiguous() =', x.is_contiguous())
    print('x.is_distributed() =', x.is_distributed())
    print('x.is_floating_point() =', x.is_floating_point())
    print('x.is_pinned() =', x.is_pinned())
    print('x.is_quantized =', x.is_quantized)
    print('x.is_shared() =', x.is_shared())
    print('x.is_signed() =', x.is_signed())
    print('x.is_sparse =', x.is_sparse)

    print('x.contiguous() =', x.contiguous())
    print('x.storage() =', x.storage())

    #--------------------
    x = torch.randn(2, 2)
    print('torch.is_tensor(x) =', torch.is_tensor(x))
    print('torch.is_storage(x) =', torch.is_storage(x))
    print('torch.is_complex(x) =', torch.is_complex(x))
    print('torch.is_floating_point(x) =', torch.is_floating_point(x))

    # Sets the default floating point dtype to d.
    # This type will be used as default floating point type for type inference in torch.tensor().
    torch.set_default_dtype(torch.float32)
    print('torch.get_default_dtype() =', torch.get_default_dtype())
    # Sets the default torch.Tensor type to floating point tensor type.
    # This type will also be used as default floating point type for type inference in torch.tensor().
    torch.set_default_tensor_type(torch.FloatTensor)

    #--------------------
    # REF [site] >> https://pytorch.org/docs/stable/tensor_view.html
    # View tensor shares the same underlying data with its base tensor.
    # Supporting View avoids explicit data copy, thus allows us to do fast and memory efficient reshaping, slicing and element-wise operations.

    # If you want to resize/reshape tensor, you can use torch.view.
    x = torch.randn(4, 4)
    y = x.view(16)
    z = x.view(-1, 8)  # The size -1 is inferred from other dimensions.
    print('x.size() = {}, y.size() = {}, z.size() = {}.'.format(
        x.size(), y.size(), z.size()))

    t = torch.rand(4, 4)
    b = t.view(2, 8)
    print('t.storage().data_ptr() == b.storage().data_ptr()?',
          t.storage().data_ptr() == b.storage().data_ptr())
def has_data_ptr(obj):
    return torch.is_tensor(obj) or torch.is_storage(obj) or (
        is_new_api() and isinstance(obj, typed_storage_class))