def get_jit_class_def(cls, self_name): # Get defs for each method within the current class independently # TODO: proper overriding analysis when implementing class inheritance methods = inspect.getmembers( cls, predicate=lambda m: (inspect.ismethod(m) or inspect.isfunction(m)) and not is_static_fn( cls, m.__name__) and m.__name__ in cls.__dict__) def is_classmethod(fn): return inspect.ismethod(fn) and getattr(fn, "__self__", None) == cls methods = [ get_jit_def(obj, name, self_name=self_name, is_classmethod=is_classmethod(obj)) for (name, obj) in methods ] properties = get_class_properties(cls, self_name) sourcelines, file_lineno, filename = get_source_lines_and_file( cls, torch._C.ErrorReport.call_stack()) source = ''.join(sourcelines) dedent_src = dedent(source) py_ast = ast.parse(dedent_src) leading_whitespace_len = len(source.split('\n', 1)[0]) - len( dedent_src.split('\n', 1)[0]) ctx = make_source_context(source, filename, file_lineno, leading_whitespace_len, False) class_ast = py_ast.body[0] assert isinstance(class_ast, ast.ClassDef) assigns = get_class_assigns(ctx, class_ast) return build_class_def(ctx, class_ast, methods, properties, self_name, assigns)
def get_signature(fn, rcb, loc, is_method): signature = try_real_annotations(fn, loc) if signature is not None and is_method: # If this is a method, then the signature will include a type for # `self`, but type comments do not contain a `self`. So strip it # away here so everything is consistent (`inspect.ismethod` does # not work here since `fn` is unbound at this point) param_types, return_type = signature param_types = param_types[1:] signature = (param_types, return_type) if signature is None: type_line, source = None, None try: source = dedent(''.join(get_source_lines_and_file(fn)[0])) type_line = get_type_line(source) except TypeError: pass # This might happen both because we failed to get the source of fn, or # because it didn't have any annotations. if type_line is not None: signature = parse_type_line(type_line, rcb, loc) return signature
def persistent_id(obj: Any) -> Optional[Tuple]: # FIXME: the docs say that persistent_id should only return a string # but torch store returns tuples. This works only in the binary protocol # see # https://docs.python.org/2/library/pickle.html#pickling-and-unpickling-external-objects # https://github.com/python/cpython/blob/master/Lib/pickle.py#L527-L537 if isinstance(obj, type) and issubclass(obj, nn.Module): if obj in serialized_container_types: return None serialized_container_types[obj] = True source_file = source = None try: source_lines, _, source_file = get_source_lines_and_file(obj) source = ''.join(source_lines) except Exception: # saving the source is optional, so we can ignore any errors warnings.warn("Couldn't retrieve source code for container of " "type " + obj.__name__ + ". It won't be checked " "for correctness upon loading.") return ('module', obj, source_file, source) if isinstance(obj, torch.storage._TypedStorage) or torch.is_storage(obj): if isinstance(obj, torch.storage._TypedStorage): # TODO: Once we decide to break serialization FC, this case # can be deleted storage = obj._storage storage_dtype = obj.dtype storage_type_str = obj.pickle_storage_type() storage_type = getattr(torch, storage_type_str) dtype = obj.dtype storage_numel = obj.size() else: storage = obj storage_dtype = storage.dtype storage_type = normalize_storage_type(type(obj)) dtype = torch.uint8 storage_numel = cast(Storage, storage).nbytes() # If storage is allocated, ensure that any other saved storages # pointing to the same data all have the same dtype. If storage is # not allocated, don't perform this check if storage.data_ptr() != 0: if storage.data_ptr() in storage_dtypes: if storage_dtype != storage_dtypes[storage.data_ptr()]: raise RuntimeError( 'Cannot save multiple tensors or storages that ' 'view the same data as different types') else: storage_dtypes[storage.data_ptr()] = storage_dtype view_metadata: Optional[Tuple[str, int, int]] storage = cast(Storage, storage) # Offset is always 0, but we keep it for backwards compatibility # with the old serialization format (which supported storage views) offset = 0 storage_key = str(storage._cdata) location = location_tag(storage) # TODO: There's an issue here with FC. It might be impossible to # solve, but it's worth noting. Imagine we save a list `[storage, # tensor]`, where `tensor.storage()` is the same as `storage`, and # `tensor.element_size() > 1`. Let's say that `tensor.dtype == # torch.float`. The storage will be serialized with element size # of 1, since we're choosing to serialize the first occurance of # a duplicate storage. Since this legacy serialization format saves # the numel of the storage, rather than nbytes directly, we'll be # effectively saving nbytes in this case. We'll be able to load it # and the tensor back up with no problems in _this_ and future # versions of pytorch, but in older versions, here's the problem: # the storage will be loaded up as a _UntypedStorage, and then the # FloatTensor will loaded and the _UntypedStorage will be assigned to # it. Since the storage dtype does not match the tensor dtype, this # will cause an error. If we reverse the list, like `[tensor, # storage]`, then we will save the `tensor.storage()` as a faked # `FloatStorage`, and the saved size will be the correct # dtype-specific numel count that old versions expect. `tensor` # will be able to load up properly in old versions, pointing to # a FloatStorage. However, `storage` is still being translated to # a _UntypedStorage, and it will try to resolve to the same # FloatStorage that `tensor` contains. This will also cause an # error. It doesn't seem like there's any way around this. # Probably, we just cannot maintain FC for the legacy format if the # saved list contains both a tensor and a storage that point to the # same data. We should still be able to maintain FC for lists of # just tensors, as long as all views share the same dtype as the # tensor they are viewing. if storage_key not in serialized_storages: serialized_storages[storage_key] = (storage, dtype) is_view = storage._cdata != storage._cdata if is_view: view_metadata = (str(storage._cdata), offset, storage.nbytes()) else: view_metadata = None res = ('storage', storage_type, storage_key, location, storage_numel, view_metadata) return res return None