def safe_unpickle_fh(fh, fix_imports=True, encoding="ASCII", errors="strict"): '''Safely unpickle untrusted data from *fh* *fh* must be seekable. ''' if not fh.seekable(): raise TypeError('*fh* must be seekable') pos = fh.tell() # First make sure that we know all used opcodes try: for (opcode, arg, _) in pickletools.genops(fh): if opcode.proto > 2 or opcode.name not in SAFE_UNPICKLE_OPCODES: raise pickle.UnpicklingError('opcode %s is unsafe' % opcode.name) except (ValueError, EOFError): raise pickle.UnpicklingError('corrupted data') fh.seek(pos) # Then use a custom Unpickler to ensure that we only give access to # specific, whitelisted globals. Note that with the above opcodes, there is # no way to trigger attribute access, so "brachiating" from a white listed # object to __builtins__ is not possible. return SafeUnpickler(fh, fix_imports=fix_imports, encoding=encoding, errors=errors).load()
def _correctly_load_bbox(bbox, path, is_zipped=False): """ Helper method for loading old version of pickled BBox object :param bbox: BBox object which was incorrectly loaded with pickle :type bbox: sentinelhub.BBox :param path: Path to file where BBox object is stored :type path: str :param is_zipped: `True` if file is zipped and `False` otherwise :type is_zipped: bool :return: Correctly loaded BBox object :rtype: sentinelhub.BBox """ warnings.warn( "Bounding box of your EOPatch is saved in old format which in the future won't be supported " "anymore. Please save bounding box again, you can overwrite the existing one", DeprecationWarning, stacklevel=4) with gzip.open(path) if is_zipped else open(path, 'rb') as pickle_file: crs_cnt = -1 for _, arg, _ in pickletools.genops(pickle_file): if arg == 'sentinelhub.constants CRS': crs_cnt = 2 if crs_cnt == 0: return sentinelhub.BBox(tuple(bbox), sentinelhub.CRS(arg)) crs_cnt -= 1 raise ValueError( 'Failed to correctly load BBox object, try downgrading sentinelhub package to <=2.4.7' )
def import_name(obj): """Return the full import name for a Python object Note that this does not always exist (e.g. for dynamically generated functions). This function does it's best, using Pickle for the heavy lifting. For example: >>> import_name(import_name) 'rss2email.util import_name' Note the space between the module (``rss2email.util``) and the function within the module (``import_name``). Some objects can't be pickled: >>> import_name(lambda x: 'muahaha') Traceback (most recent call last): ... _pickle.PicklingError: Can't pickle <class 'function'>: attribute lookup builtins.function failed Some objects don't have a global scope: >>> import_name('abc') Traceback (most recent call last): ... ValueError: abc """ pickle = _pickle.dumps(obj) for opcode,arg,pos in _pickletools.genops(pickle): if opcode.name == 'GLOBAL': return arg raise ValueError(obj)
def optimize_pickle(p): """Optimize a pickle string by removing unused PUT opcodes.""" gets = set() # set of args used by a GET opcode puts = [] # (arg, startpos, stoppos) for the PUT opcodes prevpos = None # set to pos if previous opcode was a PUT for opcode, arg, pos in genops(p): if prevpos is not None: puts.append((prevarg, prevpos, pos)) prevpos = None if 'PUT' in opcode.name: prevarg, prevpos = arg, pos elif 'GET' in opcode.name: gets.add(arg) # Copy the pickle string except for PUTS without a corresponding GET s = [] i = 0 for arg, start, stop in puts: j = stop if (arg in gets) else start s.append(p[i:j]) i = stop s.append(p[i:]) return ''.join(s)
def optimize(obj, d=dumps, p=protocol, s=set, q=deque, g=genops): ''' Optimize a pickle string by removing unused PUT opcodes. Raymond Hettinger Python cookbook recipe # 545418 ''' # set of args used by a GET opcode this = d(obj, p) gets = s() gadd = gets.add # (arg, startpos, stoppos) for the PUT opcodes puts = q() pappend = puts.append # set to pos if previous opcode was a PUT prevpos, prevarg = None, None for opcode, arg, pos in genops(this): if prevpos is not None: pappend((prevarg, prevpos, pos)) prevpos = None if 'PUT' in opcode.name: prevarg, prevpos = arg, pos elif 'GET' in opcode.name: gadd(arg) # Copy the pickle string except for PUTS without a corresponding GET s = q() sappend = s.append i = 0 for arg, start, stop in puts: sappend(this[i:stop if (arg in gets) else start]) i = stop sappend(this[i:]) return b('').join(s)
def messages(self): """ Iterate over all available messages. Yields: object: The next decoded message. """ frame_start = 0 while self._buffer.find(b'.', frame_start, frame_start + self.HEADER_MAX_LEN) > -1: header = self._buffer[frame_start:frame_start + self.HEADER_MAX_LEN] header_len = list(pickletools.genops(header))[-1][2]+1 doc_len = pickle.loads(header[:header_len]) if not isinstance(doc_len, int) or doc_len < 0: raise ValueError('Document length must be a positive integer') doc_start = frame_start + header_len doc_end = doc_start + doc_len if doc_end > len(self._buffer): break yield pickle.loads(self._buffer[doc_start:doc_end]) frame_start += header_len + doc_len self._buffer = self._buffer[frame_start:]
def import_name(obj): """Return the full import name for a Python object Note that this does not always exist (e.g. for dynamically generated functions). This function does it's best, using Pickle for the heavy lifting. For example: >>> import_name(import_name) 'rss2email.util import_name' Note the space between the module (``rss2email.util``) and the function within the module (``import_name``). Some objects can't be pickled: >>> import_name(lambda x: 'muahaha') Traceback (most recent call last): ... _pickle.PicklingError: Can't pickle <class 'function'>: attribute lookup builtins.function failed Some objects don't have a global scope: >>> import_name('abc') Traceback (most recent call last): ... ValueError: abc """ pickle = _pickle.dumps(obj) for opcode, arg, pos in _pickletools.genops(pickle): if opcode.name == 'GLOBAL': return arg raise ValueError(obj)
def check_opcodes(pickle_data): """ Blacklist pickle opcodes that would make this too easy. """ banned_opcodes = set([pickle.STRING, pickle.INST, pickle.REDUCE]) for opcode, args, _ in pickletools.genops(pickle_data): if opcode.code in banned_opcodes: raise ValueError('Opcode not allowed: "%s".' % opcode.name)
def loads(self, string): for opcode in pickletools.genops(string): if opcode[0].name in self.OPCODE_BLACKLIST: raise pickle.UnpicklingError('Potentially unsafe pickle') orig_unpickler = pickle.Unpickler(StringIO(string)) orig_unpickler.find_global = self.find_class return orig_unpickler.load()
def test_resolve_serializer(self): """Ensure function resolve_serializer works correctly""" serializer = resolve_serializer(None) self.assertIsNotNone(serializer) self.assertEqual(serializer, DefaultSerializer) # Test round trip with default serializer test_data = {'test': 'data'} serialized_data = serializer.dumps(test_data) self.assertEqual(serializer.loads(serialized_data), test_data) self.assertEqual( next(pickletools.genops(serialized_data))[1], pickle.HIGHEST_PROTOCOL ) # Test using json serializer serializer = resolve_serializer(json) self.assertIsNotNone(serializer) self.assertTrue(hasattr(serializer, 'dumps')) self.assertTrue(hasattr(serializer, 'loads')) # Test raise NotImplmentedError with self.assertRaises(NotImplementedError): resolve_serializer(object) # Test raise Exception with self.assertRaises(Exception): resolve_serializer(queue.Queue()) # Test using path.to.serializer string serializer = resolve_serializer('tests.fixtures.Serializer') self.assertIsNotNone(serializer)
def extract_codes_from_model(model_file, dest): with open(model_file, "rb") as f: read_script(f) importer_namespace, envir_info = read_importer_and_source(f) # read pickle protocol next(pickletools.genops(f)) # read class name and module class_op, class_op_args, _ = next(pickletools.genops(f)) assert class_op.code == "c", "Invalid model file" entrance_module, entrance_class_name = class_op_args.split(" ") # write codes extract_codes(envir_info["modules"], dest, importer_namespace) # write main script with open(os.path.join(dest, "parser_.py"), "w") as f: f.write("from {} import {}\n".format(entrance_module, entrance_class_name)) f.write("{}.main()\n".format(entrance_class_name))
def extract_pickled_filename(pickle_path): with open(pickle_path, 'r') as f: items = pickletools.genops(f) paths = [ key for op, key, val in items if isinstance(key, str) and key.startswith('/data') ] return max(paths, key=lambda x: len(x))
def save_pickle(self, package: str, resource: str, obj: Any, dependencies: bool = True): """Save a python object to the archive using pickle. Equivalent to :func:`torch.save` but saving into the archive rather than a stand-alone file. Stanard pickle does not save the code, only the objects. If ``dependencies`` is true, this method will also scan the pickled objects for which modules are required to reconstruct them and save the relevant code. To be able to save an object where ``type(obj).__name__`` is ``my_module.MyObject``, ``my_module.MyObject`` must resolve to the class of the object according to the ``importer`` order. When saving objects that have previously been packaged, the importer's ``import_module`` method will need to be present in the ``importer`` list for this to work. Args: package (str): The name of module package this resource should go in (e.g. ``"my_package.my_subpackage"``). resource (str): A unique name for the resource, used to identify it to load. obj (Any): The object to save, must be picklable. dependencies (bool, optional): If ``True``, we scan the source for dependencies. """ filename = self._filename(package, resource) # Write the pickle data for `obj` data_buf = io.BytesIO() pickler = create_pickler(data_buf, self.importer) pickler.persistent_id = self._persistent_id pickler.dump(obj) data_value = data_buf.getvalue() name_in_dependency_graph = f"<{package}.{resource}>" self.dependency_graph.add_node( name_in_dependency_graph, action=_ModuleProviderAction.INTERN, provided=True, is_pickle=True, ) if dependencies: all_dependencies = [] for opcode, arg, pos in pickletools.genops(data_value): if opcode.name == "GLOBAL": # a global reference assert isinstance(arg, str) module, field = arg.split(" ") if module not in all_dependencies: all_dependencies.append(module) if self.verbose: dep_string = "".join(f" {dep}\n" for dep in all_dependencies) print(f"{resource} depends on:\n{dep_string}\n") for module_name in all_dependencies: self.dependency_graph.add_edge(name_in_dependency_graph, module_name) self.require_module_if_not_provided(module_name) self._write(filename, data_value)
def protocol_version(file_object): """ Print protocol version :param file_object: :return: """ maxproto = -1 for opcode, arg, pos in pickletools.genops(file_object): maxproto = max(maxproto, opcode.proto) return maxproto
def is_pickle(data): try: for idx, _ in enumerate(pickletools.genops(data)): if idx > 100: break return True except Exception: return False
def _import_lack_modules(c_bin): for opcode, arg, pos in pickletools.genops(c_bin): if opcode.name == 'GLOBAL': module, name = arg.split(' ') logging.debug("_import_lack_modules => {}, {}".format(module, name)) if '.' in module: continue elif module not in sys.modules: # import_module(name=here, package='dummy_module') sys.modules[module] = Contract
def DecodeOldReggieInfo(data, validKeys): """ Decode the provided level info data into a dictionary, which will have only the keys specified. Raises an exception if the data can't be parsed. """ # The idea here is that we implement just enough of the pickle # protocol (v2) to be able to parse the dictionaries that past # Reggies have pickled, even if PyQt4 isn't available. # # We keep track of the stack and memo, just enough to figure out # in what order the strings are pushed to the stack. (We need to # implement the memo because default level info uses memoization to # avoid encoding the '-' string more than once.) Then we filter out # 'PyQt4.QtCore' and 'QString'. Assuming nobody's crazy enough to # use those as actual level info field values, that should leave us # with exactly 12 strings (6 field names and 6 fields). Then we just # put the dictionary together in the same way as the SETITEMS pickle # instruction, and we're done. # Figure out in what order strings are pushed to the pickle stack stack = [] memo = {} for inst, arg, _ in pickletools.genops(data): if inst.name in ['SHORT_BINSTRING', 'BINSTRING', 'BINUNICODE']: stack.append(arg) elif inst.name == 'GLOBAL': # In practice, this is used to push sip._unpickle_type, # which then gets BINGET'd over and over. So we have to take # it into account, or else we get confused and end up # pushing some random string to the stack repeatedly instead stack.append(None) elif inst.name == 'BINPUT' and stack: memo[arg] = stack[-1] elif inst.name == 'BINGET' and arg in memo: stack.append(memo[arg]) # Filter out uninteresting strings and check that the length is right strings = [s for s in stack if s not in {'PyQt4.QtCore', 'QString', None}] if len(strings) != 12: raise ValueError('Wrong number of strings in level metadata (%d)' % len(strings)) # Convert e.g. [a, b, c, d, e, f] -> {a: b, c: d, e: f} # https://stackoverflow.com/a/12739974 it = iter(strings) levelinfo = dict(zip(it, it)) # Double-check that the keys are as expected, and return if set(levelinfo) != validKeys: raise ValueError('Wrong keys in level metadata: ' + str(set(levelinfo))) return levelinfo
def pickle_filtered(value): """ Unpickle a string, but reject certain inputs. Specifically, only built-in class objects and functions are allowed. """ for opcode, arg, _ in pickletools.genops(value): if (opcode.code in [pickle.GLOBAL, pickle.INST] and not arg.startswith('__builtin__')): raise ValueError('Only built-in functions/classes allowed.') pickle.loads(value)
def save_pickle(self, package: str, resource: str, obj: Any, dependencies: bool = True): """Save a python object to the archive using pickle. Equivalent to :func:`torch.save` but saving into the archive rather than a stand-alone file. Stanard pickle does not save the code, only the objects. If `dependencies` is true, this method will also scan the pickled objects for which modules are required to reconstruct them and save the relevant code. To be able to save an object where `type(obj).__name__` is `my_module.MyObject`, `my_module.MyObject` must resolve to the class of the object according to the `importer` order. When saving objects that have previously been packaged, the importer's `import_module` method will need to be present in the `importer` list for this to work. Args: package (str): The name of module package this resource should go it (e.g. "my_package.my_subpackage") resource (str): A unique name for the resource, used to indentify it to load. obj (Any): The object to save, must be picklable. dependencies (bool, optional): If True, we scan the source for dependencies (see :ref:`Dependencies`). """ filename = self._filename(package, resource) # Write the pickle data for `obj` data_buf = io.BytesIO() pickler = self._create_pickler(data_buf) pickler.persistent_id = self._persistent_id pickler.dump(obj) data_value = data_buf.getvalue() if dependencies: all_dependencies = [] for opcode, arg, pos in pickletools.genops(data_value): if opcode.name == 'GLOBAL': # a global reference assert isinstance(arg, str) module, field = arg.split(' ') if module not in all_dependencies: all_dependencies.append(module) for dep in all_dependencies: self.debug_deps.append((package + '.' + resource, dep)) if self.verbose: dep_string = ''.join(f' {dep}\n' for dep in all_dependencies) print(f"{resource} depends on:\n{dep_string}\n") for module_name in all_dependencies: self.require_module_if_not_provided(module_name) self._write(filename, data_value)
def optimize_puts(p): """ Optimize a pickle by assigning the low 256 BINPUT's to the most used gets. Should only be used for pickle protocol 1 - 3 """ counter = {} process = [] prevnode = None for opcode, arg, pos in pickletools.genops(p): if prevnode is not None: process.append((pos, prevnode)) prevnode = None if "GET" in opcode.name: if arg in counter: counter[arg] += 1 else: counter[arg] = 1 prevnode = opcode.name, arg, pos elif "PUT" in opcode.name: prevnode = opcode.name, arg, pos elif "MEMOIZE" in opcode.name: raise Exception( "Memoize opcode detected, pickle version not supported") replmap = dict((key, i) for i, key in enumerate( sorted(counter.keys(), key=lambda x: -counter[x]))) rv = [] i = 0 for newpos, (name, arg, pos) in process: rv.append(p[i:pos]) newarg = replmap[arg] if "GET" in name: if newarg < 256: rv.append(pickle.BINGET + chr(newarg)) else: rv.append(pickle.LONG_BINGET + pack("<i", newarg)) elif "PUT" in name: if newarg < 256: rv.append(pickle.BINPUT + chr(newarg)) else: rv.append(pickle.LONG_BINPUT + pack("<i", newarg)) i = newpos rv.append(p[i:]) return ''.join(rv)
def _create_model(self, model_path): tf.InteractiveSession() with open(model_path, 'rb') as file: print(file) proto_op = next(pickletools.genops(file)) assert proto_op[0].name == 'PROTO' proto_ver = proto_op[1] print("Pickled with version", proto_ver) loaded_networks = pickle.load(file) if isinstance(loaded_networks, tuple): # Original model saves a tuple (Generator, Discriminator, weighted avg Generator) self._G, self._D, self._Gs = loaded_networks else: self._Gs = loaded_networks """
def optimize_puts(p): """ Optimize a pickle by assigning the low 256 BINPUT's to the most used gets. Should only be used for pickle protocol 1 - 3 """ counter = {} process = [] prevnode = None for opcode, arg, pos in pickletools.genops(p): if prevnode is not None: process.append((pos, prevnode)) prevnode = None if "GET" in opcode.name: if arg in counter: counter[arg] += 1 else: counter[arg] = 1 prevnode = opcode.name, arg, pos elif "PUT" in opcode.name: prevnode = opcode.name, arg, pos elif "MEMOIZE" in opcode.name: raise Exception("Memoize opcode detected, pickle version not supported") replmap = dict((key, i) for i, key in enumerate(sorted(counter.keys(), key=lambda x: -counter[x]))) rv = [] i = 0 for newpos, (name, arg, pos) in process: rv.append(p[i:pos]) newarg = replmap[arg] if "GET" in name: if newarg < 256: rv.append(pickle.BINGET + chr(newarg)) else: rv.append(pickle.LONG_BINGET + pack("<i", newarg)) elif "PUT" in name: if newarg < 256: rv.append(pickle.BINPUT + chr(newarg)) else: rv.append(pickle.LONG_BINPUT + pack("<i", newarg)) i = newpos rv.append(p[i:]) return ''.join(rv)
def _just_the_instructions(pickle): """ Get the instruction stream of a pickle. This is sort-of like genops, except genops occasionally errors out on certain structural pickle errors. We don't want that, because we want to figure out as much as we can about the pickle. """ ops = pt.genops(pickle) while True: try: yield next(ops) except ValueError as e: if e.args == ("pickle exhausted before seeing STOP",): break else: raise except StopIteration: break
def test_pickle_protocol(self): # Pickled files should use protocol 3 (a compromise between # efficiency and wide applicability). model = Model(count=37) filename = os.path.join(self.tmpdir, "nonexistent.pkl") self.assertFalse(os.path.exists(filename)) with mock.patch.object(self.toolkit, "view_application"): with self.assertWarns(DeprecationWarning): model.configure_traits(filename=filename) self.assertTrue(os.path.exists(filename)) with open(filename, "rb") as pickled_object_file: pickled_object = pickled_object_file.read() # Get and check the first opcode opcode, arg, _ = next(pickletools.genops(pickled_object)) self.assertEqual(opcode.name, "PROTO") self.assertEqual(arg, 3)
def load(pickled: Union[ByteString, BinaryIO]) -> "Pickled": if not isinstance(pickled, (bytes, bytearray)) and hasattr(pickled, "read"): pickled = pickled.read() opcodes: List[Opcode] = [] for info, arg, pos in genops(pickled): if info.arg is None or info.arg.n == 0: if pos is not None: data = pickled[pos:pos + 1] else: data = info.code elif info.arg.n > 0 and pos is not None: data = pickled[pos:pos + 1 + info.arg.n] else: data = None if pos is not None and opcodes and opcodes[ -1].pos is not None and not opcodes[-1].has_data(): opcodes[-1].data = pickled[opcodes[-1].pos:pos] opcodes.append( Opcode(info=info, argument=arg, data=data, position=pos)) if opcodes and not opcodes[-1].has_data( ) and opcodes[-1].pos is not None: opcodes[-1].data = pickled[opcodes[-1].pos:] return Pickled(opcodes)
def protocol_version(file): maxproto = -1 with open(file, 'rb') as file_object: for opcode, arg, pos in pickletools.genops(file_object): maxproto = max(maxproto, opcode.proto) return maxproto
def protocol_version(file_object): maxproto = -1 for opcode, arg, pos in pickletools.genops(file_object): maxproto = max(maxproto, opcode.proto) return maxproto
def set_pickle(self, pickle): self.pickle = pickle self.ops = pickletools.genops(BytesIO(pickle))
def update_event(self, inp=-1): self.set_output_val(0, pickletools.genops(self.input(0)))
def opcode_in_pickle(code, pickle): for op, _, _ in pickletools.genops(pickle): if op.code == code.decode("latin-1"): return True return False
def opcode_in_pickle(code, pickle): for op, dummy, dummy in pickletools.genops(pickle): if op.code == code: return True return False
import pickle, pickletools import numpy as np import tensorflow as tf from timeit import default_timer as timer import PIL.Image # This finds the tfutils.py as a file # Initialize TensorFlow session. tf.InteractiveSession() path = "aerials512vectors1024px_snapshot-010200.pkl" with open(path, 'rb') as file: print(file) proto_op = next(pickletools.genops(file)) assert proto_op[0].name == 'PROTO' proto_ver = proto_op[1] print("Pickled with version", proto_ver) G, D, Gs = pickle.load(file) """ def save_imageONE_from_latentONE(latents, name="img", idx_over=0): # Generate dummy labels (not used by the official networks). labels = np.zeros([latents.shape[0]] + Gs.input_shapes[1][1:]) # Run the generator to produce a set of images. images = Gs.run(latents, labels) # Convert images to PIL-compatible format. images = np.clip(np.rint((images + 1.0) / 2.0 * 255.0), 0.0, 255.0).astype(np.uint8) # [-1,1] => [0,255] images = images.transpose(0, 2, 3, 1) # NCHW => NHWC
9) with open(path.join(pack_folder, "un.rpyc"), "wb") as f: f.write(unrpyc) if args.debug: print("File length = {0}".format(len(unrpyc))) import pickletools data = zlib.decompress(unrpyc) with open(path.join(pack_folder, "un.dis"), "wb" if p.PY2 else "w") as f: pickletools.dis(data, f) for com, arg, _ in pickletools.genops(data): if arg and (isinstance(arg, str) or p.PY3 and isinstance(arg, bytes)) and len(arg) > 1000: if p.PY3 and isinstance(arg, str): arg = arg.encode("latin1") data = zlib.decompress(arg) break else: raise Exception("didn't find the gzipped blob inside") with open(path.join(pack_folder, "un.dis2"), "wb" if p.PY2 else "w") as f: pickletools.dis(data, f) with open(path.join(pack_folder, "un.dis3"), "wb" if p.PY2 else "w") as f:
unrpyc = zlib.compress(p.optimize(p.dumps(decompiler, protocol), protocol), 9) with open(path.join(pack_folder, "un.rpyc"), "wb") as f: f.write(unrpyc) if args.debug: print("File length = {0}".format(len(unrpyc))) import pickletools data = zlib.decompress(unrpyc) with open(path.join(pack_folder, "un.dis"), "wb" if p.PY2 else "w") as f: pickletools.dis(data, f) for com, arg, _ in pickletools.genops(data): if arg and (isinstance(arg, str) or p.PY3 and isinstance(arg, bytes)) and len(arg) > 1000: if p.PY3 and isinstance(arg, str): arg = arg.encode("latin1") data = zlib.decompress(arg) break else: raise Exception("didn't find the gzipped blob inside") with open(path.join(pack_folder, "un.dis2"), "wb" if p.PY2 else "w") as f: pickletools.dis(data, f) with open(path.join(pack_folder, "un.dis3"), "wb" if p.PY2 else "w") as f:
def count_opcode(code, pickle): n = 0 for op, dummy, dummy in pickletools.genops(pickle): if op.code == code: n += 1 return n
def save_pickle( self, package: str, resource: str, obj: Any, dependencies: bool = True, pickle_protocol: int = 3, ): """Save a python object to the archive using pickle. Equivalent to :func:`torch.save` but saving into the archive rather than a stand-alone file. Stanard pickle does not save the code, only the objects. If ``dependencies`` is true, this method will also scan the pickled objects for which modules are required to reconstruct them and save the relevant code. To be able to save an object where ``type(obj).__name__`` is ``my_module.MyObject``, ``my_module.MyObject`` must resolve to the class of the object according to the ``importer`` order. When saving objects that have previously been packaged, the importer's ``import_module`` method will need to be present in the ``importer`` list for this to work. Args: package (str): The name of module package this resource should go in (e.g. ``"my_package.my_subpackage"``). resource (str): A unique name for the resource, used to identify it to load. obj (Any): The object to save, must be picklable. dependencies (bool, optional): If ``True``, we scan the source for dependencies. """ assert (pickle_protocol == 4) or ( pickle_protocol == 3), "torch.package only supports pickle protocols 3 and 4" filename = self._filename(package, resource) # Write the pickle data for `obj` data_buf = io.BytesIO() pickler = create_pickler(data_buf, self.importer, protocol=pickle_protocol) pickler.persistent_id = self._persistent_id pickler.dump(obj) data_value = data_buf.getvalue() mocked_modules = defaultdict(list) name_in_dependency_graph = f"<{package}.{resource}>" self.dependency_graph.add_node( name_in_dependency_graph, action=_ModuleProviderAction.INTERN, provided=True, is_pickle=True, ) def _check_mocked_error(module: Optional[str], field: Optional[str]): """ checks if an object (field) comes from a mocked module and then adds the pair to mocked_modules which contains mocked modules paired with their list of mocked objects present in the pickle. We also hold the invariant that the first user defined rule that applies to the module is the one we use. """ assert isinstance(module, str) assert isinstance(field, str) if self._can_implicitly_extern(module): return for pattern, pattern_info in self.patterns.items(): if pattern.matches(module): if pattern_info.action == _ModuleProviderAction.MOCK: mocked_modules[module].append(field) return if dependencies: all_dependencies = [] module = None field = None memo: DefaultDict[int, str] = defaultdict(None) memo_count = 0 # pickletools.dis(data_value) for opcode, arg, pos in pickletools.genops(data_value): if pickle_protocol == 4: if (opcode.name == "SHORT_BINUNICODE" or opcode.name == "BINUNICODE8"): assert isinstance(arg, str) module = field field = arg memo[memo_count] = arg elif (opcode.name == "BINGET_LONG" or opcode.name == "BINGET" or opcode.name == "GET"): assert isinstance(arg, int) module = field field = memo.get(arg, None) elif opcode.name == "MEMOIZE": memo_count += 1 elif opcode.name == "STACK_GLOBAL": assert isinstance(module, str) if module not in all_dependencies: all_dependencies.append(module) _check_mocked_error(module, field) elif (pickle_protocol == 3 and opcode.name == "GLOBAL"): # a global reference assert isinstance(arg, str) module, field = arg.split(" ") if module not in all_dependencies: all_dependencies.append(module) _check_mocked_error(module, field) for module_name in all_dependencies: self.dependency_graph.add_edge(name_in_dependency_graph, module_name) """ If an object happens to come from a mocked module, then we collect these errors and spit them out with the other errors found by package exporter. """ if module in mocked_modules: assert isinstance(module, str) fields = mocked_modules[module] self.dependency_graph.add_node( module_name, action=_ModuleProviderAction.MOCK, error=PackagingErrorReason.MOCKED_BUT_STILL_USED, error_context= f"Object(s) '{fields}' from module `{module_name}` was mocked out during packaging " f"but is being used in resource - `{resource}` in package `{package}`. ", provided=True, ) else: self.add_dependency(module_name) self._write(filename, data_value)
def _pickle_dump_fix(self, obj, memo_start_idx=0): ''' Pickle and object and optimize its string by changing MEMOIZE into PUT, removing unused PUT/MEMOIZE, fixing GET opcodes, and remove PROTO, FRAME, and STOP opcodes. Returns the pickle bytes, and the end memo index. ''' p = pickle.dumps(obj, 4) oldids = set() newids = {} opcodes = [] # Trick to avoid instanciating objects (we use the "is" operator) put = 'PUT' get = 'GET' ops = list(pickletools.genops(p)) ops = [(x[0], x[1], x[2], y[2]) for x, y in zip(ops[:-1], ops[1:])]+[(ops[-1][0], ops[-1][1], ops[-1][2], len(p))] for opcode, arg, pos, end_pos in ops: if opcode.name in ('FRAME', 'STOP'): # Ignore these pass elif opcode.name == 'PROTO': # Ignore, but check that it's version 4 assert arg == 4, "Pickle version should be 4" elif 'PUT' in opcode.name: oldids.add(arg) opcodes.append((put, arg)) elif opcode.name == 'MEMOIZE': idx = len(oldids) oldids.add(idx) opcodes.append((put, idx)) elif 'GET' in opcode.name: newids[arg] = None opcodes.append((get, arg)) else: opcodes.append((pos, end_pos)) del oldids out = [] memo_put_idx = memo_start_idx for op, arg in opcodes: if op is put: if arg not in newids: continue newids[arg] = memo_put_idx if memo_put_idx < 256: data = pickle.BINPUT + struct.pack('<B', memo_put_idx) else: data = pickle.LONG_BINPUT + struct.pack('<I', memo_put_idx) memo_put_idx += 1 elif op is get: memo_get_idx = newids[arg] if memo_get_idx < 256: data = pickle.BINGET + struct.pack('<B', memo_get_idx) else: data = pickle.LONG_BINGET + struct.pack('<I', memo_get_idx) else: data = p[op:arg] out.append(data) return b''.join(out), memo_put_idx