def test_delete(self): inst = np_container(None) inst.save(self.fn, model_name='a') from pyemma._base.serialization.h5file import H5File with H5File(self.fn) as f: f.delete('a') self.assertNotIn('a', f.models_descriptive.keys())
def test_rename(self): inst = np_container(None) inst.save(self.fn, model_name='a') with H5File(self.fn, mode='a') as f: f.rename('a', 'b') models = f.models_descriptive.keys() self.assertIn('b', models) self.assertNotIn('a', models)
def save(self, file_name, model_name='default', overwrite=False, save_streaming_chain=False): r""" saves the current state of this object to given file and name. Parameters ----------- file_name: str path to desired output file model_name: str, default='default' creates a group named 'model_name' in the given file, which will contain all of the data. If the name already exists, and overwrite is False (default) will raise a RuntimeError. overwrite: bool, default=False Should overwrite existing model names? save_streaming_chain : boolean, default=False if True, the data_producer(s) of this object will also be saved in the given file. Examples -------- >>> import pyemma, numpy as np >>> from pyemma.util.contexts import named_temporary_file >>> m = pyemma.msm.MSM(P=np.array([[0.1, 0.9], [0.9, 0.1]])) >>> with named_temporary_file() as file: # doctest: +SKIP ... m.save(file, 'simple') # doctest: +SKIP ... inst_restored = pyemma.load(file, 'simple') # doctest: +SKIP >>> np.testing.assert_equal(m.P, inst_restored.P) # doctest: +SKIP """ import six if six.PY2: raise NotImplementedError( 'This feature is only available on Python3. Consider upgrading.' ) from pyemma._base.serialization.h5file import H5File try: with H5File(file_name=file_name) as f: f.add_serializable(model_name, obj=self, overwrite=overwrite, save_streaming_chain=save_streaming_chain) except Exception as e: msg = ('During saving the object {obj}") ' 'the following error occurred: {error}'.format(obj=self, error=e)) if isinstance(self, Loggable): self.logger.exception(msg) else: logger.exception(msg) raise
def load(cls, file_name, model_name='default'): """ Loads a previously saved PyEMMA object from disk. Parameters ---------- file_name : str or file like object (has to provide read method). The file like object tried to be read for a serialized object. model_name: str, default='default' if multiple models are contained in the file, these can be accessed by their name. Use :func:`pyemma.list_models` to get a representation of all stored models. Returns ------- obj : the de-serialized object """ from .h5file import H5File with H5File(file_name, model_name=model_name, mode='r') as f: return f.model
def load(cls, file_name, model_name='default'): """ loads a previously saved object of this class from a file. Parameters ---------- file_name : str or file like object (has to provide read method). The file like object tried to be read for a serialized object. model_name: str, default='default' if multiple models are contained in the file, these can be accessed by their name. Use func:`pyemma.list_models` to get a representation of all stored models. Returns ------- obj : the de-serialized object """ import six if six.PY2: raise NotImplementedError('This feature is only available on Python3. Consider upgrading.') from .h5file import H5File with H5File(file_name, model_name=model_name, mode='r') as f: return f.model
def _load_cmp(self, pdb): top = mdtraj.load(pdb).top with H5File(self.f, mode='a') as fh: fh.add_object('top', top) restored = fh.model assert top == restored assert tuple(top.atoms) == tuple(restored.atoms) assert tuple(top.bonds) == tuple(restored.bonds) # mdtraj (1.9.1) does not impl eq for Residue... def eq(self, other): from mdtraj.core.topology import Residue if not isinstance(other, Residue): return False return (self.index == other.index and self.resSeq == other.resSeq and other.name == self.name and tuple(other.atoms) == tuple(self.atoms)) from unittest import mock with mock.patch('mdtraj.core.topology.Residue.__eq__', eq): self.assertEqual(tuple(top.residues), tuple(restored.residues))
def main(argv=None): import six if six.PY2: print('This tool is only available for Python3.') sys.exit(1) import argparse from pyemma import load from pyemma._base.serialization.h5file import H5File parser = argparse.ArgumentParser() parser.add_argument('--json', action='store_true', default=False) parser.add_argument('files', metavar='files', nargs='+', help='files to inspect') parser.add_argument( '--recursive', action='store_true', default=False, help='If the pipeline of the stored estimator was stored, ' 'gather these information as well. This will require to load the model, ' 'so it could take a while, if the pipeline contains lots of data.') parser.add_argument('-v', '--verbose', action='store_true', default=False) args = parser.parse_args(argv) # store found models by filename from collections import defaultdict models = defaultdict(dict) for f in args.files: try: with H5File(f) as fh: m = fh.models_descriptive for k in m: models[f][k] = m[k] for model_name, values in m.items(): if values['saved_streaming_chain']: restored = load(f) models[f][model_name]['input_chain'] = [ repr(x) for x in restored._data_flow_chain() ] except BaseException as e: print( '{} did not contain a valid PyEMMA model. Error was {err}. ' 'If you are sure, that it does, please post an issue on Github' .format(f, err=e)) if args.verbose: import traceback traceback.print_exc() return 1 if not args.json: from io import StringIO buff = StringIO() buff.write('PyEMMA models\n') buff.write('=' * (buff.tell() - 1)) buff.write('\n' * 2) for f in models: buff.write('file: {}'.format(f)) buff.write('\n') buff.write('-' * 80) buff.write('\n') model_file = models[f] for i, model_name in enumerate(model_file): attrs = model_file[model_name] buff.write('{index}. name: {key}\n' 'created: {created}\n' '{repr}\n'.format(key=model_name, index=i + 1, created=attrs['created_readable'], repr=attrs['class_str'])) if attrs['saved_streaming_chain']: buff.write('\n---------Input chain---------\n') for j, x in enumerate(attrs['input_chain']): buff.write('{index}. {repr}\n'.format(index=j + 1, repr=x)) buff.write('-' * 80) buff.write('\n') buff.seek(0) print(buff.read()) else: import json json.dump(models, fp=sys.stdout) return 0
def test_model_not_existant(self): inst = np_container(None) inst.save(self.fn, 'foo') with self.assertRaises(ValueError) as cm: f = H5File(self.fn, model_name='bar') self.assertIn('"bar" not found', cm.exception.args[0])
def test_delete(self): inst = np_container(None) inst.save(self.fn, model_name='a') with H5File(self.fn, mode='a') as f: f.delete('a') self.assertNotIn('a', f.models_descriptive.keys())