Пример #1
0
    def __init__(self, atoms, **kwargs):
        """Initialize with an ASE atoms object and keyword arguments."""
        self._atoms = atoms
        for key in kwargs:
            if key not in self._default_settings:
                raise RuntimeError('Unknown keyword: %s' % key)
        for k, v in self._default_settings.items():
            setattr(self, '_%s' % k, kwargs.pop(k, v))

        # when a MD sim. has passed a local minimum:
        self._passedminimum = PassedMinimum()

        # Misc storage.
        self._previous_optimum = None
        self._previous_energy = None
        self._temperature = self._T0
        self._Ediff = self._Ediff0

        #Oganov fingerprints for structure comparison
        self._comp = OFPComparator(dE=1.0,
                                   cos_dist_max=5e-3,
                                   rcut=20.,
                                   binwidth=0.05,
                                   pbc=[True, True, True],
                                   sigma=0.05,
                                   nsigma=4,
                                   recalculate=False)
def run_ga(n_to_test):
    """
    This method specifies how to run the GA once the
    initial random structures have been stored in godb.db.
    """
    # Various initializations:
    population_size = 10  # maximal size of the population
    da = DataConnection('godb.db')
    atom_numbers_to_optimize = da.get_atom_numbers_to_optimize()  # = [14] * 7
    n_to_optimize = len(atom_numbers_to_optimize)  # = 7
    # This defines how close the Si atoms are allowed to get
    # in candidate structures generated by the genetic operators:
    blmin = closest_distances_generator(atom_numbers_to_optimize,
                                        ratio_of_covalent_radii=0.4)
    # This is our OFPComparator instance which will be
    # used to judge whether or not two structures are identical:
    comparator = OFPComparator(n_top=None, dE=1.0, cos_dist_max=1e-3,
                               rcut=10., binwidth=0.05, pbc=[False]*3,
                               sigma=0.1, nsigma=4, recalculate=False)

    # Defining a typical combination of genetic operators:
    pairing = CutAndSplicePairing(da.get_slab(), n_to_optimize, blmin)
    rattlemut = RattleMutation(blmin, n_to_optimize, rattle_prop=0.8,
                               rattle_strength=1.5)
    operators = OperationSelector([2., 1.], [pairing, rattlemut])

    # Relax the randomly generated initial candidates:
    while da.get_number_of_unrelaxed_candidates() > 0:
        a = da.get_an_unrelaxed_candidate()
        a = relax_one(a)
        da.add_relaxed_step(a)

    # Create the population
    population = Population(data_connection=da,
                            population_size=population_size,
                            comparator=comparator,
                            logfile='log.txt')
    current_pop = population.get_current_population()

    # Test n_to_test new candidates
    for step in range(n_to_test):
        print('Starting configuration number %d' % step, flush=True)

        a3 = None
        while a3 is None:
            a1, a2 = population.get_two_candidates()
            a3, description = operators.get_new_individual([a1, a2])

        da.add_unrelaxed_candidate(a3, description=description)
        a3 = relax_one(a3)
        da.add_relaxed_step(a3)

        population.update()
        best = population.get_current_population()[0]
        print('Highest raw score at this point: %.3f' % get_raw_score(best))

    print('GA finished after step %d' % step)
    write('all_candidates.traj', da.get_all_relaxed_candidates())
    write('current_population.traj', population.get_current_population())
Пример #3
0
    def __init__(self, atoms, **kwargs):
        """Initialize with an ASE atoms object and keyword arguments."""
        self._atoms = atoms
        for key in kwargs:
            if key not in self._default_settings:
                raise RuntimeError('Unknown keyword: %s' % key)
        for k, v in self._default_settings.items():
            setattr(self, '_%s' % k, kwargs.pop(k, v))

        # when a MD sim. has passed a local minimum:
        self._passedminimum = PassedMinimum()

        # Misc storage.
        self._previous_optimum = None
        self._previous_energy = None
        self._temperature = self._T0
        self._Ediff = self._Ediff0

        #Oganov fingerprints for structure comparison
        self._comp = OFPComparator(dE=10.0,
                     cos_dist_max=self._minima_threshold, rcut=20., binwidth=0.05,
                     pbc=[True, True, True], sigma=0.05, nsigma=4,
                     recalculate=False)

        #inorganic indices and positions for Hookean constraints
        self._Pb_indices = np.where(self._atoms.symbols == 'Pb')[0]
        self._Pb_positions = self._atoms.positions[self._Pb_indices]
        self._Br_indices = np.where(self._atoms.symbols == 'Br')[0]
        self._Br_positions = self._atoms.positions[self._Br_indices]
        self._inorganic_indices = np.concatenate((self._Pb_indices, self._Br_indices))
        self._inorganic_positions = self._atoms.positions[self._inorganic_indices]
        #make list for distances of Pb atoms and 4 surrounding Br
#        self._distances_list = np.empty((len(self._Pb_indices), 6))
        self._indices_list = np.empty((len(self._Pb_indices), 6), dtype='int')

        for i in range(len(self._Pb_indices)):
            distances = self._atoms.get_distances(self._Pb_indices[i], self._Br_indices, mic=True)
#            self._distances_list[i] = distances[np.argsort(distances)[:6]]
            self._indices_list[i] = self._Br_indices[np.argsort(distances)[:6]]
class MinimaHopping:
    """Implements the minima hopping method of global optimization outlined
    by S. Goedecker,  J. Chem. Phys. 120: 9911 (2004). Initialize with an
    ASE atoms object. Optional parameters are fed through keywords.
    To run multiple searches in parallel, specify the minima_traj keyword,
    and have each run point to the same path.
    """

    _default_settings = {
        'T0': 1000.,  # K, initial MD 'temperature'
        'beta1': 1.1,  # temperature adjustment parameter
        'beta2': 1.1,  # temperature adjustment parameter
        'beta3': 1. / 1.1,  # temperature adjustment parameter
        'Ediff0': 0.5,  # eV, initial energy acceptance threshold
        'alpha1': 0.98,  # energy threshold adjustment parameter
        'alpha2': 1. / 0.98,  # energy threshold adjustment parameter
        'mdmin': 2,  # criteria to stop MD simulation (no. of minima)
        'logfile': 'hop.log',  # text log
        'minima_threshold':
        0.5e-3,  # A, cosine distance threshold for identical configs
        'timestep': 1.0,  # fs, timestep for MD simulations
        'optimizer': QuasiNewton,  # local optimizer to use
        'minima_traj': 'minima.traj',  # storage file for minima list
        'fmax': 0.05,  # eV/A, max force for cell optimization
        'fmax2': 0.1,  # eV/A, max force for geometry optimization
        'externalstress':
        1e-1,  # ev/A^3, the external stress tensor or scalar representing pressure.
        'ttime': 25.,  # fs, time constant for temperature coupling
        'pfactor': 0.6 * 75.**2
    }  # constant in the barostat differential equation

    def __init__(self, atoms, **kwargs):
        """Initialize with an ASE atoms object and keyword arguments."""
        self._atoms = atoms
        for key in kwargs:
            if key not in self._default_settings:
                raise RuntimeError('Unknown keyword: %s' % key)
        for k, v in self._default_settings.items():
            setattr(self, '_%s' % k, kwargs.pop(k, v))

        # when a MD sim. has passed a local minimum:
        self._passedminimum = PassedMinimum()

        # Misc storage.
        self._previous_optimum = None
        self._previous_energy = None
        self._temperature = self._T0
        self._Ediff = self._Ediff0

        #Oganov fingerprints for structure comparison
        self._comp = OFPComparator(dE=10.0,
                                   cos_dist_max=self._minima_threshold,
                                   rcut=20.,
                                   binwidth=0.05,
                                   pbc=[True, True, True],
                                   sigma=0.05,
                                   nsigma=4,
                                   recalculate=False)

        #Pb and Br indices and positions for constraints
        self._Pb_indices = np.where(self._atoms.symbols == 'Pb')[0]
        self._Pb_positions = self._atoms.positions[self._Pb_indices]
        self._Br_indices = np.where(self._atoms.symbols == 'Br')[0]
        self._Br_positions = self._atoms.positions[self._Pb_indices]
        self._inorganic_indices = np.concatenate(
            (self._Pb_indices, self._Br_indices))
        self._inorganic_positions = self._atoms.positions[
            self._inorganic_indices]

    def __call__(self, totalsteps=None, maxtemp=None):
        """Run the minima hopping algorithm. Can specify stopping criteria
        with total steps allowed or maximum searching temperature allowed.
        If neither is specified, runs indefinitely (or until stopped by
        batching software)."""
        self._startup()
        while True:
            if (totalsteps and self._counter >= totalsteps):
                self._log(
                    'msg', 'Run terminated. Step #%i reached of '
                    '%i allowed. Increase totalsteps if resuming.' %
                    (self._counter, totalsteps))
                return
            if (maxtemp and self._temperature >= maxtemp):
                self._log(
                    'msg', 'Run terminated. Temperature is %.2f K;'
                    ' max temperature allowed %.2f K.' %
                    (self._temperature, maxtemp))
                return

            self._previous_optimum = self._atoms.copy()
            self._previous_energy = self._atoms.get_potential_energy()
            self._molecular_dynamics()
            self._optimize()
            self._counter += 1
            self._check_results()

    def _startup(self):
        """Initiates a run, and determines if running from previous data or
        a fresh run."""

        status = np.array(-1.)
        exists = self._read_minima()
        if world.rank == 0:
            if not exists:
                # Fresh run with new minima file.
                status = np.array(0.)
            elif not os.path.exists(self._logfile):
                # Fresh run with existing or shared minima file.
                status = np.array(1.)
            else:
                # Must be resuming from within a working directory.
                status = np.array(2.)
        world.barrier()
        world.broadcast(status, 0)

        if status == 2.:
            self._resume()
        else:
            self._counter = 0
            self._log('init')
            self._log('msg', 'Performing initial optimization.')
            if status == 1.:
                self._log(
                    'msg', 'Using existing minima file with %i prior '
                    'minima: %s' % (len(self._minima), self._minima_traj))
            self._optimize()
            self._check_results()
            self._counter += 1

    def _resume(self):
        """Attempt to resume a run, based on information in the log
        file. Note it will almost always be interrupted in the middle of
        either a qn or md run or when exceeding totalsteps, so it only has
        been tested in those cases currently."""
        f = paropen(self._logfile, 'r')
        lines = f.read().splitlines()
        f.close()
        self._log('msg', 'Attempting to resume stopped run.')
        self._log(
            'msg', 'Using existing minima file with %i prior '
            'minima: %s' % (len(self._minima), self._minima_traj))
        mdcount, qncount = 0, 0
        for line in lines:
            if (line[:4] == 'par:') and ('Ediff' not in line):
                self._temperature = float(line.split()[1])
                self._Ediff = float(line.split()[2])
            elif line[:18] == 'msg: Optimization:':
                qncount = int(line[19:].split('qn')[1])
            elif line[:24] == 'msg: Molecular dynamics:':
                mdcount = int(line[25:].split('md')[1])
        self._counter = max((mdcount, qncount))
        if qncount == mdcount:
            # Either stopped during local optimization or terminated due to
            # max steps.
            self._log('msg', 'Attempting to resume at qn%05i' % qncount)
            if qncount > 0:
                atoms = io.read('qn%05i.traj' % (qncount - 1), index=-1)
                self._previous_optimum = atoms.copy()
                self._previous_energy = atoms.get_potential_energy()
            if os.path.getsize('qn%05i.traj' % qncount) > 0:
                atoms = io.read('qn%05i.traj' % qncount, index=-1)
            else:
                atoms = io.read('md%05i.traj' % qncount, index=-3)
            self._atoms = atoms.copy()
            fmax = np.sqrt((atoms.get_forces()**2).sum(axis=1).max())
            if fmax < self._fmax:
                # Stopped after a qn finished.
                self._log(
                    'msg', 'qn%05i fmax already less than fmax=%.3f' %
                    (qncount, self._fmax))
                self._counter += 1
                return
            self._optimize()
            self._counter += 1
            if qncount > 0:
                self._check_results()
            else:
                self._record_minimum()
                self._log('msg', 'Found a new minimum.')
                self._log('msg', 'Accepted new minimum.')
                self._log('par')
        elif qncount < mdcount:
            # Probably stopped during molecular dynamics.
            self._log('msg', 'Attempting to resume at md%05i.' % mdcount)
            atoms = io.read('qn%05i.traj' % qncount, index=-1)
            self._previous_optimum = atoms.copy()
            self._previous_energy = atoms.get_potential_energy()
            self._molecular_dynamics(resume=mdcount)
            self._optimize()
            self._counter += 1
            self._check_results()

    def _check_results(self):
        """Adjusts parameters and positions based on outputs."""

        # No prior minima found?
        self._read_minima()
        if len(self._minima) == 0:
            self._log('msg', 'Found a new minimum.')
            self._log('msg', 'Accepted new minimum.')
            self._record_minimum()
            self._log('par')
            return
        # Returned to starting position?
#        if self._previous_optimum:
#            compare = ComparePositions(translate=False)
#            dmax = compare(self._atoms, self._previous_optimum)
#            self._log('msg', 'Max distance to last minimum: %.3f A' % dmax)
#            if dmax < self._minima_threshold:
#                self._log('msg', 'Re-found last minimum.')
#                self._temperature *= self._beta1
#                self._log('par')
#                return
# In a previously found position?
#        unique, dmax_closest = self._unique_minimum_position()
#        self._log('msg', 'Max distance to closest minimum: %.3f A' %
#                  dmax_closest)
#        if not unique:
#            self._temperature *= self._beta2
#            self._log('msg', 'Found previously found minimum.')
#            self._log('par')
#            if self._previous_optimum:
#                self._log('msg', 'Restoring last minimum.')
#                self._atoms.positions = self._previous_optimum.positions
#            return
# Must have found a unique minimum.
#        self._temperature *= self._beta3
#        self._log('msg', 'Found a new minimum.')
#        self._log('par')
        if (self._previous_energy is None
                or (self._atoms.get_potential_energy() <
                    self._previous_energy + self._Ediff)):
            unique = self._is_unique()
            del self._atoms.info['fingerprints']
            if unique:
                self._log('msg', 'Accepted new minimum.')
                self._Ediff *= self._alpha1
                self._temperature = self._T0
                self._previous_optimum = self._atoms.copy()
                self._log('par')
                self._record_minimum()
            if not unique:
                self._log(
                    'msg',
                    'Rejected minimum because a similar fingerprint was found.'
                )
                self._atoms.positions = self._previous_optimum.positions
                self._atoms.cell = self._previous_optimum.cell
                self._Ediff *= self._alpha2
                self._temperature *= self._beta1
                self._log('par')
        else:
            self._log(
                'msg', 'Rejected new minimum due to energy. '
                'Restoring last minimum.')
            self._atoms.positions = self._previous_optimum.positions
            self._atoms.cell = self._previous_optimum.cell
            self._Ediff *= self._alpha2
            self._temperature *= self._beta1
            self._log('par')

    def _is_unique(self):
        if True in [
                self._comp.looks_like(self._atoms, min) for min in self._minima
        ]:
            return False
        else:
            return True

    def _log(self, cat='msg', message=None):
        """Records the message as a line in the log file."""
        if cat == 'init':
            if world.rank == 0:
                if os.path.exists(self._logfile):
                    raise RuntimeError('File exists: %s' % self._logfile)
            f = paropen(self._logfile, 'w')
            f.write('par: %12s %12s %12s\n' % ('T (K)', 'Ediff (eV)', 'mdmin'))
            f.write('ene: %12s %12s %12s\n' %
                    ('E_current', 'E_previous', 'Difference'))
            f.close()
            return
        f = paropen(self._logfile, 'a')
        if cat == 'msg':
            line = 'msg: %s' % message
        elif cat == 'par':
            line = ('par: %12.4f %12.4f %12i' %
                    (self._temperature, self._Ediff, self._mdmin))
        elif cat == 'ene':
            current = self._atoms.get_potential_energy()
            if self._previous_optimum:
                previous = self._previous_energy
                line = ('ene: %12.5f %12.5f %12.5f' %
                        (current, previous, current - previous))
            else:
                line = ('ene: %12.5f' % current)
        f.write(line + '\n')
        f.close()

    def _optimize(self):
        """Perform an optimization."""
        del self._atoms.constraints
        self._atoms.set_momenta(np.zeros(self._atoms.get_momenta().shape))
        geo_opt = FIRE(self._atoms)
        geo_opt.run(fmax=self._fmax2)
        self._constrain()
        ecf = ExpCellFilter(self._atoms)
        opt = self._optimizer(ecf,
                              trajectory='qn%05i.traj' % self._counter,
                              logfile='qn%05i.log' % self._counter)
        self._log('msg', 'Optimization: qn%05i' % self._counter)
        opt.run(fmax=self._fmax)
        self._log('ene')
        del self._atoms.constraints
        tri_mat, coord_transform = convert_cell_4NPT(self._atoms.get_cell())
        self._atoms.set_positions([
            np.matmul(coord_transform, position)
            for position in self._atoms.get_positions()
        ])
        self._atoms.set_cell(tri_mat.transpose())

    def _record_minimum(self):
        """Adds the current atoms configuration to the minima list."""
        traj = io.Trajectory(self._minima_traj, 'a')
        traj.write(self._atoms)
        self._read_minima()
        self._log('msg', 'Recorded minima #%i.' % (len(self._minima) - 1))

    def _read_minima(self):
        """Reads in the list of minima from the minima file."""
        exists = os.path.exists(self._minima_traj)
        if exists:
            empty = os.path.getsize(self._minima_traj) == 0
            if not empty:
                traj = io.Trajectory(self._minima_traj, 'r')
                self._minima = [atoms for atoms in traj]
            else:
                self._minima = []
            return True
        else:
            self._minima = []
            return False

    def _constrain(self):
        """Constrain atoms."""
        self._inorganic_positions = self._atoms.positions[
            self._inorganic_indices]
        constraints = [
            Hookean(a1=self._inorganic_indices[i],
                    a2=self._inorganic_positions[i],
                    rt=0.1,
                    k=15.) for i in range(len(self._inorganic_indices))
        ]
        #        self._Pb_positions = self._atoms.positions[self._Pb_indices]
        #        Pb_z = self._Pb_positions[:,2]
        #        Pb_average_z = np.average(Pb_z)
        #        Pb_upper_indices = self._Pb_indices[np.where(Pb_z > Pb_average_z)]
        #        Pb_bottom_indices = self._Pb_indices[np.where(Pb_z < Pb_average_z)]
        #        constraint_fix = [FixAtoms(indices=self._inorganic_indices)]
        #        constraint_upper_to_fix = [Hookean(a1=Pb_upper_indices[0], a2=int(i), rt=self._atoms.get_distance(Pb_upper_indices[0], i, mic=True) + 0.1, k=5.) for i in Pb_upper_indices[1:]]
        #        constraint_upper = [Hookean(a1=int(i), a2=int(j), rt=self._atoms.get_distance(i, j, mic=True) + 0.1, k=5.) for i in Pb_upper_indices[1:] for j in Pb_upper_indices[1:] if i < j]
        #        constraint_bottom = [Hookean(a1=int(i), a2=int(j), rt=self._atoms.get_distance(i, j, mic=True) + 0.1, k=5.) for i in Pb_bottom_indices for j in Pb_bottom_indices if i < j]
        #        constraints_non_flat = [constrain for constrain in [constraint_fix, constraint_upper_to_fix, constraint_upper, constraint_bottom] if constrain != []]
        #        constraints = [item for sublist in constraints_non_flat for item in sublist]
        #        constraints = [Hookean(a1=self._Pb_indices[i], a2=self._Pb_positions[i], rt=0.01, k=15.) for i in range(len(self._Pb_indices))]
        self._atoms.set_constraint(constraints)
#        print(atoms.constraints)

    def _molecular_dynamics(self, resume=None):
        """Performs a molecular dynamics simulation, until mdmin is
        exceeded. If resuming, the file number (md%05i) is expected."""
        self._log('msg', 'Molecular dynamics: md%05i' % self._counter)
        mincount = 0
        energies, oldpositions = [], []
        thermalized = False
        if resume:
            self._log('msg', 'Resuming MD from md%05i.traj' % resume)
            if os.path.getsize('md%05i.traj' % resume) == 0:
                self._log(
                    'msg', 'md%05i.traj is empty. Resuming from '
                    'qn%05i.traj.' % (resume, resume - 1))
                atoms = io.read('qn%05i.traj' % (resume - 1), index=-1)
            else:
                images = io.Trajectory('md%05i.traj' % resume, 'r')
                for atoms in images:
                    energies.append(atoms.get_potential_energy())
                    oldpositions.append(atoms.positions.copy())
                    passedmin = self._passedminimum(energies)
                    if passedmin:
                        mincount += 1
                self._atoms.set_momenta(atoms.get_momenta())
                thermalized = True
            self._atoms.positions = atoms.get_positions()
            self._log('msg',
                      'Starting MD with %i existing energies.' % len(energies))
        if not thermalized:
            MaxwellBoltzmannDistribution(self._atoms,
                                         temp=self._temperature * units.kB,
                                         force_temp=True)
        traj = io.Trajectory('md%05i.traj' % self._counter, 'a', self._atoms)
        self._constrain()
        dyn = NPT(self._atoms,
                  timestep=self._timestep * units.fs,
                  temperature=self._temperature * units.kB,
                  externalstress=self._externalstress,
                  ttime=self._ttime * units.fs,
                  pfactor=self._pfactor * units.fs**2)
        #        dyn = NPTber(self._atoms, timestep=self._timestep * units.fs, temperature=self._temperature, fixcm=True, pressure=self._pressure, taut=self._taut * units.fs, taup=self._taup * units.fs, compressibility=self._compressibility)
        log = MDLogger(dyn,
                       self._atoms,
                       'md%05i.log' % self._counter,
                       header=True,
                       stress=False,
                       peratom=False)
        dyn.attach(log, interval=1)
        dyn.attach(traj, interval=1)
        while mincount < self._mdmin:
            #            self._constrain()
            dyn.run(1)
            #            del self._atoms.constraints
            energies.append(self._atoms.get_potential_energy())
            passedmin = self._passedminimum(energies)
            if passedmin:
                mincount += 1
            oldpositions.append(self._atoms.positions.copy())
        # Reset atoms to minimum point.
        self._atoms.positions = oldpositions[passedmin[0]]
blmin_soft = closest_distances_generator(atom_numbers_to_optimize, 0.8)
softmut = SoftMutation(blmin_soft, bounds=[2., 5.], use_tags=True)

operators = OperationSelector([5, 1, 1, 1, 1, 1], [pairing, rattlemut,
                                                   strainmut, rotmut, rattlerotmut, softmut])

# Relaxing the initial candidates
while da.get_number_of_unrelaxed_candidates() > 0:
    a = da.get_an_unrelaxed_candidate()
    relax(a)
    da.add_relaxed_step(a)

# The structure comparator for the population
comp = OFPComparator(n_top=n_top, dE=1.0, cos_dist_max=5e-3, rcut=10.,
                     binwidth=0.05, pbc=[True, True, True], sigma=0.05,
                     nsigma=4, recalculate=False)

# The population
population = Population(data_connection=da,
                        population_size=10,
                        comparator=comp,
                        logfile='log.txt')

current_pop = population.get_current_population()
strainmut.update_scaling_volume(current_pop, w_adapt=0.5, n_adapt=4)
pairing.update_scaling_volume(current_pop, w_adapt=0.5, n_adapt=4)

# Test a few new candidates
n_to_test = 10
Пример #6
0
    cell = a3.get_cell()
    assert cellbounds.is_within_bounds(cell)
    assert np.all(a3.numbers == a.numbers)
    assert not atoms_too_close(a3, blmin, use_tags=True)

modes_file = 'modes.txt'
softmut_with = SoftMutation(blmin,
                            bounds=[2., 5.],
                            use_tags=True,
                            used_modes_file=modes_file)
no_muts = 3
for _ in range(no_muts):
    softmut_with.get_new_individual([a1])
softmut_with.read_used_modes(modes_file)
assert len(list(softmut_with.used_modes.values())[0]) == no_muts
os.remove(modes_file)

comparator = OFPComparator(recalculate=True)
gold = bulk('Au') * (2, 2, 2)
assert comparator.looks_like(gold, gold)

# This move should not exceed the default threshold
gc = gold.copy()
gc[0].x += .1
assert comparator.looks_like(gold, gc)

# An additional step will exceed the threshold
gc[0].x += .2
assert not comparator.looks_like(gold, gc)
Пример #7
0
                              atoms_too_close)
from ase.ga.offspring_creator import OperationSelector
from ase.ga.ofp_comparator import OFPComparator
from ase.ga.bulk_utilities import CellBounds
from ase.ga.bulk_startgenerator import StartGenerator
from ase.ga.bulk_crossovers import CutAndSplicePairing
from ase.ga.bulk_mutations import *
from ase.ga.standardmutations import RattleMutation
from tango.relax_utils import push_apart, relax_precon, finalize
from tango.calculators import DftbPlusCalculator

comparator = OFPComparator(n_top=None,
                           dE=1.0,
                           cos_dist_max=5e-2,
                           rcut=10.,
                           binwidth=0.05,
                           pbc=[True] * 3,
                           sigma=0.1,
                           nsigma=4,
                           recalculate=False)


def penalize(t):
    # penalize explosion:
    raw_score = get_raw_score(t)
    max_volume_per_atom = 50.
    if t.get_volume() / len(t) >= max_volume_per_atom:
        raw_score -= 1e9
    set_raw_score(t, raw_score)

    def __init__(self, atoms, **kwargs):
        """Initialize with an ASE atoms object and keyword arguments."""
        self._atoms = atoms
        for key in kwargs:
            if key not in self._default_settings:
                raise RuntimeError('Unknown keyword: %s' % key)
        for k, v in self._default_settings.items():
            setattr(self, '_%s' % k, kwargs.pop(k, v))

        # when a MD sim. has passed a local minimum:
        self._passedminimum = PassedMinimum()

        # Misc storage.
        self._previous_optimum = None
        self._previous_energy = None
        self._temperature = self._T0
        self._Ediff = self._Ediff0

        #Oganov fingerprints for structure comparison
        self._comp = OFPComparator(dE=10.0,
                                   cos_dist_max=self._minima_threshold,
                                   rcut=20.,
                                   binwidth=0.05,
                                   pbc=[True, True, True],
                                   sigma=0.05,
                                   nsigma=4,
                                   recalculate=False)

        #inorganic indices and positions for Hookean constraints
        self._Pb_indices = np.where(self._atoms.symbols == 'Pb')[0]
        self._Pb_positions = self._atoms.positions[self._Pb_indices]
        self._Br_indices = np.where(self._atoms.symbols == 'Br')[0]
        self._Br_positions = self._atoms.positions[self._Br_indices]
        self._inorganic_indices = np.concatenate(
            (self._Pb_indices, self._Br_indices))
        self._inorganic_positions = self._atoms.positions[
            self._inorganic_indices]
        #make list for distances of Pb atoms and 6 surrounding Br
        #        self._distances_list = np.empty((len(self._Pb_indices), 6))
        self._indices_list = np.empty((len(self._Pb_indices), 6), dtype='int')

        for i in range(len(self._Pb_indices)):
            distances = self._atoms.get_distances(self._Pb_indices[i],
                                                  self._Br_indices,
                                                  mic=True)
            #            self._distances_list[i] = distances[np.argsort(distances)[:6]]
            self._indices_list[i] = self._Br_indices[np.argsort(distances)[:6]]

        #find pairs of 2 Br's closest in z-direction and furthest in z-direction
        average_Br_z = np.average(self._Br_positions[:, 2])
        top_Br = self._Br_indices[np.where(
            self._Br_positions[:, 2] > average_Br_z)]
        bottom_Br = self._Br_indices[np.where(
            self._Br_positions[:, 2] < average_Br_z)]
        #Br groups are labelled 1,2,3,4 from bottom to top so 1-4 and 2-3 should be paired
        Br_group1 = bottom_Br[np.argsort(
            self._atoms.positions[bottom_Br][:, 2])[:4]]
        Br_group2 = bottom_Br[np.argsort(
            self._atoms.positions[bottom_Br][:, 2])[-4:]]
        Br_group3 = top_Br[np.argsort(self._atoms.positions[top_Br][:, 2])[:4]]
        Br_group4 = top_Br[np.argsort(self._atoms.positions[top_Br][:,
                                                                    2])[-4:]]

        self._Br_index1 = int(Br_group1[0])
        Br_relative1 = self._atoms.positions[
            self._Br_index1] - self._atoms.positions[Br_group4]
        self._Br_index4 = int(Br_group4[np.argmin(
            np.linalg.norm(Br_relative1[:, :2], axis=1))])

        Br_relative2 = self._atoms.positions[
            self._Br_index1] - self._atoms.positions[Br_group2]
        self._Br_index2 = int(Br_group2[np.argmin(
            np.linalg.norm(Br_relative2[:, :2], axis=1))])

        Br_relative3 = self._atoms.positions[
            self._Br_index2] - self._atoms.positions[Br_group3]
        self._Br_index3 = int(Br_group3[np.argmin(
            np.linalg.norm(Br_relative3[:, :2], axis=1))])
def test_bulk_operators():
    h2 = Atoms('H2', positions=[[0, 0, 0], [0, 0, 0.75]])
    blocks = [('H', 4), ('H2O', 3), (h2, 2)]  # the building blocks
    volume = 40. * sum([x[1] for x in blocks])  # cell volume in angstrom^3
    splits = {(2,): 1, (1,): 1}  # cell splitting scheme

    stoichiometry = []
    for block, count in blocks:
        if type(block) == str:
            stoichiometry += list(Atoms(block).numbers) * count
        else:
            stoichiometry += list(block.numbers) * count

    atom_numbers = list(set(stoichiometry))
    blmin = closest_distances_generator(atom_numbers=atom_numbers,
                                        ratio_of_covalent_radii=1.3)

    cellbounds = CellBounds(bounds={'phi': [30, 150], 'chi': [30, 150],
                                    'psi': [30, 150], 'a': [3, 50],
                                    'b': [3, 50], 'c': [3, 50]})

    sg = StartGenerator(blocks, blmin, volume, cellbounds=cellbounds,
                        splits=splits)

    # Generate 2 candidates
    a1 = sg.get_new_candidate()
    a1.info['confid'] = 1
    a2 = sg.get_new_candidate()
    a2.info['confid'] = 2

    # Define and test genetic operators
    pairing = CutAndSplicePairing(blmin, p1=1., p2=0., minfrac=0.15,
                                  cellbounds=cellbounds, use_tags=True)

    a3, desc = pairing.get_new_individual([a1, a2])
    cell = a3.get_cell()
    assert cellbounds.is_within_bounds(cell)
    assert not atoms_too_close(a3, blmin, use_tags=True)

    n_top = len(a1)
    strainmut = StrainMutation(blmin, stddev=0.7, cellbounds=cellbounds,
                               use_tags=True)
    softmut = SoftMutation(blmin, bounds=[2., 5.], used_modes_file=None,
                           use_tags=True)
    rotmut = RotationalMutation(blmin, fraction=0.3, min_angle=0.5 * np.pi)
    rattlemut = RattleMutation(blmin, n_top, rattle_prop=0.3, rattle_strength=0.5,
                               use_tags=True, test_dist_to_slab=False)
    rattlerotmut = RattleRotationalMutation(rattlemut, rotmut)
    permut = PermutationMutation(n_top, probability=0.33, test_dist_to_slab=False,
                                 use_tags=True, blmin=blmin)
    combmut = CombinationMutation(rattlemut, rotmut, verbose=True)
    mutations = [strainmut, softmut, rotmut,
                 rattlemut, rattlerotmut, permut, combmut]

    for i, mut in enumerate(mutations):
        a = [a1, a2][i % 2]
        a3 = None
        while a3 is None:
            a3, desc = mut.get_new_individual([a])

        cell = a3.get_cell()
        assert cellbounds.is_within_bounds(cell)
        assert np.all(a3.numbers == a.numbers)
        assert not atoms_too_close(a3, blmin, use_tags=True)

    modes_file = 'modes.txt'
    softmut_with = SoftMutation(blmin, bounds=[2., 5.], use_tags=True,
                                used_modes_file=modes_file)
    no_muts = 3
    for _ in range(no_muts):
        softmut_with.get_new_individual([a1])
    softmut_with.read_used_modes(modes_file)
    assert len(list(softmut_with.used_modes.values())[0]) == no_muts
    os.remove(modes_file)

    comparator = OFPComparator(recalculate=True)
    gold = bulk('Au') * (2, 2, 2)
    assert comparator.looks_like(gold, gold)

    # This move should not exceed the default threshold
    gc = gold.copy()
    gc[0].x += .1
    assert comparator.looks_like(gold, gc)

    # An additional step will exceed the threshold
    gc[0].x += .2
    assert not comparator.looks_like(gold, gc)