def __init__(self, images, unique_atoms, descriptor, Gs, fprange, label="example", cores=1): self.images = images if type(images) is not list: self.images = [images] self.descriptor = descriptor self.atom_images = self.images if isinstance(images, str): extension = os.path.splitext(images)[1] if extension != (".traj" or ".db"): self.atom_images = ase.io.read(images, ":") self.fprange = fprange self.training_unique_atoms = unique_atoms self.hashed_images = amp_hash(self.atom_images) G2_etas = Gs["G2_etas"] G2_rs_s = Gs["G2_rs_s"] G4_etas = Gs["G4_etas"] G4_zetas = Gs["G4_zetas"] G4_gammas = Gs["G4_gammas"] cutoff = Gs["cutoff"] if str(descriptor)[8:16] == "amptorch": self.hashed_images = hash_images(self.atom_images, Gs) make_amp_descriptors_simple_nn( self.atom_images, Gs, self.training_unique_atoms, cores=cores, label=label ) G = make_symmetry_functions(elements=self.training_unique_atoms, type="G2", etas=G2_etas) G += make_symmetry_functions( elements=self.training_unique_atoms, type="G4", etas=G4_etas, zetas=G4_zetas, gammas=G4_gammas, ) for g in G: g["Rs"] = G2_rs_s self.descriptor = self.descriptor(Gs=G, cutoff=cutoff) self.descriptor.calculate_fingerprints( self.hashed_images, calculate_derivatives=True ) self.unique_atoms = self.unique()
def __init__(self, images, descriptor, Gs, forcetraining, label, cores, delta_data=None, store_primes=False, specific_atoms=False): self.images = images self.base_descriptor = descriptor self.descriptor = descriptor self.Gs = Gs self.atom_images = self.images self.forcetraining = forcetraining self.store_primes = store_primes self.cores = cores self.delta = False self.specific_atoms = specific_atoms if delta_data is not None: self.delta_data = delta_data self.delta_energies = np.array(delta_data[0]) self.delta_forces = delta_data[1] self.num_atoms = np.array(delta_data[2]) self.delta = True if self.store_primes: if not os.path.isdir("./stored-primes/"): os.mkdir("stored-primes") if isinstance(images, str): extension = os.path.splitext(images)[1] if extension != (".traj" or ".db"): self.atom_images = ase.io.read(images, ":") self.elements = self.unique() #TODO Print log - control verbose print("Calculating fingerprints...") G2_etas = Gs["G2_etas"] G2_rs_s = Gs["G2_rs_s"] G4_etas = Gs["G4_etas"] G4_zetas = Gs["G4_zetas"] G4_gammas = Gs["G4_gammas"] cutoff = Gs["cutoff"] # create simple_nn fingerprints if str(descriptor)[8:16] == "amptorch": self.hashed_images = hash_images(self.atom_images, Gs=Gs) make_amp_descriptors_simple_nn(self.atom_images, Gs, self.elements, cores=cores, label=label, specific_atoms=self.specific_atoms) self.isamp_hash = False else: self.hashed_images = amp_hash(self.atom_images) self.isamp_hash = True G = make_symmetry_functions(elements=self.elements, type="G2", etas=G2_etas) G += make_symmetry_functions( elements=self.elements, type="G4", etas=G4_etas, zetas=G4_zetas, gammas=G4_gammas, ) for g in list(G): g["Rs"] = G2_rs_s self.descriptor = self.descriptor(Gs=G, cutoff=cutoff) self.descriptor.calculate_fingerprints( self.hashed_images, calculate_derivatives=forcetraining) print("Fingerprints Calculated!") self.fprange = calculate_fingerprints_range(self.descriptor, self.hashed_images) # perform preprocessing self.fingerprint_dataset, self.energy_dataset, self.num_of_atoms, self.sparse_fprimes, self.forces_dataset, self.index_hashes, self.scalings, self.rearange_forces = ( self.preprocess_data())