コード例 #1
0
 def __init__(self, images, unique_atoms, descriptor, Gs, fprange, label="example", cores=1):
     self.images = images
     if type(images) is not list:
         self.images = [images]
     self.descriptor = descriptor
     self.atom_images = self.images
     if isinstance(images, str):
         extension = os.path.splitext(images)[1]
         if extension != (".traj" or ".db"):
             self.atom_images = ase.io.read(images, ":")
     self.fprange = fprange
     self.training_unique_atoms = unique_atoms
     self.hashed_images = amp_hash(self.atom_images)
     G2_etas = Gs["G2_etas"]
     G2_rs_s = Gs["G2_rs_s"]
     G4_etas = Gs["G4_etas"]
     G4_zetas = Gs["G4_zetas"]
     G4_gammas = Gs["G4_gammas"]
     cutoff = Gs["cutoff"]
     if str(descriptor)[8:16] == "amptorch":
         self.hashed_images = hash_images(self.atom_images, Gs)
         make_amp_descriptors_simple_nn(
             self.atom_images, Gs, self.training_unique_atoms, cores=cores, label=label
         )
     G = make_symmetry_functions(elements=self.training_unique_atoms, type="G2", etas=G2_etas)
     G += make_symmetry_functions(
         elements=self.training_unique_atoms,
         type="G4",
         etas=G4_etas,
         zetas=G4_zetas,
         gammas=G4_gammas,
     )
     for g in G:
         g["Rs"] = G2_rs_s
     self.descriptor = self.descriptor(Gs=G, cutoff=cutoff)
     self.descriptor.calculate_fingerprints(
         self.hashed_images, calculate_derivatives=True
     )
     self.unique_atoms = self.unique()
コード例 #2
0
 def __init__(self,
              images,
              descriptor,
              Gs,
              forcetraining,
              label,
              cores,
              delta_data=None,
              store_primes=False,
              specific_atoms=False):
     self.images = images
     self.base_descriptor = descriptor
     self.descriptor = descriptor
     self.Gs = Gs
     self.atom_images = self.images
     self.forcetraining = forcetraining
     self.store_primes = store_primes
     self.cores = cores
     self.delta = False
     self.specific_atoms = specific_atoms
     if delta_data is not None:
         self.delta_data = delta_data
         self.delta_energies = np.array(delta_data[0])
         self.delta_forces = delta_data[1]
         self.num_atoms = np.array(delta_data[2])
         self.delta = True
     if self.store_primes:
         if not os.path.isdir("./stored-primes/"):
             os.mkdir("stored-primes")
     if isinstance(images, str):
         extension = os.path.splitext(images)[1]
         if extension != (".traj" or ".db"):
             self.atom_images = ase.io.read(images, ":")
     self.elements = self.unique()
     #TODO Print log - control verbose
     print("Calculating fingerprints...")
     G2_etas = Gs["G2_etas"]
     G2_rs_s = Gs["G2_rs_s"]
     G4_etas = Gs["G4_etas"]
     G4_zetas = Gs["G4_zetas"]
     G4_gammas = Gs["G4_gammas"]
     cutoff = Gs["cutoff"]
     # create simple_nn fingerprints
     if str(descriptor)[8:16] == "amptorch":
         self.hashed_images = hash_images(self.atom_images, Gs=Gs)
         make_amp_descriptors_simple_nn(self.atom_images,
                                        Gs,
                                        self.elements,
                                        cores=cores,
                                        label=label,
                                        specific_atoms=self.specific_atoms)
         self.isamp_hash = False
     else:
         self.hashed_images = amp_hash(self.atom_images)
         self.isamp_hash = True
     G = make_symmetry_functions(elements=self.elements,
                                 type="G2",
                                 etas=G2_etas)
     G += make_symmetry_functions(
         elements=self.elements,
         type="G4",
         etas=G4_etas,
         zetas=G4_zetas,
         gammas=G4_gammas,
     )
     for g in list(G):
         g["Rs"] = G2_rs_s
     self.descriptor = self.descriptor(Gs=G, cutoff=cutoff)
     self.descriptor.calculate_fingerprints(
         self.hashed_images, calculate_derivatives=forcetraining)
     print("Fingerprints Calculated!")
     self.fprange = calculate_fingerprints_range(self.descriptor,
                                                 self.hashed_images)
     # perform preprocessing
     self.fingerprint_dataset, self.energy_dataset, self.num_of_atoms, self.sparse_fprimes, self.forces_dataset, self.index_hashes, self.scalings, self.rearange_forces = (
         self.preprocess_data())