예제 #1
0
    def _forward(self, nn, mod_species, mod_coordinates):
        _, ANILastPart = self.break_into_two_stages(
            nn, split_at=self.split_at)  # only keep

        species_coordinates = (mod_species, mod_coordinates)

        coordinate_hash = hash(tuple(mod_coordinates[0].flatten().tolist()))

        if coordinate_hash in self.precalculation:
            species, y = self.precalculation[coordinate_hash]
        else:
            species, y = self.ANIFirstPart.forward(species_coordinates)
            self.precalculation[coordinate_hash] = (species, y)

        if self.training:
            # detach so we don't compute expensive gradients w.r.t. y
            species_y = SpeciesEnergies(species, y.detach())
        else:
            species_y = SpeciesEnergies(species, y)

        return ANILastPart.forward(species_y)
예제 #2
0
    def forward(
        self,
        species_aev: Tuple[Tensor, Tensor],  # type: ignore
        cell: Optional[Tensor] = None,
        pbc: Optional[Tensor] = None,
    ) -> Tuple[Tensor, Tensor]:

        species, aev = species_aev
        assert species.shape == aev.shape[:-1]

        # in our case, species will be the same for all snapshots
        atom_species = species[0]
        assert (atom_species == species).all()

        # NOTE: depending on the element, outputs will have different dimensions...
        # something like output.shape = n_snapshots, n_atoms, n_dims
        # where n_dims is either 160, 128, or 96...

        # Ugly hard-coding approach: make this of size max_dim=160 and only write
        # into the first 96, 128, or 160 elements, NaN-poisoning the rest
        # TODO: make this less hard-code-y
        n_snapshots, n_atoms = species.shape
        max_dim = 200
        output = torch.zeros((n_snapshots, n_atoms, max_dim)) * np.nan
        # TODO: note intentional NaN-poisoning here -- not sure if there's a
        #   better way to emulate jagged array

        # loop through atom nets
        for i, (_, module) in enumerate(self.items()):
            mask = atom_species == i
            # look only at the elements that are present in species
            if sum(mask) > 0:
                # get output for these atoms given the aev for these atoms
                current_out = module(aev[:, mask, :])
                # dimenstion of current_out is [nr_of_frames, nr_of_atoms_with_element_i,max_dim]
                out_dim = current_out.shape[-1]
                # jagged array
                output[:, mask, :out_dim] = current_out
                # final dimenstions are [n_snapshots, n_atoms, max_dim]

        return SpeciesEnergies(species, output)
예제 #3
0
    def forward(
        self,
        species_aev: Tuple[Tensor, Tensor],
        cell: Optional[Tensor] = None,
        pbc: Optional[Tensor] = None,
    ) -> SpeciesEnergies:
        species, aev = species_aev
        species_ = species.flatten()
        aev = aev.flatten(0, 1)

        output = aev.new_zeros(species_.shape)

        for i, (_, m) in enumerate(self.items()):
            mask = species_ == i
            midx = mask.nonzero().flatten()
            if midx.shape[0] > 0:
                input_ = aev.index_select(0, midx)
                input_ = input_[:, :self.last_two_layer_nr_of_feature[
                    self.index_of_last_layer][i]]
                output.masked_scatter_(mask, m(input_).flatten())
        output = output.view_as(species)
        return SpeciesEnergies(species, torch.sum(output, dim=1))
예제 #4
0
    def forward(self, species_aev):
        species, _ = species_aev
        output = torch.stack(
            [m(species_aev).energies for m in self.ani_models], dim=2)

        return SpeciesEnergies(species, output)
예제 #5
0
    with profiler.profile(record_shapes=True) as prof_f:
        with profiler.record_function("model_inference"):
            species_y = f.forward(species_coordinates)
    s = prof_f.self_cpu_time_total / 1000000
    print(f"time to precompute up until last layer(s) (f): {s:.3f} s")

    with profiler.profile(record_shapes=True) as prof_g:
        with profiler.record_function("model_inference"):
            species_e = g.forward(species_y)
    s = prof_g.self_cpu_time_total / 1000000
    print(f"time to compute last layer(s) (g): {s:.3f} s")

    # finally, compute gradients w.r.t. last layer only
    g.train()

    # note that species_y requires grad
    species, y = species_y
    # print(y.requires_grad) # >True

    # detach so we don't compute expensive gradients w.r.t. y
    species_y_ = SpeciesEnergies(species, y.detach())
    # print(species_y_[1].requires_grad) # >False

    with profiler.profile(record_shapes=True) as prof_backprop:
        with profiler.record_function("model_inference"):
            L = g.forward(species_y_).energies.sum()
            L.backward()

    s = prof_backprop.self_cpu_time_total / 1000000
    print(f"time to compute derivatives of E w.r.t. last layer: {s:.3f} s")