Esempio n. 1
0
 def symmetry(self, symprec=1e-3):
     """ Compute space group with spglib. """
     from matador.utils.cell_utils import doc2spg
     import spglib as spg
     print('Refining symmetries...')
     if self.mode == 'display':
         print_warning('{}'.format('At symprec: ' + str(symprec)))
         print_warning("{:^36}{:^16}{:^16}".format('text_id', 'new sg', 'old sg'))
     for _, doc in enumerate(self.cursor):
         try:
             spg_cell = doc2spg(doc)
             sg = spg.get_spacegroup(spg_cell, symprec=symprec).split(' ')[0]
             if sg != doc['space_group']:
                 self.changed_count += 1
                 self.diff_cursor.append(doc)
                 if self.mode == 'display':
                     print_notify("{:^36}{:^16}{:^16}"
                                  .format(doc['text_id'][0]+' '+doc['text_id'][1], sg, doc['space_group']))
                 doc['space_group'] = sg
             else:
                 if self.mode == 'display':
                     print("{:^36}{:^16}{:^16}"
                           .format(doc['text_id'][0]+' '+doc['text_id'][1], sg, doc['space_group']))
         except Exception:
             self.failed_count += 1
             if self.args.get('debug'):
                 print_exc()
                 print_failure('Failed for' + ' '.join(doc['text_id']))
Esempio n. 2
0
 def plot_hull(self, **kwargs):
     """ Hull plot helper function. """
     from matador import plotting
     if self._dimension == 3:
         ax = plotting.plot_ternary_hull(self, **kwargs)
     elif self._dimension == 2:
         ax = plotting.plot_2d_hull(self, **kwargs)
     else:
         print_notify('Unable to plot phase diagram of dimension {}.'.format(self._dimension))
     return ax
Esempio n. 3
0
    def create_hull(self):
        """ Begin the hull creation routines and perform the
        post-processing specified by initial arguments.

        """
        if self.args.get('uniq'):
            from matador.utils.cursor_utils import filter_unique_structures
            if self.args.get('uniq') is True:
                sim_tol = 0.1
            else:
                sim_tol = self.args.get('uniq')
            print_notify('Filtering for unique structures...')
            self.cursor = filter_unique_structures(self.cursor,
                                                   args=self.args,
                                                   quiet=True,
                                                   sim_tol=sim_tol,
                                                   hull=True,
                                                   energy_key=self.energy_key)

        self.construct_phase_diagram()

        if not self.hull_cursor:
            print_warning(
                'No structures on hull with chosen chemical potentials.')
        else:
            print_notify(
                '{} structures found within {} eV of the hull, including chemical potentials.'
                .format(len(self.hull_cursor), self.hull_cutoff))

        display_results(self.hull_cursor,
                        hull=True,
                        energy_key=self.energy_key,
                        **self.args)

        if self.compute_voltages:
            print("Constructing electrode system with active ion: {}".format(
                self.species[0]))
            self.voltage_curve()

        if self.compute_volumes:
            self.volume_curve()

        if not self.args.get('no_plot'):
            if self.compute_voltages and self.voltage_data:
                self.plot_voltage_curve(show=False)
            if self.compute_volumes and self.volume_data:
                self.plot_volume_curve(show=False)

            self.plot_hull(**self.args['plot_kwargs'],
                           debug=self.args.get('debug'),
                           show=True)
Esempio n. 4
0
 def spawn(self):
     """ Spawn processes to perform PDF fitting. """
     print_notify('Performing ' + str(self.nprocesses) +
                  ' concurrent fits.')
     procs = []
     for ind in range(self.nprocesses):
         procs.append(mp.Process(target=self.perform_fits))
     try:
         for proc in procs:
             proc.start()
     except (KeyboardInterrupt, SystemExit, RuntimeError):
         for proc in procs:
             proc.terminate()
         exit('Killing running jobs and exiting...')
Esempio n. 5
0
    def update_docs(self):
        """ Updates documents in database with correct priority. """
        requests = []
        # if in "set" mode, do not overwrite, just apply
        if self.mode == 'set':
            for _, doc in enumerate(self.diff_cursor):
                requests.append(pm.UpdateOne({'_id': doc['_id'], self.field: {'$exists': False}},
                                             {'$set': {self.field: doc[self.field]}}))
        # else if in overwrite mode, overwrite previous field
        elif self.mode == 'overwrite':
            for _, doc in enumerate(self.diff_cursor):
                requests.append(pm.UpdateOne({'_id': doc['_id']}, {'$set': {self.field: doc[self.field]}}))
        if self.args.get('debug'):
            for request in requests:
                print(request)

        result = self.collection.bulk_write(requests)
        print_notify(str(result.modified_count) + ' docs modified.')
Esempio n. 6
0
    def stats(self):
        """ Print some useful stats about the database. """
        if self.args.get('list'):
            print_notify(
                str(len(self.db.list_collection_names())) +
                ' collections found in database:\n')
            collstats_list = []
            for name in self.db.list_collection_names():
                collstats_list.append(self.db.command('collstats', name))
                collstats_list[-1]['name'] = name
            collstats_list = sorted(collstats_list,
                                    key=lambda k: k['count'],
                                    reverse=True)
            print("\t{:^20}\t{:^20}".format('Name', 'Number of structures'))
            for collection in collstats_list:
                if not collection['name'].startswith('__'):
                    print("\t{:<20}\t{:>20d}".format(collection['name'],
                                                     collection['count']))
            print('\n')
        elif self.args.get('delete'):
            target = self.args.get('db')
            if isinstance(target, list) and len(target) == 1:
                target = target[0]
            else:
                raise SystemExit(
                    'I will only delete one collection at a time...')
            if target is None:
                raise SystemExit('Please specify a collection to delete.')
            if target not in self.db.list_collection_names():
                raise SystemExit(
                    'No collection named {} was found'.format(target))

            from getpass import getuser
            user = getuser()
            if user not in target:
                raise SystemExit(
                    'I cannot delete a collection that\'s name does not start with '
                    'your username, {}'.format(user))
            stats = self.db.command('collstats', target)

            if self.args.get('no_quickstart'):
                answer = 'y'
            else:
                answer = input(
                    'Are you sure you want to delete collection {} containing {} '
                    'structures? [y/n]\n'.format(target, stats['count']))
            if answer.lower() == 'y':
                if target == 'repo':
                    raise SystemExit(
                        'I\'m sorry Dave, I\'m afraid I can\'t do that...')
                print('Deleting collection {}...'.format(target))
                self.db[target].drop()
                print('and its changelog...')
                self.db['__changelog_{}'.format(target)].drop()
            else:
                raise SystemExit('Nevermind then!')

        else:
            comp_list = dict()
            stats_dict = dict()
            stats_dict['count'] = 0
            stats_dict['avgObjSize'] = 0
            stats_dict['storageSize'] = 0
            stats_dict['totalIndexSize'] = 0
            for collection in self.collections:
                db_stats_dict = self.db.command('collstats', collection)
                stats_dict['count'] += db_stats_dict['count']
                stats_dict['avgObjSize'] += db_stats_dict['avgObjSize']
                stats_dict['storageSize'] += db_stats_dict['storageSize']
                stats_dict['totalIndexSize'] += db_stats_dict['totalIndexSize']
            print((
                "The collection(s) queried in {} contain {} structures at {:.1f} kB each "
                "totalling {:.1f} MB with a further {:.1f} MB of indexes."
            ).format(self.db.name, stats_dict['count'],
                     stats_dict['avgObjSize'] / (1024),
                     stats_dict['storageSize'] / (1024**2),
                     stats_dict['totalIndexSize'] / (1024**2)))
            for collname in self.collections:
                cursor = self.collections[collname].find()
                for doc in cursor:
                    temp = ''
                    for ind, elem in enumerate(sorted(doc['stoichiometry'])):
                        temp += str(elem[0])
                        if ind != len(doc['stoichiometry']) - 1:
                            temp += '+'
                    if temp not in comp_list:
                        comp_list[temp] = 0
                    comp_list[temp] += 1
            keys = list(comp_list.keys())
            vals = list(comp_list.values())
            comp_list = list(zip(keys, vals))
            comp_list.sort(key=lambda t: t[1], reverse=True)
            small_count = 0
            first_ind = 1000
            cutoff = 100
            for ind, comp in enumerate(comp_list):
                if comp[1] < cutoff:
                    if ind < first_ind:
                        first_ind = ind
                    small_count += comp[1]
            comp_list = comp_list[:first_ind]
            comp_list.append(['others < ' + str(cutoff), small_count])
            comp_list.sort(key=lambda t: t[1], reverse=True)
            try:
                from ascii_graph import Pyasciigraph
                from ascii_graph.colors import Gre, Blu, Red
                from ascii_graph.colordata import hcolor
            except ImportError:
                print(
                    "ascii_graph dependency not found, not creating histogram."
                )
            else:
                graph = Pyasciigraph(line_length=80, multivalue=False)
                thresholds = {
                    int(stats_dict['count'] / 40): Gre,
                    int(stats_dict['count'] / 10): Blu,
                    int(stats_dict['count'] / 4): Red
                }
                data = hcolor(comp_list, thresholds)
                for line in graph.graph(label=None, data=data):
                    print(line)
                print('\n')

                for comp in comp_list:
                    print(comp)
Esempio n. 7
0
    def __init__(self, *args, **kwargs):
        """ Initialise the query with command line arguments and return
        results.

        """
        # read args
        self.kwargs = kwargs
        self.args = vars(args[0])
        self.args['no_quickstart'] = self.kwargs.get('no_quickstart')
        self.argstr = kwargs.get('argstr')

        file_exts = ['cell', 'res', 'pdb', 'markdown', 'latex', 'param', 'xsf']
        self.export = any([self.args.get(ext) for ext in file_exts])

        self.subcommand = self.args.pop("subcmd")

        if self.subcommand != 'import':
            self.settings = load_custom_settings(
                config_fname=self.args.get('config'),
                debug=self.args.get('debug'),
                no_quickstart=self.args.get('no_quickstart'))
            result = make_connection_to_collection(
                self.args.get('db'),
                check_collection=(self.subcommand != "stats"),
                mongo_settings=self.settings)
            self.client, self.db, self.collections = result

        if self.subcommand == 'stats':
            self.stats()

        try:
            if self.subcommand == 'import':
                from matador.db import Spatula
                self.importer = Spatula(self.args)

            if self.subcommand == 'query':
                self.query = DBQuery(self.client, self.collections,
                                     **self.args)
                self.cursor = self.query.cursor

            if self.subcommand == 'swaps':
                from matador.swaps import AtomicSwapper
                self.query = DBQuery(self.client, self.collections,
                                     **self.args)
                if self.args.get('hull_cutoff') is not None:
                    self.hull = QueryConvexHull(query=self.query, **self.args)
                    self.swapper = AtomicSwapper(self.hull.hull_cursor,
                                                 **self.args)
                else:
                    self.swapper = AtomicSwapper(self.query.cursor,
                                                 **self.args)
                self.cursor = self.swapper.cursor

            if self.subcommand == 'refine':
                from matador.db import Refiner
                self.query = DBQuery(self.client, self.collections,
                                     **self.args)
                if self.args.get('hull_cutoff') is not None:
                    self.hull = QueryConvexHull(self.query, **self.args)
                    self.refiner = Refiner(self.hull.cursor, self.query.repo,
                                           **self.args)
                else:
                    self.refiner = Refiner(self.query.cursor, self.query.repo,
                                           **self.args)

                self.cursor = self.refiner.cursor

            if self.subcommand == 'hull' or self.subcommand == 'voltage':
                self.hull = QueryConvexHull(**self.args,
                                            voltage=self.subcommand ==
                                            'voltage',
                                            client=self.client,
                                            collections=self.collections)
                self.cursor = self.hull.hull_cursor

            if self.subcommand == 'changes':
                from matador.db import DatabaseChanges
                if len(self.collections) != 1:
                    raise SystemExit(
                        'Cannot view changes of more than one collection at once.'
                    )
                if self.args.get('undo'):
                    action = 'undo'
                else:
                    action = 'view'
                changeset = self.args.get('changeset')
                if changeset is None:
                    changeset = 0
                DatabaseChanges([key for key in self.collections][0],
                                changeset_ind=changeset,
                                action=action,
                                mongo_settings=self.settings,
                                override=kwargs.get('no_quickstart'))

            if self.subcommand == 'hulldiff':
                from matador.hull.hull_diff import diff_hulls
                if self.args.get('compare') is None:
                    raise SystemExit(
                        'Please specify which hulls to query with --compare.')
                diff_hulls(self.client, self.collections, **self.args)

            if self.export and self.cursor:
                from matador.export import query2files
                if self.args.get('write_n') is not None:
                    self.cursor = [
                        doc for doc in self.cursor if len(doc['stoichiometry'])
                        == self.args.get('write_n')
                    ]
                if not self.cursor:
                    print_failure('No structures left to export.')
                query2files(self.cursor,
                            **self.args,
                            argstr=self.argstr,
                            subcmd=self.subcommand,
                            hash_dupe=True)

            if self.args.get('view'):
                from matador.utils.viz_utils import viz
                if self.args.get('top') is None:
                    self.top = len(self.cursor)
                else:
                    self.top = self.args.get('top')
                if len(self.cursor[:self.top]) > 10:
                    from time import sleep
                    print_warning(
                        'WARNING: opening {} files with ase-gui...'.format(
                            len(self.cursor)))
                    print_warning(
                        'Please kill script within 3 seconds if undesired...')
                    sleep(3)
                if len(self.cursor[:self.top]) > 20:
                    print_failure(
                        'You will literally be opening that many windows, ' +
                        'I\'ll give you another 5 seconds to reconsider...')

                    sleep(5)
                    print_notify('It\'s your funeral...')
                    sleep(1)
                for doc in self.cursor[:self.top]:
                    viz(doc)

            if self.subcommand != 'import':
                self.client.close()

        except (RuntimeError, SystemExit, KeyboardInterrupt) as oops:
            if isinstance(oops, RuntimeError):
                print_failure(oops)
            elif isinstance(oops, SystemExit):
                print_warning(oops)
            try:
                self.client.close()
            except AttributeError:
                pass
            raise oops
Esempio n. 8
0
    def set_chempots(self, energy_key=None):
        """ Search for chemical potentials that match the structures in
        the query cursor and add them to the cursor. Also set the concentration
        of chemical potentials in :attr:`cursor`, if not already set.

        """
        if energy_key is None:
            energy_key = self.energy_key
        query = self._query
        query_dict = dict()
        species_stoich = [sorted(get_stoich_from_formula(spec, sort=False)) for spec in self.species]
        self.chempot_cursor = []

        if self.args.get('chempots') is not None:
            self.fake_chempots(custom_elem=self.species)

        elif self.from_cursor:
            chempot_cursor = sorted([doc for doc in self.cursor if doc['stoichiometry'] in species_stoich],
                                    key=lambda doc: recursive_get(doc, energy_key))

            for species in species_stoich:
                for doc in chempot_cursor:
                    if doc['stoichiometry'] == species:
                        self.chempot_cursor.append(doc)
                        break

            if len(self.chempot_cursor) != len(self.species):
                raise RuntimeError('Found {} of {} required chemical potentials'
                                   .format(len(self.chempot_cursor), len(self.species)))

            if self.args.get('debug'):
                print([mu['stoichiometry'] for mu in self.chempot_cursor])

        else:
            print(60 * '─')
            self.chempot_cursor = len(self.species) * [None]
            # scan for suitable chem pots in database
            for ind, elem in enumerate(self.species):

                print('Scanning for suitable', elem, 'chemical potential...')
                query_dict['$and'] = deepcopy(list(query.calc_dict['$and']))

                if not self.args.get('ignore_warnings'):
                    query_dict['$and'].append(query.query_quality())

                if len(species_stoich[ind]) == 1:
                    query_dict['$and'].append(query.query_composition(custom_elem=[elem]))
                else:
                    query_dict['$and'].append(query.query_stoichiometry(custom_stoich=[elem]))

                # if oqmd, only query composition, not parameters
                if query.args.get('tags') is not None:
                    query_dict['$and'].append(query.query_tags())

                mu_cursor = query.repo.find(SON(query_dict)).sort(energy_key, pm.ASCENDING)
                if mu_cursor.count() == 0:
                    print_notify('Failed... searching without spin polarization field...')
                    scanned = False
                    while not scanned:
                        for idx, dicts in enumerate(query_dict['$and']):
                            for key in dicts:
                                if key == 'spin_polarized':
                                    del query_dict['$and'][idx][key]
                                    break
                            if idx == len(query_dict['$and']) - 1:
                                scanned = True
                    mu_cursor = query.repo.find(SON(query_dict)).sort(energy_key, pm.ASCENDING)

                if mu_cursor.count() == 0:
                    raise RuntimeError('No chemical potentials found for {}...'.format(elem))

                self.chempot_cursor[ind] = mu_cursor[0]
                if self.chempot_cursor[ind] is not None:
                    print('Using', ''.join([self.chempot_cursor[ind]['text_id'][0], ' ',
                                            self.chempot_cursor[ind]['text_id'][1]]), 'as chem pot for', elem)
                    print(60 * '─')
                else:
                    raise RuntimeError('No possible chem pots available for {}.'.format(elem))

            for i, mu in enumerate(self.chempot_cursor):
                self.chempot_cursor[i][self._extensive_energy_key + '_per_b'] = mu[energy_key]
                self.chempot_cursor[i]['num_a'] = 0

            self.chempot_cursor[0]['num_a'] = float('inf')

        # don't check for IDs if we're loading from cursor
        if not self.from_cursor:
            ids = [doc['_id'] for doc in self.cursor]
            if self.chempot_cursor[0]['_id'] is None or self.chempot_cursor[0]['_id'] not in ids:
                self.cursor.insert(0, self.chempot_cursor[0])
            for match in self.chempot_cursor[1:]:
                if match['_id'] is None or match['_id'] not in ids:
                    self.cursor.append(match)

        # add faked chempots to overall cursor
        elif self.args.get('chempots') is not None:
            self.cursor.insert(0, self.chempot_cursor[0])
            self.cursor.extend(self.chempot_cursor[1:])

        # find all elements present in the chemical potentials
        elements = []
        for mu in self.chempot_cursor:
            for elem, _ in mu['stoichiometry']:
                if elem not in elements:
                    elements.append(elem)
        self.elements = elements
        self.num_elements = len(elements)
Esempio n. 9
0
    def __init__(self,
                 collection_name: str,
                 changeset_ind=0,
                 action='view',
                 override=False,
                 mongo_settings=None):
        """ Parse arguments and run changes interface.

        Parameters:
            collection_name (str): the base collection name to act upon

        Keyword arguments:
            changset_ind (int): the number of the changset to act upon (1 is oldest)
            action (str): either 'view' or 'undo'
            override (bool): override all options to positive answers for testing
            mongo_settings (dict): dictionary of already-sources mongo settings

        """
        self.changelog_name = '__changelog_{}'.format(collection_name)
        _, _, self.collections = make_connection_to_collection(
            self.changelog_name,
            allow_changelog=True,
            override=override,
            mongo_settings=mongo_settings)
        self.repo = [self.collections[key] for key in self.collections][0]
        curs = list(self.repo.find())

        if not curs:
            exit('No changesets found for {}'.format(collection_name))

        # if no changeset specified, print summary
        if changeset_ind == 0:
            self.print_change_summary(curs)

        elif changeset_ind > len(curs):
            exit('No changeset {} found for collection called "{}".'.format(
                changeset_ind, collection_name))

        # otherwise, try to act on particular changeset
        elif changeset_ind <= len(curs):
            self.change = curs[changeset_ind - 1]
            self.view_changeset(self.change, changeset_ind - 1)
            if action == 'undo':
                count = curs[changeset_ind - 1]['count']
                print_warning(
                    'An attempt will now be made to remove {} structures from {}.'
                    .format(count, collection_name))
                print_notify('Are you sure you want to do that? (y/n)')
                if override:
                    response = 'y'
                else:
                    response = input()
                if response.lower() == 'y':
                    print_notify('You don\'t have any doubts at all? (y/n)')
                    if override:
                        next_response = 'n'
                    else:
                        next_response = input()
                    if next_response.lower() == 'n':
                        print('You\'re the boss, deleting structures now...')
                    else:
                        exit('As I thought...')
                else:
                    exit()

                # proceed with deletion
                _, _, collections = make_connection_to_collection(
                    collection_name, allow_changelog=False, override=override)
                collection_to_delete_from = [
                    collections[key] for key in collections
                ][0]
                result = collection_to_delete_from.delete_many(
                    {'_id': {
                        '$in': self.change['id_list']
                    }})
                print('Deleted {}/{} successfully.'.format(
                    result.deleted_count, self.change['count']))
                print('Tidying up changelog database...')
                self.repo.delete_one({'_id': self.change['_id']})
                if not self.repo.find_one():
                    print('No structures left remaining, deleting database...')
                    collection_to_delete_from.drop()
                    self.repo.drop()
                print('Success!')
Esempio n. 10
0
    def __init__(self, cursor, required_inds=None, debug=False, **fprint_args):
        """ Compute PDFs over n processes, where n is set by either
        ``$SLURM_NTASKS``, ``$OMP_NUM_THREADS`` or physical core count.

        Parameters:
            cursor (list of dict): list of matador structures
            fingerprint (Fingerprint): class to compute for each structure

        Keyword arguments:
            pdf_args (dict): arguments to pass to the fingerprint calculator
            required_inds (list(int)): indices in cursor to skip.

        """
        if required_inds is None:
            required_inds = list(range(len(cursor)))
        elif len(required_inds) == 0:
            return
        else:
            print(
                "Skipping {} structures out of {} as no comparisons are required"
                .format(len(cursor) - len(required_inds), len(cursor)))

        if self.fingerprint is None or self.default_key is None:
            raise NotImplementedError(
                'Do not create FingerprintFactory directly, '
                'use the appropriate sub-class!')

        # create list of empty (lazy) PDF objects
        if 'lazy' in fprint_args:
            del fprint_args['lazy']

        for ind, doc in enumerate(cursor):
            if isinstance(doc, Crystal):
                doc._data.pop(self.default_key, None)
            if ind in required_inds:
                doc[self.default_key] = self.fingerprint(doc,
                                                         lazy=True,
                                                         **fprint_args)
            else:
                doc[self.default_key] = None

        compute_list = [
            doc for ind, doc in enumerate(cursor) if ind in required_inds
        ]

        # how many processes to use? either SLURM_NTASKS, OMP_NUM_THREADS or total num CPUs
        if os.environ.get('SLURM_NTASKS') is not None:
            self.nprocs = int(os.environ.get('SLURM_NTASKS'))
            env = '$SLURM_NTASKS'
        elif os.environ.get('OMP_NUM_THREADS') is not None:
            self.nprocs = int(os.environ.get('OMP_NUM_THREADS'))
            env = '$OMP_NUM_THREADS'
        else:
            self.nprocs = psutil.cpu_count(logical=False)
            env = 'core count'
        print_notify(
            'Running {} jobs on at most {} processes, set by {}.'.format(
                len(required_inds), self.nprocs, env))
        self.nprocs = min(len(compute_list), self.nprocs)

        start = time.time()
        if self.nprocs == 1:
            import tqdm
            for ind, doc in tqdm.tqdm(enumerate(cursor)):
                if cursor[ind][self.default_key] is not None:
                    cursor[ind][self.default_key].calculate()
        else:
            pool = mp.Pool(processes=self.nprocs)
            fprint_cursor = []
            # for large cursors, set chunk to at most 16
            # for smaller cursors, tend to use chunksize 1 for improved load balancing
            chunksize = min(
                max(1, int(0.25 * len(compute_list) / self.nprocs)), 16)
            results = pool.map_async(functools.partial(
                _calc_fprint_pool_wrapper, key=self.default_key),
                                     compute_list,
                                     callback=fprint_cursor.extend,
                                     error_callback=print,
                                     chunksize=chunksize)
            pool.close()
            width = len(str(len(required_inds)))
            total = len(required_inds)
            while not results.ready():
                sys.stdout.write(
                    '{done:{width}d} / {total:{width}d}  {percentage:3d}%\r'.
                    format(width=width,
                           done=total - results._number_left,
                           total=total,
                           percentage=int(
                               100 * (total - results._number_left) / total)))
                sys.stdout.flush()
                time.sleep(1)

            if len(fprint_cursor) != len(required_inds):
                raise RuntimeError(
                    'There was an error calculating the desired Fingerprint')

            fprint_ind = 0
            for ind, doc in enumerate(cursor):
                if ind in required_inds:
                    if isinstance(cursor[ind], Crystal):
                        cursor[ind]._data.pop(self.default_key, None)
                    cursor[ind][self.default_key] = fprint_cursor[fprint_ind][
                        self.default_key]
                    fprint_ind += 1

        elapsed = time.time() - start
        if debug:
            pool.close()
            print('Compute time: {:.4f} s'.format(elapsed))
            print('Work complete!')
Esempio n. 11
0
    def __init__(
        self,
        client=False,
        collections=False,
        subcmd='query',
        debug=False,
        quiet=False,
        mongo_settings=None,
        **kwargs
    ):
        """ Parse arguments from matador or API call before calling
        query.

        Keyword arguments:
            client (pm.MongoClient): the MongoClient to connect to.
            collections (dict of pm.collections.Collection): dictionary of pymongo Collections.
            subcmd (str): either 'query' or 'hull', 'voltage', 'hulldiff'.
                These will decide whether calcuation accuracies are matched
                in the final results.

        """
        # read args and set housekeeping
        self.args = kwargs
        self.debug = debug
        if self.args.get('subcmd') is None:
            self.args['subcmd'] = subcmd
        if self.args.get('testing') is None:
            self.args['testing'] = False
        if self.args.get('as_crystal') is None:
            self.args['as_crystal'] = False

        if subcmd in ['hull', 'hulldiff', 'voltage'] and self.args.get('composition') is None:
            raise RuntimeError('{} requires composition query'.format(subcmd))

        self._create_hull = (self.args.get('subcmd') in ['hull', 'hulldiff', 'voltage'] or
                             self.args.get('hull_cutoff') is not None)

        # public attributes
        self.cursor = EmptyCursor()
        self.query_dict = None
        self.calc_dict = None
        self.repo = None

        # private attributes to be set later
        self._empty_query = None
        self._gs_enthalpy = None
        self._non_elemental = None
        self._chempots = None
        self._num_to_display = None

        if debug:
            print(self.args)

        if quiet:
            f = open(devnull, 'w')
            sys.stdout = f

        # if testing keyword is used, all database operations are ignored
        if not self.args.get('testing'):

            # connect to db or use passed client
            if client:
                self._client = client
                self._db = client.crystals
            if collections is not False:
                _collections = collections

            if (not collections or not client):
                # use passed settings or load from config file
                if mongo_settings:
                    self.mongo_settings = mongo_settings
                else:
                    self.mongo_settings = load_custom_settings(
                        config_fname=self.args.get('config'), debug=self.args.get('debug')
                    )

                result = make_connection_to_collection(
                    self.args.get('db'), mongo_settings=self.mongo_settings
                )
                # ideally this would be rewritten to use a context manager to ensure
                # that connections are _always_ cleaned up
                self._client, self._db, _collections = result

            if len(_collections) > 1:
                raise NotImplementedError("Querying multiple collections is no longer supported.")
            else:
                for collection in _collections:
                    self._collection = _collections[collection]
                    break

        # define some periodic table macros
        self._periodic_table = get_periodic_table()

        # set default top value to 10
        if self.args.get('summary') or self.args.get('subcmd') in ['swaps', 'polish']:
            self.top = None
        else:
            self.top = self.args.get('top') if self.args.get('top') is not None else 10

        # create the dictionary to pass to MongoDB
        self._construct_query()

        if not self.args.get('testing'):

            if self.args.get('id') is not None and (self._create_hull or self.args.get('calc_match')):
                # if we've requested and ID and hull/calc_match, do the ID query
                self.perform_id_query()

            self.perform_query()

            if self._create_hull and self.args.get('id') is None:
                # if we're making a normal hull, find the sets of calculations to use
                self.perform_hull_query()

            if not self._create_hull:
                # only filter for uniqueness if not eventually making a hull
                if self.args.get('uniq'):
                    from matador.utils.cursor_utils import filter_unique_structures
                    print_notify('Filtering for unique structures...')

                    if isinstance(self.cursor, pm.cursor.Cursor):
                        raise RuntimeError("Unable to filter pymongo cursor for uniqueness directly.")

                    if self.args.get('top') is not None:
                        top = self.args['top']
                    else:
                        top = len(self.cursor)

                    self.cursor = filter_unique_structures(
                        self.cursor[:top],
                        debug=self.args.get('debug'),
                        sim_tol=self.args.get('uniq'),
                        energy_tol=1e20
                    )

            if self.args.get('available_values') is not None:
                print('Querying available values...')
                self._query_available_values(self.args.get('available_values'), self.cursor)

            # if no client was passed, then we need to close the one we made
            if not client and not self.args.get('testing'):
                self._client.close()

        if quiet:
            f.close()
            sys.stdout = sys.__stdout__