def _forward(self, nn, mod_species, mod_coordinates): _, ANILastPart = self.break_into_two_stages( nn, split_at=self.split_at) # only keep species_coordinates = (mod_species, mod_coordinates) coordinate_hash = hash(tuple(mod_coordinates[0].flatten().tolist())) if coordinate_hash in self.precalculation: species, y = self.precalculation[coordinate_hash] else: species, y = self.ANIFirstPart.forward(species_coordinates) self.precalculation[coordinate_hash] = (species, y) if self.training: # detach so we don't compute expensive gradients w.r.t. y species_y = SpeciesEnergies(species, y.detach()) else: species_y = SpeciesEnergies(species, y) return ANILastPart.forward(species_y)
def forward( self, species_aev: Tuple[Tensor, Tensor], # type: ignore cell: Optional[Tensor] = None, pbc: Optional[Tensor] = None, ) -> Tuple[Tensor, Tensor]: species, aev = species_aev assert species.shape == aev.shape[:-1] # in our case, species will be the same for all snapshots atom_species = species[0] assert (atom_species == species).all() # NOTE: depending on the element, outputs will have different dimensions... # something like output.shape = n_snapshots, n_atoms, n_dims # where n_dims is either 160, 128, or 96... # Ugly hard-coding approach: make this of size max_dim=160 and only write # into the first 96, 128, or 160 elements, NaN-poisoning the rest # TODO: make this less hard-code-y n_snapshots, n_atoms = species.shape max_dim = 200 output = torch.zeros((n_snapshots, n_atoms, max_dim)) * np.nan # TODO: note intentional NaN-poisoning here -- not sure if there's a # better way to emulate jagged array # loop through atom nets for i, (_, module) in enumerate(self.items()): mask = atom_species == i # look only at the elements that are present in species if sum(mask) > 0: # get output for these atoms given the aev for these atoms current_out = module(aev[:, mask, :]) # dimenstion of current_out is [nr_of_frames, nr_of_atoms_with_element_i,max_dim] out_dim = current_out.shape[-1] # jagged array output[:, mask, :out_dim] = current_out # final dimenstions are [n_snapshots, n_atoms, max_dim] return SpeciesEnergies(species, output)
def forward( self, species_aev: Tuple[Tensor, Tensor], cell: Optional[Tensor] = None, pbc: Optional[Tensor] = None, ) -> SpeciesEnergies: species, aev = species_aev species_ = species.flatten() aev = aev.flatten(0, 1) output = aev.new_zeros(species_.shape) for i, (_, m) in enumerate(self.items()): mask = species_ == i midx = mask.nonzero().flatten() if midx.shape[0] > 0: input_ = aev.index_select(0, midx) input_ = input_[:, :self.last_two_layer_nr_of_feature[ self.index_of_last_layer][i]] output.masked_scatter_(mask, m(input_).flatten()) output = output.view_as(species) return SpeciesEnergies(species, torch.sum(output, dim=1))
def forward(self, species_aev): species, _ = species_aev output = torch.stack( [m(species_aev).energies for m in self.ani_models], dim=2) return SpeciesEnergies(species, output)
with profiler.profile(record_shapes=True) as prof_f: with profiler.record_function("model_inference"): species_y = f.forward(species_coordinates) s = prof_f.self_cpu_time_total / 1000000 print(f"time to precompute up until last layer(s) (f): {s:.3f} s") with profiler.profile(record_shapes=True) as prof_g: with profiler.record_function("model_inference"): species_e = g.forward(species_y) s = prof_g.self_cpu_time_total / 1000000 print(f"time to compute last layer(s) (g): {s:.3f} s") # finally, compute gradients w.r.t. last layer only g.train() # note that species_y requires grad species, y = species_y # print(y.requires_grad) # >True # detach so we don't compute expensive gradients w.r.t. y species_y_ = SpeciesEnergies(species, y.detach()) # print(species_y_[1].requires_grad) # >False with profiler.profile(record_shapes=True) as prof_backprop: with profiler.record_function("model_inference"): L = g.forward(species_y_).energies.sum() L.backward() s = prof_backprop.self_cpu_time_total / 1000000 print(f"time to compute derivatives of E w.r.t. last layer: {s:.3f} s")