コード例 #1
0
 def fundamental_gaps(self, query=None, sort='system'):
     """
     print the gap info for all the systems in the database
     :param query:
     :param sort: default sort per system
     :return:
     """
     mp_key = os.environ['MP_KEY']
     if query is None:
         query = {}
     for item in self.col.find(query).sort(sort):
         print('')
         print('System    : ', item['system'].split('_')[0])
         print('Ps        : ', item['ps'])
         print('extra     : ', item['extra_vars'])
         print('gwresults : ', item['gw_results'])
         print('item      : ', item['item'])
         if 'mp-' in item['item']:
             try:
                 with MPRester(mp_key) as mp_database:
                     gap = {}
                     bandstructure = mp_database.get_bandstructure_by_material_id(
                         item['item'])
                     gap['vbm_l'] = bandstructure.kpoints[
                         bandstructure.get_vbm()['kpoint_index'][0]].label
                     gap['cbm_l'] = bandstructure.kpoints[
                         bandstructure.get_cbm()['kpoint_index'][0]].label
                     gap['vbm_e'] = bandstructure.get_vbm()['energy']
                     gap['cbm_e'] = bandstructure.get_cbm()['energy']
                     gap['cbm'] = tuple(
                         bandstructure.kpoints[bandstructure.get_cbm(
                         )['kpoint_index'][0]].frac_coords)
                     gap['vbm'] = tuple(
                         bandstructure.kpoints[bandstructure.get_vbm(
                         )['kpoint_index'][0]].frac_coords)
             except (MPRestError, IndexError, KeyError) as err:
                 print(err.message)
                 gap = None
         else:
             gap = None
         if gap:
             print(gap['cbm_l'], gap['vbm_l'])
             print(gap['cbm'], gap['vbm'])
         try:
             data = self.gfs.get(item['results_file']).read()
             if len(data) > 1000:
                 srf = MySigResFile(data)
                 srf.print_gap_info()
         except IOError:
             print('No Sigres file in DataBase')
コード例 #2
0
    def sigres_plots(self, query=None, sort='system'):
        """
        Plot the scissor operators for all the systems in the DB
        :param query:
        :param sort:
        :return:
        """
        if query is None:
            query = {}
        for item in self.col.find(query).sort(sort):
            print('System    : ', item['system'].split('_')[0])
            print('Ps        : ', item['ps'])
            print('extra     : ', item['extra_vars'])
            print('gwresults : ', item['gw_results'])
            try:
                print('data  : ', item['results_file'])
                data = self.gfs.get(item['results_file']).read()
                if len(data) > 1000:
                    srf = MySigResFile(data)
                    title = "QPState corrections of " + str(item['system']) + "\nusing "\
                            + str(item['ps'].split('/')[-2])
                    if item['extra_vars'] is not None:
                        title += ' and ' + str(
                            self.fix_extra(item['extra_vars']))
                    fig = srf.plot_scissor(title=title)
                    srf.print_gap_info()
                    sc = srf.get_scissor()
                    try:
                        if not item['tgwgap']:
                            bands_data = self.gfs.get(item['ksbands_file'])
                            ks_bands = MyBandsFile(bands_data)
                            item['tkshomo'] = ks_bands.h**o
                            item['tkslumo'] = ks_bands.lumo
                            item['tksgap'] = ks_bands.lumo - ks_bands.h**o
                            item['tgwhomo'] = ks_bands.h**o + sc.apply(
                                ks_bands.h**o)
                            item['tgwlumo'] = ks_bands.lumo + sc.apply(
                                ks_bands.lumo)
                            item['tgwgap'] = item['tgwlumo'] - item['tgwhomo']
                            self.col.save(item)
                    except KeyError:
                        pass
            except IOError:
                print('No Sigres file in DataBase')

        try:
            return fig
        except UnboundLocalError:
            return None
コード例 #3
0
ファイル: test_cycle.py プロジェクト: setten/HTGW
    def test_SiC_conv(self):
        """
        Testing a full convergence calculation cycle on SiC using precomupted data.
        """

        # the current version uses refence data from a run using the production version on zenobe
        # once all checks out the run should be done using the input generated using this version to replace the
        # reference

        wdir = tempfile.mkdtemp()
        os.chdir(wdir)

        temp_ABINIT_PS_EXT = os.environ.get('ABINIT_PS_EXT', None)
        temp_ABINIT_PS = os.environ.get('ABINIT_PS', None)

        os.environ['ABINIT_PS_EXT'] = '.psp8'
        os.environ['ABINIT_PS'] = wdir

        reference_dir = os.path.join(__reference_dir__, 'SiC_test_case')
        if not os.path.isdir(reference_dir): raise RuntimeError('py.test needs to be started in the HTGW root, '
                                                                '%s does not exist' % __reference_dir__)

        # copy input
        print(wdir)
        self.assertTrue(os.path.isdir(reference_dir))
        src_files = os.listdir(reference_dir)
        for file_name in src_files:
            full_file_name = os.path.join(reference_dir, file_name)
            if os.path.isfile(full_file_name):
                shutil.copy(full_file_name, wdir)
        self.assertEqual(len(os.listdir(wdir)), 6)

        print(os.listdir(wdir))
        structure = Structure.from_file('SiC.cif')
        structure.item = 'SiC.cif'

        print(' ==== generate flow ===  ')
        gwsetup(update=False)
        self.assertTrue(os.path.isdir(os.path.join(wdir, 'SiC_SiC.cif')))
        print(os.listdir(os.path.join(wdir)))
        print(os.listdir(os.path.join(wdir, 'SiC_SiC.cif')))
        self.assertTrue(os.path.isfile(os.path.join(wdir, 'SiC_SiC.cif', '__AbinitFlow__.pickle')))
        self.assertEqual(len(os.listdir(os.path.join(wdir, 'SiC_SiC.cif', 'w0'))), 48)

        print(' ==== copy reference results from first calculation ===  ')
        shutil.rmtree(os.path.join(wdir, 'SiC_SiC.cif'))
        shutil.copytree(os.path.join(reference_dir, 'ref_res', 'SiC_SiC.cif'), os.path.join(wdir, 'SiC_SiC.cif'))
        self.assertTrue(os.path.isdir(os.path.join(wdir, 'SiC_SiC.cif')))
        self.assertEqual(len(os.listdir(os.path.join(wdir, 'SiC_SiC.cif', 'w0'))), 68)

        print(' ==== process output ===  ')
        gwoutput()
        print(os.listdir('.'))
        self.assertTrue(os.path.isfile('plot-fits'))
        self.assertTrue(os.path.isfile('plots'))
        self.assertEqual(is_converged(hartree_parameters=True, structure=structure, return_values=True),
                         {u'ecut': 44.0, u'ecuteps': 4.0, u'gap': 6.816130591466406, u'nbands': 60})
        self.assertTrue(os.path.isfile('SiC_SiC.cif.full_res'))

        print(' ==== generate next flow ===  ')
        print('      version with bandstructure and dos  ')
        gwsetup(update=False)
        self.assertTrue(os.path.isdir('SiC_SiC.cif.conv'))
        print(os.listdir(os.path.join(wdir, 'SiC_SiC.cif.conv', 'w0')))
        self.assertEqual(len(os.listdir(os.path.join(wdir, 'SiC_SiC.cif.conv', 'w0'))), 15)

        print(' ==== copy reference from second flow ===  ')
        time.sleep(1)  # the .conv directory should be older than the first one
        shutil.rmtree(os.path.join(wdir, 'SiC_SiC.cif.conv'))
        shutil.copytree(os.path.join(reference_dir, 'ref_res', 'SiC_SiC.cif.conv'),
                        os.path.join(wdir, 'SiC_SiC.cif.conv'))
        self.assertTrue(os.path.isdir(os.path.join(wdir, 'SiC_SiC.cif.conv')))
        self.assertEqual(len(os.listdir(os.path.join(wdir, 'SiC_SiC.cif.conv', 'w0'))), 13)

        print(' ==== process output ===  ')
        backup = sys.stdout
        sys.stdout = StringIO()  # capture output
        gwoutput()
        out = sys.stdout.getvalue()  # release output
        sys.stdout.close()  # close the stream
        sys.stdout = backup  # restore original stdout

        print('=== *** ====\n'+out+'=== *** ====\n')
        gap = 0
        for l in out.split('\n'):
            if 'values' in l:
                gap = float(l.split(' ')[6])
        self.assertEqual(gap, 7.114950664158926)

        print(os.listdir('.'))
        print('processed')
        self.assertTrue(os.path.isfile('SiC_SiC.cif.full_res'))
        full_res = read_grid_from_file(s_name(structure)+'.full_res')
        self.assertEqual(full_res, {u'all_done': True, u'grid': 0})
        self.assertTrue(os.path.isdir(os.path.join(wdir, 'SiC_SiC.cif.res')))
        self.assertEqual(len(os.listdir(os.path.join(wdir, 'SiC_SiC.cif.res'))), 5)
        print(os.listdir(os.path.join(wdir, 'SiC_SiC.cif.res')))

        msrf = MySigResFile(os.path.join(wdir, 'SiC_SiC.cif.res', 'out_SIGRES.nc'))
        self.assertEqual(msrf.h**o, 6.6843830378711786)
        self.assertEqual(msrf.lumo, 8.0650328308487982)
        self.assertEqual(msrf.homo_gw, 6.2325949743130034)
        self.assertEqual(msrf.lumo_gw, 8.2504215095164763)
        self.assertEqual(msrf.fundamental_gap('ks'), msrf.lumo - msrf.h**o)
        self.assertEqual(msrf.fundamental_gap('gw'), msrf.lumo_gw - msrf.homo_gw)
        self.assertAlmostEqual(msrf.fundamental_gap('gamma'), gap, places=3)

        # since we now have a mysigresfile object we test the functionality

        msrf.get_scissor()
        # return self.qplist_spin[0].build_scissors(domains=[[-200, mid], [mid, 200]], k=1, plot=False)

        res = msrf.get_scissor_residues()
        self.assertEqual(res, [0.05322754684319431, 0.34320373172956475])
        # return sc.residues

        #msrf.plot_scissor(title='')

        #msrf.plot_qpe(title='')

        # to be continued

        if temp_ABINIT_PS is not None:
            os.environ['ABINIT_PS_EXT'] = temp_ABINIT_PS_EXT
            os.environ['ABINIT_PS'] = temp_ABINIT_PS
コード例 #4
0
    def band_plots(self, query=None, sort='system'):
        """
        Plot the scissored bands for all the systems in the DB
        :param query:
        :param sort:
        :return:
        """
        if query is None:
            query = {}
        for item in self.col.find(query).sort(sort):
            exclude = ['ZnO_mp-2133', 'K3Sb_mp-14017', 'NiP2_mp-486', 'OsS2_mp-20905', 'PbO_mp-1336', 'Rb3Sb_mp-16319',
                       'As2Os_mp-2455']
            if item['system'] in exclude:
                continue
            if False:
                print('System    : ', item['system'].split('_mp-')[0])
                print('Ps        : ', item['ps'])
                print('extra     : ', item['extra_vars'])
                print('gwresults : ', item['gw_results'])
            try:
                data = self.gfs.get(item['results_file']).read()
                if len(data) > 1000:
                    srf = MySigResFile(data)
                    title = "QPState corrections of " + str(item['system']) + "\nusing " \
                            + str(item['ps'].split('/')[-2])
                    if item['extra_vars'] is not None:
                        title += ' and ' + str(self.fix_extra(item['extra_vars']))
                    srf.plot_scissor(title=title)
                    srf.print_gap_info()
                    sc = srf.get_scissor()

                    with self.gfs.get(item['ksbands_file']) as f:
                        ksb = MyBandsFile(f.read()).ebands
                    qpb = ksb.apply_scissors(sc)

                    # Plot the LDA and the QPState band structure with matplotlib.
                    plotter = ElectronBandsPlotter()

                    plotter.add_ebands("KS", ksb)

                    plotter.add_ebands("KS+scissors(e)", qpb)

                    fig = plotter.plot(align='cbm', ylim=(-5, 10), title="%s Bandstructure" %
                                                                         item['system'].split('_mp-')[0])

                    try:
                        if not item['tgwgap']:
                            bands_data = self.gfs.get(item['ksbands_file'])
                            ks_bands = MyBandsFile(bands_data)
                            item['tkshomo'] = ks_bands.h**o
                            item['tkslumo'] = ks_bands.lumo
                            item['tksgap'] = ks_bands.lumo - ks_bands.h**o
                            item['tgwhomo'] = ks_bands.h**o + sc.apply(ks_bands.h**o)
                            item['tgwlumo'] = ks_bands.lumo + sc.apply(ks_bands.lumo)
                            item['tgwgap'] = item['tgwlumo'] - item['tgwhomo']
                            self.col.save(item)
                    except KeyError:
                        pass
            except (KeyError, IOError, NoFile):
                print('No Sigres file in DataBase')

        try:
            return fig
        except UnboundLocalError:
            return None
コード例 #5
0
    def get_data_set(self, query=None, sigresdata=False, update=False):
        """
        method to retrieve a data set from database
        """
        self.data_set = []
        if sigresdata:
            print('may need to parse nc files, this may take some time')
        if query is None:
            query = {}
        for item in self.col.find(query):
            try:
                mpid = item['system'].split('_mp-')[1]
            except IndexError:
                mpid = 0
            try:
                entry = {'system': item['system'].split('_mp-')[0],
                         'id': "mp-%s" % mpid,
                         'ps': item['ps'].split('/')[-2],
                         'xc': item['spec']['functional'],
                         'kp_in': item['spec'].get('kp_in', None),
                         'extra': self.fix_extra(item['extra_vars']),
                         'data': item['gw_results']}
                print("id: %s" % entry['id'])
                mp_id = entry['id'].split('mp-')[1]
                try:
                    for exp_result in self.col_external.exp.find({'MP_id': str(mp_id)}):
                        expgap = exp_result['band_gap']
                        icsd = exp_result['icsd_data']['icsd_id']
                    entry['expgap'] = expgap
                    for pbe_result in self.col_external.GGA_BS.find({'transformations.history.0.id': icsd}):
                        pbegap = pbe_result['analysis']['bandgap']
                    entry['pbegap'] = pbegap
                except:
                    print(str(mp_id))
                    expgap = None
                    icsd = None
                    pbegap = None
                if sigresdata:
                    try:
                        if update:
                            raise KeyError
                        if 'max_en' not in item['srf_data'].keys():
                            raise KeyError
                        entry['data'].update(item['srf_data'])
                        sys.stdout.write(".")
                        sys.stdout.flush()
                    except KeyError:
                        sys.stdout.write(":")
                        sys.stdout.flush()
                        try:
                            print('getting srfdata')
                            srf = MySigResFile(self.gfs.get(item['results_file']).read())
                            srf_data = {'gwfgap': srf.fundamental_gap('gw'),
                                        'ksfgap': srf.fundamental_gap('ks'),
                                        'gwggap': srf.fundamental_gap('gamma'),
                                        'gwhomo': srf.homo_gw,
                                        'kshomo': srf.h**o,
                                        'gwlumo': srf.lumo_gw,
                                        'kslumo': srf.lumo,
                                        'max_en': srf.en_max_band,
                                        'scissor_residues': srf.scissor_residues}
                            entry['data'].update(srf_data)
                        except:  # (IOError, OSError, NetcdfReaderError):
                            srf_data = {'gwfgap': 0,
                                        'ksfgap': 0,
                                        'gwhomo': 0,
                                        'kshomo': 0,
                                        'gwlumo': 0,
                                        'kslumo': 0,
                                        'max_en': 0,
                                        'scissor_residues': [0, 0]}
                            entry['data'].update(srf_data)
                        item['srf_data'] = srf_data
                        self.col.save(item)
                    if item['system'] not in []:
                        print('bands for', item['system'])
                        try:
                            if item['bs_data']['gwbgap'] is None:
                                raise KeyError
                            entry['data'].update(item['bs_data'])
                            sys.stdout.write(".")
                            sys.stdout.flush()
                        except KeyError:
                            try:
                                t0 = time.time()
                                if item['system'] in ['As2Os_mp-2455', 'AgI_mp-22894', 'PbO_mp-20878', 'Cu2O_mp-361']:
                                    raise KeyError
                                with self.gfs.get(item['ksbands_file']) as f:
                                    ksb = MyBandsFile(f.read())
                                t = time.time() - t0
                                print('reading bands : %s' % t)
                                t0 = time.time()
                                with self.gfs.get(item['results_file']) as g:
                                    srf = MySigResFile(g.read())
                                t = time.time() - t0
                                print('reading srf : %s' % t)
                                t0 = time.time()
                                scissor = srf.get_scissor()
                                gwbhomo = ksb.h**o + scissor.apply(ksb.h**o)
                                gwblumo = ksb.lumo + scissor.apply(ksb.lumo)
                                bs_data = {'ksbgap': ksb.lumo - ksb.h**o,
                                           'ksbhomo': ksb.h**o,
                                           'gwbhomo': gwbhomo,
                                           'ksblumo': ksb.lumo,
                                           'gwblumo': gwblumo,
                                           'gwbgap':  gwblumo - gwbhomo}
                                entry['data'].update(bs_data)
                                t = time.time() - t0
                                print('making updating : %s' % t)
                                print('got ksbandsdata')
                            except (NoFile, KeyError, ValueError, KeyboardInterrupt):
                                print('failed getting ksbandsdata')
                                bs_data = {'ksbgap': None,
                                           'ksbhomo': None,
                                           'gwbhomo': None,
                                           'gwblumo': None,
                                           'ksblumo': None,
                                           'gwbgap': None}
                            item['bs_data'] = bs_data
                            self.col.save(item)

                self.data_set.append(entry)
            except CursorNotFound:
                print('cursonotfound')
        print('\n')