def testLoadDefaultSettings(self): """ Test default config. """ settings = load_custom_settings( config_fname="definitely_doesnt_exist.yml", no_quickstart=True) self.assertEqual(settings.settings, DEFAULT_SETTINGS) settings.reset()
def get_element_colours(): """ Read element colours from VESTA file. The colours file can be specified in the matadorrc. If unspecified, the default ../config/vesta_elements.ini will be used. """ import os from matador.config import load_custom_settings, SETTINGS # check if element_colours has been given as an absolute path if SETTINGS: colours_fname = SETTINGS.get('plotting', {}).get('element_colours') else: colours_fname = load_custom_settings().get('plotting', {}).get('element_colours') if colours_fname is None: colours_fname = '/'.join(os.path.realpath(__file__).split('/') [:-1]) + '/../config/vesta_elements.ini' elif not os.path.isfile(colours_fname): print( 'Could not find {}, please specify an absolute path. Falling back to default...' .format(colours_fname)) # otherwise fallback to ../config/vesta_elements.ini colours_fname = '/'.join(os.path.realpath(__file__).split('/') [:-1]) + '/../config/vesta_elements.ini' with open(colours_fname, 'r') as f: flines = f.readlines() element_colours = dict() for line in flines: line = line.split() elem = line[1] colour = list(map(float, line[-3:])) element_colours[elem] = colour return element_colours
def testLoadUserDefaultSettings(self): """ Test default config. """ exists = False try: if os.path.isfile(os.path.expanduser("~/.matadorrc")): exists = True shutil.copy( os.path.expanduser("~/.matadorrc"), os.path.expanduser("~/.matadorrc_bak"), ) shutil.copy(REAL_PATH + "data/custom_config.yml", os.path.expanduser("~/.matadorrc")) settings = load_custom_settings(no_quickstart=True) self.assertEqual(settings.settings, DUMMY_SETTINGS) os.remove(os.path.expanduser("~/.matadorrc")) if exists: shutil.copy( os.path.expanduser("~/.matadorrc_bak"), os.path.expanduser("~/.matadorrc"), ) os.remove(os.path.expanduser("~/.matadorrc_bak")) except Exception as oops: if exists: shutil.copy( os.path.expanduser("~/.matadorrc_bak"), os.path.expanduser("~/.matadorrc"), ) os.remove(os.path.expanduser("~/.matadorrc_bak")) raise oops
def testLoadNamedCustomSettings(self): """ Test custom config. """ settings = load_custom_settings( config_fname=(REAL_PATH + "data/custom_config.yml"), no_quickstart=True) self.assertEqual(settings.settings, DUMMY_SETTINGS) from matador.config import SETTINGS self.assertEqual(SETTINGS.settings, DUMMY_SETTINGS) SETTINGS.reset()
def setUp(self): from matador.config import load_custom_settings, SETTINGS SETTINGS.reset() _ = load_custom_settings(config_fname=CONFIG_FNAME) SETTINGS["mongo"]["default_collection"] = DB_NAME SETTINGS["mongo"]["default_collection_file_path"] = "/data/" SETTINGS["mongo"]["host"] = "mongo_test.com" SETTINGS["mongo"]["port"] = 27017 self.settings = SETTINGS os.makedirs(OUTPUT_DIR, exist_ok=False) os.chdir(OUTPUT_DIR)
def wrapped_plot_function(*args, **kwargs): """ Wrap and return the plotting function. """ saving = False result = None # if we're going to be saving a figure, switch to Agg to avoid X-forwarding try: for arg in args: if arg.savefig: import matplotlib # don't warn as backend might have been set externally by e.g. Jupyter matplotlib.use('Agg', force=False) saving = True break except AttributeError: pass if not saving: if any(kwargs.get(ext) for ext in SAVE_EXTS): import matplotlib matplotlib.use('Agg', force=False) saving = True settings = load_custom_settings(kwargs.get('config_fname'), quiet=True, no_quickstart=True) try: style = settings.get('plotting', {}).get('default_style') if kwargs.get('style'): style = kwargs['style'] if style is not None and not isinstance(style, list): style = [style] if style is None: style = ['matador'] if 'matador' in style: for ind, styles in enumerate(style): if styles == 'matador': style[ind] = MATADOR_STYLE # now actually call the function set_style(style) result = function(*args, **kwargs) except Exception as exc: if 'TclError' not in type(exc).__name__: raise exc print_failure('Caught exception: {}'.format(type(exc).__name__)) print_warning('Error message was: {}'.format(exc)) print_warning('This is probably an X-forwarding error') print_failure('Skipping plot...') return result
def __init__(self, *args, **kwargs): """ Initialise the query with command line arguments and return results. """ # read args self.kwargs = kwargs self.args = vars(args[0]) self.args['no_quickstart'] = self.kwargs.get('no_quickstart') self.argstr = kwargs.get('argstr') file_exts = ['cell', 'res', 'pdb', 'markdown', 'latex', 'param', 'xsf'] self.export = any([self.args.get(ext) for ext in file_exts]) self.subcommand = self.args.pop("subcmd") if self.subcommand != 'import': self.settings = load_custom_settings( config_fname=self.args.get('config'), debug=self.args.get('debug'), no_quickstart=self.args.get('no_quickstart')) result = make_connection_to_collection( self.args.get('db'), check_collection=(self.subcommand != "stats"), mongo_settings=self.settings) self.client, self.db, self.collections = result if self.subcommand == 'stats': self.stats() try: if self.subcommand == 'import': from matador.db import Spatula self.importer = Spatula(self.args) if self.subcommand == 'query': self.query = DBQuery(self.client, self.collections, **self.args) self.cursor = self.query.cursor if self.subcommand == 'swaps': from matador.swaps import AtomicSwapper self.query = DBQuery(self.client, self.collections, **self.args) if self.args.get('hull_cutoff') is not None: self.hull = QueryConvexHull(query=self.query, **self.args) self.swapper = AtomicSwapper(self.hull.hull_cursor, **self.args) else: self.swapper = AtomicSwapper(self.query.cursor, **self.args) self.cursor = self.swapper.cursor if self.subcommand == 'refine': from matador.db import Refiner self.query = DBQuery(self.client, self.collections, **self.args) if self.args.get('hull_cutoff') is not None: self.hull = QueryConvexHull(self.query, **self.args) self.refiner = Refiner(self.hull.cursor, self.query.repo, **self.args) else: self.refiner = Refiner(self.query.cursor, self.query.repo, **self.args) self.cursor = self.refiner.cursor if self.subcommand == 'hull' or self.subcommand == 'voltage': self.hull = QueryConvexHull(**self.args, voltage=self.subcommand == 'voltage', client=self.client, collections=self.collections) self.cursor = self.hull.hull_cursor if self.subcommand == 'changes': from matador.db import DatabaseChanges if len(self.collections) != 1: raise SystemExit( 'Cannot view changes of more than one collection at once.' ) if self.args.get('undo'): action = 'undo' else: action = 'view' changeset = self.args.get('changeset') if changeset is None: changeset = 0 DatabaseChanges([key for key in self.collections][0], changeset_ind=changeset, action=action, mongo_settings=self.settings, override=kwargs.get('no_quickstart')) if self.subcommand == 'hulldiff': from matador.hull.hull_diff import diff_hulls if self.args.get('compare') is None: raise SystemExit( 'Please specify which hulls to query with --compare.') diff_hulls(self.client, self.collections, **self.args) if self.export and self.cursor: from matador.export import query2files if self.args.get('write_n') is not None: self.cursor = [ doc for doc in self.cursor if len(doc['stoichiometry']) == self.args.get('write_n') ] if not self.cursor: print_failure('No structures left to export.') query2files(self.cursor, **self.args, argstr=self.argstr, subcmd=self.subcommand, hash_dupe=True) if self.args.get('view'): from matador.utils.viz_utils import viz if self.args.get('top') is None: self.top = len(self.cursor) else: self.top = self.args.get('top') if len(self.cursor[:self.top]) > 10: from time import sleep print_warning( 'WARNING: opening {} files with ase-gui...'.format( len(self.cursor))) print_warning( 'Please kill script within 3 seconds if undesired...') sleep(3) if len(self.cursor[:self.top]) > 20: print_failure( 'You will literally be opening that many windows, ' + 'I\'ll give you another 5 seconds to reconsider...') sleep(5) print_notify('It\'s your funeral...') sleep(1) for doc in self.cursor[:self.top]: viz(doc) if self.subcommand != 'import': self.client.close() except (RuntimeError, SystemExit, KeyboardInterrupt) as oops: if isinstance(oops, RuntimeError): print_failure(oops) elif isinstance(oops, SystemExit): print_warning(oops) try: self.client.close() except AttributeError: pass raise oops
def __init__(self, *args, settings=None): """ Set up arguments and initialise DB client. Notes: Several arguments can be passed to this class from the command-line, and here are interpreted through *args: Parameters: db (str): the name of the collection to import to. scan (bool): whether or not to just scan the directory, rather than importing (automatically sets dryrun to true). dryrun (bool): perform whole process, up to actually importing to the database. tags (str): apply this tag to each structure added to database. force (bool): override rules about which folders can be imported into main database. recent_only (bool): if true, sort file lists by modification date and stop scanning when a file that already exists in database is found. """ self.args = args[0] self.dryrun = self.args.get('dryrun') self.scan = self.args.get('scan') self.recent_only = self.args.get('recent_only') if self.scan: self.dryrun = True self.debug = self.args.get('debug') self.verbosity = self.args.get('verbosity') or 0 self.config_fname = self.args.get('config') self.tags = self.args.get('tags') self.prototype = self.args.get('prototype') self.tag_dict = dict() self.tag_dict['tags'] = self.tags self.import_count = 0 self.skipped = 0 self.exclude_patterns = ['bad_castep', 'input'] self.errors = 0 self.struct_list = [] self.path_list = [] self.log = logging.getLogger('spatula') loglevel = {3: "DEBUG", 2: "INFO", 1: "WARN", 0: "ERROR"} self.log.setLevel(loglevel.get(self.verbosity, "WARN")) handler = logging.StreamHandler(sys.stdout) formatter = logging.Formatter( '%(asctime)s - %(name)s - %(levelname)s - %(message)s') handler.setFormatter(formatter) self.log.addHandler(handler) # I/O files if not self.dryrun: self.num_words = len(WORDS) self.num_nouns = len(NOUNS) if not self.scan: self.logfile = tempfile.NamedTemporaryFile(mode='w+t', delete=False) self.manifest = tempfile.NamedTemporaryFile(mode='w+t', delete=False) if settings is None: self.settings = load_custom_settings( config_fname=self.config_fname, debug=self.debug, no_quickstart=self.args.get('no_quickstart')) else: self.settings = settings result = make_connection_to_collection( self.args.get('db'), check_collection=False, import_mode=True, override=self.args.get('no_quickstart'), mongo_settings=self.settings) self.client, self.db, self.collections = result # perform some relevant collection-dependent checks assert len(self.collections) == 1, 'Can only import to one collection.' self.repo = list(self.collections.values())[0] # if trying to import to the default repo, without doing a dryrun or forcing it, then # check if we're in the protected directory, i.e. the only one that is allowed to import # to the default collection default_collection = recursive_get(self.settings, ['mongo', 'default_collection']) try: default_collection_file_path = recursive_get( self.settings, ['mongo', 'default_collection_file_path']) except KeyError: default_collection_file_path = None if self.args.get('db') is None or self.args.get( 'db') == default_collection: if not self.dryrun and not self.args.get('force'): # if using default collection, check we are in the correct path if default_collection_file_path is not None: if not os.getcwd().startswith( os.path.expanduser(default_collection_file_path)): print(80 * '!') print( 'You shouldn\'t be importing to the default database from this folder! ' 'Please use --db <YourDBName> to create a new collection, ' 'or copy these files to the correct place!') print(80 * '!') raise RuntimeError( 'Failed to import to default collection from ' 'current directory, import must be called from {}'. format(default_collection_file_path)) else: if self.args.get('db') is not None: if any(['oqmd' in db for db in self.args.get('db')]): exit('Cannot import directly to oqmd repo') elif len(self.args.get('db')) > 1: exit('Can only import to one collection.') num_prototypes_in_db = self.repo.count_documents({'prototype': True}) num_objects_in_db = self.repo.count_documents({}) self.log.info( "Found {} existing objects in database.".format(num_objects_in_db)) if self.args.get('prototype'): if num_prototypes_in_db != num_objects_in_db: raise SystemExit( 'I will not import prototypes to a non-prototype database!' ) else: if num_prototypes_in_db != 0: raise SystemExit( 'I will not import DFT calculations into a prototype database!' ) if not self.dryrun: # either drop and recreate or create spatula report collection self.db.spatula.drop() self.report = self.db.spatula # scan directory on init self.file_lists = self._scan_dir() # if import, as opposed to rebuild, scan for duplicates and remove from list if not self.args.get('subcmd') == 'rebuild': self.file_lists, skipped = self._scan_dupes(self.file_lists) self.skipped += skipped # print number of files found self._display_import() # only create dicts if not just scanning if not self.scan: # convert to dict and db if required self._files2db(self.file_lists) if self.import_count == 0: print('No new structures imported!') if not self.dryrun and self.import_count > 0: print('Successfully imported', self.import_count, 'structures!') # index by enthalpy for faster/larger queries count = 0 for _ in self.repo.list_indexes(): count += 1 # ignore default id index if count > 1: self.log.info('Index found, rebuilding...') self.repo.reindex() else: self.log.info('Building index...') self.repo.create_index([('enthalpy_per_atom', pm.ASCENDING)]) self.repo.create_index([('stoichiometry', pm.ASCENDING)]) self.repo.create_index([('cut_off_energy', pm.ASCENDING)]) self.repo.create_index([('species_pot', pm.ASCENDING)]) self.repo.create_index([('kpoints_mp_spacing', pm.ASCENDING)]) self.repo.create_index([('xc_functional', pm.ASCENDING)]) self.repo.create_index([('elems', pm.ASCENDING)]) # index by source for rebuilds self.repo.create_index([('source', pm.ASCENDING)]) self.log.info('Done!') elif self.dryrun: self.log.info('Dryrun complete!') if not self.scan: self.logfile.seek(0) errors = sum(1 for line in self.logfile) self.errors += errors if errors == 1: self.log.warning('There is 1 error to view in {}'.format( self.logfile.name)) elif errors == 0: self.log.warning('There are no errors to view in {}'.format( self.logfile.name)) elif errors > 1: self.log.warning('There are {} errors to view in {}'.format( errors, self.logfile.name)) try: self.logfile.close() except Exception: pass try: self.manifest.close() except Exception: pass if not self.dryrun: # construct dictionary in spatula_report collection to hold info report_dict = dict() report_dict['last_modified'] = datetime.datetime.utcnow().replace( microsecond=0) report_dict['num_success'] = self.import_count report_dict['num_errors'] = errors report_dict['version'] = __version__ self.report.insert_one(report_dict)
def make_connection_to_collection(coll_names, check_collection=False, allow_changelog=False, mongo_settings=None, override=False, import_mode=False, quiet=True, debug=False): """ Connect to database of choice. Parameters: coll_names (str): name of collection. Keyword Arguments: check_collection (bool): check whether collections exist (forces connection) allow_changelog (bool): allow queries to collections with names prefixed by __ mongo_settings (dict): dict containing mongo and related config override (bool): don't ask for user input from stdin and assume all is well quiet (bool): don't print very much. Returns: client (MongoClient): the connection to the database db (Database): the database to query collections (dict): Collection objects indexed by name """ if mongo_settings is None: settings = load_custom_settings(no_quickstart=override) else: settings = mongo_settings if not quiet: print('Trying to connect to {host}:{port}/{db}'.format(**settings['mongo'])) client = pm.MongoClient( host=settings['mongo']['host'], port=settings['mongo']['port'], connect=False, maxIdleTimeMS=600000, # disconnect after 10 minutes idle socketTimeoutMS=3600000, # give up on database after 1 hr without results serverSelectionTimeoutMS=10000, # give up on server after 2 seconds without results connectTimeoutMS=10000) # give up trying to connect to new database after 2 seconds try: database_names = client.list_database_names() if not quiet: print('Success!') except pm.errors.ServerSelectionTimeoutError as exc: print('{}: {}'.format(type(exc).__name__, exc)) raise SystemExit('Unable to connect to {host}:{port}/{db}, exiting...'.format(**settings['mongo'])) if settings['mongo']['db'] not in database_names: if override: response = 'y' else: response = input('Database {db} does not exist at {host}:{port}/{db}, ' 'would you like to create it? (y/n) ' .format(**settings['mongo'])) if response.lower() != 'y': raise SystemExit('Exiting...') else: print('Creating database {}'.format(settings['mongo']['db'])) db = client[settings['mongo']['db']] possible_collections = [name for name in db.list_collection_names() if not name.startswith('__')] collections = dict() # allow lists of collections for backwards-compat, though normally # we only want to connect to one at a time if coll_names is not None: if not isinstance(coll_names, list): coll_names = [coll_names] if len(coll_names) > 1: raise NotImplementedError("Querying multiple collections is no longer supported.") for collection in coll_names: if not allow_changelog: if collection.startswith('__'): raise SystemExit('Queries to collections prefixed with __ are VERBOTEN!') if collection not in possible_collections: options = fuzzy_collname_match(collection, possible_collections) if not options and check_collection: client.close() raise SystemExit('Collection {} not found!'.format(collection)) else: print('Collection {} not found, did you mean one of these?'.format(collection)) for ind, value in enumerate(options[:10]): print('({}):\t{}'.format(ind, value)) if check_collection: try: choice = int(input('Please enter your choice: ')) collection = options[choice] except Exception: raise SystemExit('Invalid choice. Exiting...') elif import_mode: if override: choice = 'y' else: choice = input('Are you sure you want to make a new collection called {}? (y/n) ' .format(collection)) if choice.lower() != 'y' and choice.lower != 'yes': try: choice = int(input('Then please enter your choice from above: ')) collection = options[choice] except Exception: raise SystemExit('Invalid choice. Exiting...') collections[collection] = db[collection] else: default_collection = settings['mongo']['default_collection'] if default_collection not in possible_collections: if check_collection: client.close() raise SystemExit('Default collection {} not found!'.format(default_collection)) else: print('Creating new collection {}...'.format(default_collection)) collections['repo'] = db[default_collection] return client, db, collections
def __init__( self, client=False, collections=False, subcmd='query', debug=False, quiet=False, mongo_settings=None, **kwargs ): """ Parse arguments from matador or API call before calling query. Keyword arguments: client (pm.MongoClient): the MongoClient to connect to. collections (dict of pm.collections.Collection): dictionary of pymongo Collections. subcmd (str): either 'query' or 'hull', 'voltage', 'hulldiff'. These will decide whether calcuation accuracies are matched in the final results. """ # read args and set housekeeping self.args = kwargs self.debug = debug if self.args.get('subcmd') is None: self.args['subcmd'] = subcmd if self.args.get('testing') is None: self.args['testing'] = False if self.args.get('as_crystal') is None: self.args['as_crystal'] = False if subcmd in ['hull', 'hulldiff', 'voltage'] and self.args.get('composition') is None: raise RuntimeError('{} requires composition query'.format(subcmd)) self._create_hull = (self.args.get('subcmd') in ['hull', 'hulldiff', 'voltage'] or self.args.get('hull_cutoff') is not None) # public attributes self.cursor = EmptyCursor() self.query_dict = None self.calc_dict = None self.repo = None # private attributes to be set later self._empty_query = None self._gs_enthalpy = None self._non_elemental = None self._chempots = None self._num_to_display = None if debug: print(self.args) if quiet: f = open(devnull, 'w') sys.stdout = f # if testing keyword is used, all database operations are ignored if not self.args.get('testing'): # connect to db or use passed client if client: self._client = client self._db = client.crystals if collections is not False: _collections = collections if (not collections or not client): # use passed settings or load from config file if mongo_settings: self.mongo_settings = mongo_settings else: self.mongo_settings = load_custom_settings( config_fname=self.args.get('config'), debug=self.args.get('debug') ) result = make_connection_to_collection( self.args.get('db'), mongo_settings=self.mongo_settings ) # ideally this would be rewritten to use a context manager to ensure # that connections are _always_ cleaned up self._client, self._db, _collections = result if len(_collections) > 1: raise NotImplementedError("Querying multiple collections is no longer supported.") else: for collection in _collections: self._collection = _collections[collection] break # define some periodic table macros self._periodic_table = get_periodic_table() # set default top value to 10 if self.args.get('summary') or self.args.get('subcmd') in ['swaps', 'polish']: self.top = None else: self.top = self.args.get('top') if self.args.get('top') is not None else 10 # create the dictionary to pass to MongoDB self._construct_query() if not self.args.get('testing'): if self.args.get('id') is not None and (self._create_hull or self.args.get('calc_match')): # if we've requested and ID and hull/calc_match, do the ID query self.perform_id_query() self.perform_query() if self._create_hull and self.args.get('id') is None: # if we're making a normal hull, find the sets of calculations to use self.perform_hull_query() if not self._create_hull: # only filter for uniqueness if not eventually making a hull if self.args.get('uniq'): from matador.utils.cursor_utils import filter_unique_structures print_notify('Filtering for unique structures...') if isinstance(self.cursor, pm.cursor.Cursor): raise RuntimeError("Unable to filter pymongo cursor for uniqueness directly.") if self.args.get('top') is not None: top = self.args['top'] else: top = len(self.cursor) self.cursor = filter_unique_structures( self.cursor[:top], debug=self.args.get('debug'), sim_tol=self.args.get('uniq'), energy_tol=1e20 ) if self.args.get('available_values') is not None: print('Querying available values...') self._query_available_values(self.args.get('available_values'), self.cursor) # if no client was passed, then we need to close the one we made if not client and not self.args.get('testing'): self._client.close() if quiet: f.close() sys.stdout = sys.__stdout__