def pickle_load(cls, filepath, disable_signals=False): """ Loads the object from a pickle file and performs initial setup. Args: filepath: Filename or directory name. It filepath is a directory, we scan the directory tree starting from filepath and we read the first pickle database. Raise RuntimeError if multiple databases are found. disable_signals: If True, the nodes of the flow are not connected by signals. This option is usually used when we want to read a flow in read-only mode and we want to avoid any possible side effect. """ if os.path.isdir(filepath): # Walk through each directory inside path and find the pickle database. for dirpath, dirnames, filenames in os.walk(filepath): fnames = [f for f in filenames if f == cls.PICKLE_FNAME] if fnames: if len(fnames) == 1: filepath = os.path.join(dirpath, fnames[0]) break # Exit os.walk else: err_msg = "Found multiple databases:\n %s" % str(fnames) raise RuntimeError(err_msg) else: err_msg = "Cannot find %s inside directory %s" % (cls.PICKLE_FNAME, filepath) raise ValueError(err_msg) with FileLock(filepath): with open(filepath, "rb") as fh: flow = pmg_pickle_load(fh) # Check if versions match. if flow.VERSION != cls.VERSION: msg = ("File flow version %s != latest version %s\n." "Regenerate the flow to solve the problem " % (flow.VERSION, cls.VERSION)) warnings.warn(msg) if not disable_signals: flow.connect_signals() # Recompute the status of each task since tasks that # have been submitted previously might be completed. flow.check_status() return flow
def serialize_with_pickle(self, objects, protocols=None, test_eq=True): """ Test whether the object(s) can be serialized and deserialized with pickle. This method tries to serialize the objects with pickle and the protocols specified in input. Then it deserializes the pickle format and compares the two objects with the __eq__ operator if test_eq == True. Args: objects: Object or list of objects. protocols: List of pickle protocols to test. If protocols is None, HIGHEST_PROTOCOL is tested. Returns: Nested list with the objects deserialized with the specified protocols. """ # Use the python version so that we get the traceback in case of errors import pickle as pickle from pymatgen.serializers.pickle_coders import pmg_pickle_load, pmg_pickle_dump # Build a list even when we receive a single object. got_single_object = False if not isinstance(objects, (list, tuple)): got_single_object = True objects = [objects] if protocols is None: #protocols = set([0, 1, 2] + [pickle.HIGHEST_PROTOCOL]) protocols = [pickle.HIGHEST_PROTOCOL] # This list will contains the object deserialized with the different protocols. objects_by_protocol, errors = [], [] for protocol in protocols: # Serialize and deserialize the object. mode = "wb" fd, tmpfile = tempfile.mkstemp(text="b" not in mode) try: with open(tmpfile, mode) as fh: #pickle.dump(objects, fh, protocol=protocol) pmg_pickle_dump(objects, fh, protocol=protocol) except Exception as exc: errors.append("pickle.dump with protocol %s raised:\n%s" % (protocol, str(exc))) continue try: with open(tmpfile, "rb") as fh: #new_objects = pickle.load(fh) new_objects = pmg_pickle_load(fh) except Exception as exc: errors.append("pickle.load with protocol %s raised:\n%s" % (protocol, str(exc))) continue # Test for equality if test_eq: for old_obj, new_obj in zip(objects, new_objects): #print("old_obj:", type(old_obj)) #print(old_obj) #print("new_obj:", type(new_obj)) #print(new_obj) self.assertEqual(old_obj, new_obj) # Save the deserialized objects and test for equality. objects_by_protocol.append(new_objects) if errors: raise ValueError("\n".join(errors)) # Return nested list so that client code can perform additional tests. if got_single_object: return [o[0] for o in objects_by_protocol] else: return objects_by_protocol