示例#1
0
文件: qmcpack.py 项目: jyamu/qmc
    def incorporate_result(self,result_name,result,sim):
        input = self.input
        system = self.system
        if result_name=='orbitals':
            if isinstance(sim,Pw2qmcpack) or isinstance(sim,Wfconvert):

                h5file = result.h5file

                wavefunction = input.get('wavefunction')
                if isinstance(wavefunction,collection):
                    wavefunction = wavefunction.get_single('psi0')
                #end if
                wf = wavefunction
                if 'sposet_builder' in wf and wf.sposet_builder.type=='bspline':
                    orb_elem = wf.sposet_builder
                elif 'sposet_builders' in wf and 'bspline' in wf.sposet_builders:
                    orb_elem = wf.sposet_builders.bspline
                elif 'determinantset' in wf and wf.determinantset.type in ('bspline','einspline'):
                    orb_elem = wf.determinantset
                else:
                    self.error('could not incorporate pw2qmcpack/wfconvert orbitals\n  bspline sposet_builder and determinantset are both missing')
                #end if
                orb_elem.href = os.path.relpath(h5file,self.locdir)
                if system.structure.folded_structure!=None:
                    orb_elem.tilematrix = array(system.structure.tmatrix)
                #end if
                defs = obj(
                    twistnum   = 0,
                    meshfactor = 1.0
                    )
                for var,val in defs.iteritems():
                    if not var in orb_elem:
                        orb_elem[var] = val
                    #end if
                #end for

                system = self.system
                structure = system.structure
                nkpoints = len(structure.kpoints)
                if nkpoints==0:
                    self.error('system must have kpoints to assign twistnums')
                #end if
                    
                if not os.path.exists(h5file):
                    self.error('wavefunction file not found:  \n'+h5file)
                #end if

                twistnums = range(len(structure.kpoints))
                if self.should_twist_average:
                    self.twist_average(twistnums)
                elif orb_elem.twistnum is None:
                    orb_elem.twistnum = twistnums[0]
                #end if

            elif isinstance(sim,Sqd):

                h5file  = os.path.join(result.dir,result.h5file)
                h5file  = os.path.relpath(h5file,self.locdir)

                sqdxml_loc = os.path.join(result.dir,result.qmcfile)
                sqdxml = QmcpackInput(sqdxml_loc)

                #sqd sometimes puts the wrong ionic charge
                #  rather than setting Z to the number of electrons
                #  set it to the actual atomic number
                g = sqdxml.qmcsystem.particlesets.atom.group
                elem = g.name
                if not elem in periodic_table.elements:
                    self.error(elem+' is not an element in the periodic table')
                #end if
                g.charge = periodic_table.elements[elem].atomic_number

                input = self.input
                s = input.simulation
                qsys_old = s.qmcsystem
                del s.qmcsystem
                s.qmcsystem = sqdxml.qmcsystem
                if 'jastrows' in qsys_old.wavefunction:
                    s.qmcsystem.wavefunction.jastrows = qsys_old.wavefunction.jastrows
                    for jastrow in s.qmcsystem.wavefunction.jastrows:
                        if 'type' in jastrow:
                            jtype = jastrow.type.lower().replace('-','_')
                            if jtype=='one_body':
                                jastrow.source = 'atom'
                            #end if
                        #end if
                    #end for
                #end if
                s.qmcsystem.hamiltonian = hamiltonian(
                    name='h0',type='generic',target='e',
                    pairpots = [
                        pairpot(name='ElecElec',type='coulomb',source='e',target='e'),
                        pairpot(name='Coulomb' ,type='coulomb',source='atom',target='e'),
                        ]
                    )
                s.init = init(source='atom',target='e')

                abset = input.get('atomicbasisset')
                abset.href = h5file

            else:
                self.error('incorporating orbitals from '+sim.__class__.__name__+' has not been implemented')
            #end if
        elif result_name=='jastrow':
            if isinstance(sim,Qmcpack):
                opt_file = result.opt_file
                opt = QmcpackInput(opt_file)
                wavefunction = input.get('wavefunction')
                optwf = opt.qmcsystem.wavefunction
                def process_jastrow(wf):                
                    if 'jastrow' in wf:
                        js = [wf.jastrow]
                    elif 'jastrows' in wf:
                        js = wf.jastrows.values()
                    else:
                        js = []
                    #end if
                    jd = dict()
                    for j in js:
                        jtype = j.type.lower().replace('-','_').replace(' ','_')
                        jd[jtype] = j
                    #end for
                    return jd
                #end def process_jastrow
                if wavefunction==None:
                    qs = input.get('qmcsystem')
                    qs.wavefunction = optwf.copy()
                else:
                    jold = process_jastrow(wavefunction)
                    jopt = process_jastrow(optwf)
                    jnew = list(jopt.values())
                    for jtype in jold.keys():
                        if not jtype in jopt:
                            jnew.append(jold[jtype])
                        #end if
                    #end for
                    if len(jnew)==1:
                        wavefunction.jastrow = jnew[0].copy()
                    else:
                        wavefunction.jastrows = collection(jnew)
                    #end if
                #end if
                del optwf
            elif isinstance(sim,Sqd):
                wavefunction = input.get('wavefunction')
                jastrows = []
                if 'jastrows' in wavefunction:
                    for jastrow in wavefunction.jastrows:
                        jname = jastrow.name
                        if jname!='J1' and jname!='J2':
                            jastrows.append(jastrow)
                        #end if
                    #end for
                    del wavefunction.jastrows
                #end if

                ionps = input.get_ion_particlesets()
                if ionps is None or len(ionps)==0:
                    self.error('ion particleset does not seem to exist')
                elif len(ionps)==1:
                    ionps_name = list(ionps.keys())[0]
                else:
                    self.error('multiple ion species not supported for atomic calculations')
                #end if

                jastrows.extend([
                        generate_jastrow('J1','bspline',8,result.rcut,iname=ionps_name,system=self.system),
                        generate_jastrow('J2','pade',result.B)
                        ])

                wavefunction.jastrows = collection(jastrows)

            else:
                self.error('incorporating jastrow from '+sim.__class__.__name__+' has not been implemented')
            #end if
        elif result_name=='structure':
            structure = self.system.structure
            relstruct = result.structure
            structure.set(
                pos   = relstruct.positions,
                atoms = relstruct.atoms
                )
            self.input.incorporate_system(self.system)
        else:
            self.error('ability to incorporate result '+result_name+' has not been implemented')
示例#2
0
    def incorporate_result(self,result_name,result,sim):
        input = self.input
        system = self.system
        if result_name=='orbitals':
            if isinstance(sim,Pw2qmcpack) or isinstance(sim,Wfconvert):

                h5file = result.h5file

                wavefunction = input.get('wavefunction')
                if isinstance(wavefunction,collection):
                    wavefunction = wavefunction.get_single('psi0')
                #end if
                wf = wavefunction
                if 'sposet_builder' in wf and wf.sposet_builder.type=='bspline':
                    orb_elem = wf.sposet_builder
                elif 'sposet_builders' in wf and 'bspline' in wf.sposet_builders:
                    orb_elem = wf.sposet_builders.bspline
                elif 'sposet_builders' in wf and 'einspline' in wf.sposet_builders:
                    orb_elem = wf.sposet_builders.einspline
                elif 'determinantset' in wf and wf.determinantset.type in ('bspline','einspline'):
                    orb_elem = wf.determinantset
                else:
                    self.error('could not incorporate pw2qmcpack/wfconvert orbitals\nbspline sposet_builder and determinantset are both missing')
                #end if
                if 'href' in orb_elem and isinstance(orb_elem.href,str) and os.path.exists(orb_elem.href):
                    # user specified h5 file for orbitals, bypass orbital dependency
                    orb_elem.href = os.path.relpath(orb_elem.href,self.locdir)
                else:
                    orb_elem.href = os.path.relpath(h5file,self.locdir)
                    if system.structure.folded_structure!=None:
                        orb_elem.tilematrix = array(system.structure.tmatrix)
                    #end if
                #end if
                defs = obj(
                    #twistnum   = 0,
                    meshfactor = 1.0
                    )
                for var,val in defs.iteritems():
                    if not var in orb_elem:
                        orb_elem[var] = val
                    #end if
                #end for
                has_twist    = 'twist' in orb_elem
                has_twistnum = 'twistnum' in orb_elem
                if  not has_twist and not has_twistnum:
                    orb_elem.twistnum = 0
                #end if

                system = self.system
                structure = system.structure
                nkpoints = len(structure.kpoints)
                if nkpoints==0:
                    self.error('system must have kpoints to assign twistnums')
                #end if
                    
                if not os.path.exists(h5file):
                    self.error('wavefunction file not found:  \n'+h5file)
                #end if

                twistnums = range(len(structure.kpoints))
                if self.should_twist_average:
                    self.twist_average(twistnums)
                elif not has_twist and orb_elem.twistnum is None:
                    orb_elem.twistnum = twistnums[0]
                #end if

            elif isinstance(sim,Sqd):

                h5file  = os.path.join(result.dir,result.h5file)
                h5file  = os.path.relpath(h5file,self.locdir)

                sqdxml_loc = os.path.join(result.dir,result.qmcfile)
                sqdxml = QmcpackInput(sqdxml_loc)

                #sqd sometimes puts the wrong ionic charge
                #  rather than setting Z to the number of electrons
                #  set it to the actual atomic number
                g = sqdxml.qmcsystem.particlesets.atom.group
                elem = g.name
                if not elem in periodic_table.elements:
                    self.error(elem+' is not an element in the periodic table')
                #end if
                g.charge = periodic_table.elements[elem].atomic_number

                input = self.input
                s = input.simulation
                qsys_old = s.qmcsystem
                del s.qmcsystem
                s.qmcsystem = sqdxml.qmcsystem
                if 'jastrows' in qsys_old.wavefunction:
                    s.qmcsystem.wavefunction.jastrows = qsys_old.wavefunction.jastrows
                    for jastrow in s.qmcsystem.wavefunction.jastrows:
                        if 'type' in jastrow:
                            jtype = jastrow.type.lower().replace('-','_')
                            if jtype=='one_body':
                                jastrow.source = 'atom'
                            #end if
                        #end if
                    #end for
                #end if
                s.qmcsystem.hamiltonian = hamiltonian(
                    name='h0',type='generic',target='e',
                    pairpots = [
                        pairpot(name='ElecElec',type='coulomb',source='e',target='e'),
                        pairpot(name='Coulomb' ,type='coulomb',source='atom',target='e'),
                        ]
                    )
                s.init = init(source='atom',target='e')

                abset = input.get('atomicbasisset')
                abset.href = h5file

            elif isinstance(sim,Convert4qmc):

                res = QmcpackInput(result.location)
                qs  = input.simulation.qmcsystem
                oldwfn = qs.wavefunction
                newwfn = res.qmcsystem.wavefunction
                dset = newwfn.determinantset
                if 'jastrows' in newwfn:
                    del newwfn.jastrows
                #end if
                if 'jastrows' in oldwfn:
                    newwfn.jastrows = oldwfn.jastrows
                #end if
                if input.cusp_correction():
                    dset.cuspcorrection = True
                #end if
                if 'orbfile' in result:
                    orb_h5file = result.orbfile
                    if not os.path.exists(orb_h5file) and 'href' in dset:
                        orb_h5file = os.path.join(sim.locdir,dset.href)
                    #end if
                    if not os.path.exists(orb_h5file):
                        self.error('orbital h5 file from convert4qmc does not exist\nlocation checked: {}'.format(orb_h5file))
                    #end if
                    orb_path = os.path.relpath(orb_h5file,self.locdir)
                    dset.href = orb_path
                    detlist = dset.get('detlist')
                    if detlist is not None and 'href' in detlist:
                        detlist.href = orb_path
                    #end if
                #end if
                qs.wavefunction = newwfn

            else:
                self.error('incorporating orbitals from '+sim.__class__.__name__+' has not been implemented')
            #end if
        elif result_name=='jastrow':
            if isinstance(sim,Qmcpack):
                opt_file = result.opt_file
                opt = QmcpackInput(opt_file)
                wavefunction = input.get('wavefunction')
                optwf = opt.qmcsystem.wavefunction
                def process_jastrow(wf):                
                    if 'jastrow' in wf:
                        js = [wf.jastrow]
                    elif 'jastrows' in wf:
                        js = wf.jastrows.values()
                    else:
                        js = []
                    #end if
                    jd = dict()
                    for j in js:
                        jtype = j.type.lower().replace('-','_').replace(' ','_')
                        jd[jtype] = j
                    #end for
                    return jd
                #end def process_jastrow
                if wavefunction==None:
                    qs = input.get('qmcsystem')
                    qs.wavefunction = optwf.copy()
                else:
                    jold = process_jastrow(wavefunction)
                    jopt = process_jastrow(optwf)
                    jnew = list(jopt.values())
                    for jtype in jold.keys():
                        if not jtype in jopt:
                            jnew.append(jold[jtype])
                        #end if
                    #end for
                    if len(jnew)==1:
                        wavefunction.jastrow = jnew[0].copy()
                    else:
                        wavefunction.jastrows = collection(jnew)
                    #end if
                #end if
                del optwf
            elif isinstance(sim,Sqd):
                wavefunction = input.get('wavefunction')
                jastrows = []
                if 'jastrows' in wavefunction:
                    for jastrow in wavefunction.jastrows:
                        jname = jastrow.name
                        if jname!='J1' and jname!='J2':
                            jastrows.append(jastrow)
                        #end if
                    #end for
                    del wavefunction.jastrows
                #end if

                ionps = input.get_ion_particlesets()
                if ionps is None or len(ionps)==0:
                    self.error('ion particleset does not seem to exist')
                elif len(ionps)==1:
                    ionps_name = list(ionps.keys())[0]
                else:
                    self.error('multiple ion species not supported for atomic calculations')
                #end if

                jastrows.extend([
                        generate_jastrow('J1','bspline',8,result.rcut,iname=ionps_name,system=self.system),
                        generate_jastrow('J2','pade',result.B)
                        ])

                wavefunction.jastrows = collection(jastrows)

            else:
                self.error('incorporating jastrow from '+sim.__class__.__name__+' has not been implemented')
            #end if
        elif result_name=='particles':
            if isinstance(sim,Convert4qmc):
                ptcl_file = result.location
                qi = QmcpackInput(ptcl_file)
                self.input.simulation.qmcsystem.particlesets = qi.qmcsystem.particlesets
            else:
                self.error('incorporating particles from '+sim.__class__.__name__+' has not been implemented')
            # end if
        elif result_name=='structure':
            relstruct = result.structure.copy()
            relstruct.change_units('B')
            self.system.structure = relstruct
            self.system.remove_folded()
            self.input.incorporate_system(self.system)

        elif result_name=='cuspcorr':

            ds = self.input.get('determinantset')
            ds.cuspcorrection = True
            try: # multideterminant
              ds.sposets['spo-up'].cuspinfo = os.path.relpath(result.spo_up_cusps,self.locdir)
              ds.sposets['spo-dn'].cuspinfo = os.path.relpath(result.spo_dn_cusps,self.locdir)
            except: # single determinant
              sd = ds.slaterdeterminant
              sd.determinants['updet'].cuspinfo = os.path.relpath(result.updet_cusps,self.locdir)
              sd.determinants['downdet'].cuspinfo = os.path.relpath(result.dndet_cusps,self.locdir)
            # end try

        elif result_name=='wavefunction':
            if not isinstance(sim,Qmcpack):
                self.error('incorporating wavefunction from '+sim.__class__.__name__+' has not been implemented')
            #end if
            print '        getting optimal wavefunction from: '+result.opt_file
            opt = QmcpackInput(result.opt_file)
            qs = input.get('qmcsystem')
            qs.wavefunction = opt.qmcsystem.wavefunction.copy()
        else:
            self.error('ability to incorporate result '+result_name+' has not been implemented')
示例#3
0
文件: qmcpack.py 项目: jyamu/qmc
    def incorporate_result(self, result_name, result, sim):
        input = self.input
        system = self.system
        if result_name == "orbitals":
            if isinstance(sim, Pw2qmcpack) or isinstance(sim, Wfconvert):

                h5file = result.h5file

                dsold, wavefunction = input.get("determinantset", "wavefunction")
                if isinstance(wavefunction, collection):
                    if "psi0" in wavefunction:
                        wavefunction = wavefunction.psi0
                    else:
                        wavefunction = wavefunction.list()[0]
                    # end if
                # end if
                dsnew = dsold
                dsnew.set(type="einspline", href=os.path.relpath(h5file, self.locdir))
                if system.structure.folded_structure != None:
                    dsnew.tilematrix = array(system.structure.tmatrix)
                # end if
                defs = obj(twistnum=0, meshfactor=1.0, gpu=False)
                for var, val in defs.iteritems():
                    if not var in dsnew:
                        dsnew[var] = val
                    # end if
                # end for
                input.remove("determinantset")
                wavefunction.determinantset = dsnew

                system = self.system
                structure = system.structure
                nkpoints = len(structure.kpoints)
                if nkpoints == 0:
                    self.error("system must have kpoints to assign twistnums")
                # end if

                if not os.path.exists(h5file):
                    self.error("wavefunction file not found:  \n" + h5file)
                # end if

                if "tilematrix" in system:
                    dsnew.tilematrix = array(system.tilematrix)
                # end if
                twistnums = range(len(structure.kpoints))
                if len(twistnums) > 1:
                    self.twist_average(twistnums)
                else:
                    dsnew.twistnum = twistnums[0]
                # end if

            elif isinstance(sim, Sqd):

                h5file = os.path.join(result.dir, result.h5file)
                h5file = os.path.relpath(h5file, self.locdir)

                sqdxml_loc = os.path.join(result.dir, result.qmcfile)
                sqdxml = QmcpackInput(sqdxml_loc)

                # sqd sometimes puts the wrong ionic charge
                #  rather than setting Z to the number of electrons
                #  set it to the actual atomic number
                g = sqdxml.qmcsystem.particlesets.atom.group
                elem = g.name
                if not elem in periodic_table.elements:
                    self.error(elem + " is not an element in the periodic table")
                # end if
                g.charge = periodic_table.elements[elem].atomic_number

                input = self.input
                s = input.simulation
                qsys_old = s.qmcsystem
                del s.qmcsystem
                s.qmcsystem = sqdxml.qmcsystem
                if "jastrows" in qsys_old.wavefunction:
                    s.qmcsystem.wavefunction.jastrows = qsys_old.wavefunction.jastrows
                    for jastrow in s.qmcsystem.wavefunction.jastrows:
                        if "type" in jastrow:
                            jtype = jastrow.type.lower().replace("-", "_")
                            if jtype == "one_body":
                                jastrow.source = "atom"
                            # end if
                        # end if
                    # end for
                # end if
                s.qmcsystem.hamiltonian = hamiltonian(
                    name="h0",
                    type="generic",
                    target="e",
                    pairpots=[
                        pairpot(name="ElecElec", type="coulomb", source="e", target="e"),
                        pairpot(name="Coulomb", type="coulomb", source="atom", target="e"),
                    ],
                )
                s.init = init(source="atom", target="e")

                abset = input.get("atomicbasisset")
                abset.href = h5file

            else:
                self.error("incorporating orbitals from " + sim.__class__.__name__ + " has not been implemented")
            # end if
        elif result_name == "jastrow":
            if isinstance(sim, Qmcpack):
                opt_file = result.opt_file
                opt = QmcpackInput(opt_file)
                wavefunction = input.get("wavefunction")
                optwf = opt.qmcsystem.wavefunction

                def process_jastrow(wf):
                    if "jastrow" in wf:
                        js = [wf.jastrow]
                    elif "jastrows" in wf:
                        js = wf.jastrows.values()
                    else:
                        js = []
                    # end if
                    jd = dict()
                    for j in js:
                        jtype = j.type.lower().replace("-", "_").replace(" ", "_")
                        jd[jtype] = j
                    # end for
                    return jd

                # end def process_jastrow
                if wavefunction == None:
                    qs = input.get("qmcsystem")
                    qs.wavefunction = optwf.copy()
                else:
                    jold = process_jastrow(wavefunction)
                    jopt = process_jastrow(optwf)
                    jnew = list(jopt.values())
                    for jtype in jold.keys():
                        if not jtype in jopt:
                            jnew.append(jold[jtype])
                        # end if
                    # end for
                    if len(jnew) == 1:
                        wavefunction.jastrow = jnew[0].copy()
                    else:
                        wavefunction.jastrows = collection(jnew)
                    # end if
                # end if
                del optwf
            elif isinstance(sim, Sqd):
                wavefunction = input.get("wavefunction")
                jastrows = []
                if "jastrows" in wavefunction:
                    for jastrow in wavefunction.jastrows:
                        jname = jastrow.name
                        if jname != "J1" and jname != "J2":
                            jastrows.append(jastrow)
                        # end if
                    # end for
                    del wavefunction.jastrows
                # end if

                ionps = input.get_ion_particlesets()
                if ionps is None or len(ionps) == 0:
                    self.error("ion particleset does not seem to exist")
                elif len(ionps) == 1:
                    ionps_name = list(ionps.keys())[0]
                else:
                    self.error("multiple ion species not supported for atomic calculations")
                # end if

                jastrows.extend(
                    [
                        generate_jastrow("J1", "bspline", 8, result.rcut, iname=ionps_name, system=self.system),
                        generate_jastrow("J2", "pade", result.B),
                    ]
                )

                wavefunction.jastrows = collection(jastrows)

            else:
                self.error("incorporating jastrow from " + sim.__class__.__name__ + " has not been implemented")
            # end if
        elif result_name == "structure":
            structure = self.system.structure
            relstruct = result.structure
            structure.set(pos=relstruct.positions, atoms=relstruct.atoms)
            self.input.incorporate_system(self.system)
        else:
            self.error("ability to incorporate result " + result_name + " has not been implemented")