Пример #1
0
def test_example_merge():
    m = pybel.readstring('smi', 'c1ccccc1CO')
    m.addh()
    m.make3D()

    c = molgrid.CoordinateSet(m, molgrid.ElementIndexTyper())
    c2 = molgrid.CoordinateSet(m)

    c2.make_vector_types()  #this should not screw up index types

    ex = molgrid.Example()
    ex.coord_sets.append(c)
    ex.coord_sets.append(c2)
    assert ex.type_size() == (c.max_type + c2.max_type)
    assert ex.coordinate_size() == (c.coord.dimension(0) +
                                    c2.type_index.size())

    c3 = ex.merge_coordinates()
    assert c3.coord.tonumpy().shape == (24, 3)

    t = np.concatenate(
        [c.type_index.tonumpy(),
         c2.type_index.tonumpy() + c.max_type])
    assert np.array_equal(t, c3.type_index.tonumpy())

    #test merging without unique types, which makes no sense
    c4 = ex.merge_coordinates(0, False)
    assert c4.coord.tonumpy().shape == (24, 3)
    t = np.concatenate([c.type_index.tonumpy(), c2.type_index.tonumpy()])
    assert np.array_equal(t, c4.type_index.tonumpy())

    #test sliced merging
    c5 = ex.merge_coordinates(1, False)
    assert c5.coord.tonumpy().shape == (8, 3)  #no hydrogens in this slice
Пример #2
0
def test_examplevec():
    m = pybel.readstring('smi','c1ccccc1CO')
    m.addh()
    m.make3D()
    
    c = molgrid.CoordinateSet(m,molgrid.ElementIndexTyper())
    c2 = molgrid.CoordinateSet(m)

    c2.make_vector_types() #this should not screw up index types
    
    ex = molgrid.Example()
    ex.coord_sets.append(c)
    ex.labels.append(0)
    
    ex2 = molgrid.Example()
    ex2.coord_sets.append(c2)
    ex2.labels.append(1)
    
    evec = molgrid.ExampleVec([ex,ex2])    
Пример #3
0
def load_examples(T):
    examples = []
    for coord, types, energy, diff in T:
        radii = np.array([typeradii[int(index)] for index in types],
                         dtype=np.float32)
        c = molgrid.CoordinateSet(coord, types, radii, 4)
        ex = molgrid.Example()
        ex.coord_sets.append(c)
        ex.labels.append(diff)
        examples.append(ex)
    return examples
Пример #4
0
# Grid dimensions (including types)
gdims = gm.grid_dimensions(t.num_types())

# Pre-allocate grid
# Only one example (batch size is 1)
grid = torch.zeros(1, *gdims, dtype=torch.float32, device="cuda:0")

obmol = next(pybel.readfile("sdf", args.sdf))
obmol.addh()
print(obmol, end="")

# Use OpenBabel molecule object (obmol.OBmol) instead of PyBel molecule (obmol)
cs = molgrid.CoordinateSet(obmol.OBMol, t)

ex = molgrid.Example()
ex.coord_sets.append(cs)

c = ex.coord_sets[0].center()  # Only one coordinate set
print("center:", tuple(c))

# https://gnina.github.io/libmolgrid/python/index.html#the-transform-class
transform = molgrid.Transform(
    c,
    random_translate=0.0,
    random_rotation=False,  # float  # bool
)
transform.forward(ex, ex)

# Compute grid
gm.forward(ex, grid[0])
Пример #5
0
    def forward(self, interpolate=False, spherical=False):
        assert len(self) > 0, 'data is empty'

        # get next batch of structures
        examples = self.ex_provider.next_batch(self.batch_size)
        labels = torch.zeros(self.batch_size, device=self.device)
        examples.extract_label(0, labels)

        # create lists for examples, structs and transforms
        batch_list = lambda: [None] * self.batch_size

        input_examples = batch_list()
        input_rec_structs = batch_list()
        input_lig_structs = batch_list()
        input_transforms = batch_list()

        cond_examples = batch_list()
        cond_rec_structs = batch_list()
        cond_lig_structs = batch_list()
        cond_transforms = batch_list()

        # create output tensors for atomic density grids
        input_grids = torch.zeros(
            self.batch_size,
            self.n_channels,
            *self.grid_maker.spatial_grid_dimensions(),
            dtype=torch.float32,
            device=self.device,
        )
        cond_grids = torch.zeros(
            self.batch_size,
            self.n_channels,
            *self.grid_maker.spatial_grid_dimensions(),
            dtype=torch.float32,
            device=self.device,
        )

        # split examples, create structs and transforms
        for i, ex in enumerate(examples):

            if self.diff_cond_structs:

                # different input and conditional molecules
                input_rec_coord_set, input_lig_coord_set, \
                    cond_rec_coord_set, cond_lig_coord_set = ex.coord_sets

                # split example into inputs and conditions
                input_ex = molgrid.Example()
                input_ex.coord_sets.append(input_rec_coord_set)
                input_ex.coord_sets.append(input_lig_coord_set)

                cond_ex = molgrid.Example()
                cond_ex.coord_sets.append(cond_rec_coord_set)
                cond_ex.coord_sets.append(cond_lig_coord_set)

            else:  # same conditional molecules as input
                input_rec_coord_set, input_lig_coord_set = ex.coord_sets
                cond_rec_coord_set, cond_lig_coord_set = ex.coord_sets
                input_ex = cond_ex = ex

            # store split examples for gridding
            input_examples[i] = input_ex
            cond_examples[i] = cond_ex

            # convert coord sets to atom structs
            input_rec_structs[i] = atom_structs.AtomStruct.from_coord_set(
                input_rec_coord_set,
                typer=self.rec_typer,
                data_root=self.root_dir,
                device=self.device)
            input_lig_structs[i] = atom_structs.AtomStruct.from_coord_set(
                input_lig_coord_set,
                typer=self.lig_typer,
                data_root=self.root_dir,
                device=self.device)
            if self.diff_cond_structs:
                cond_rec_structs[i] = atom_structs.AtomStruct.from_coord_set(
                    cond_rec_coord_set,
                    typer=self.rec_typer,
                    data_root=self.root_dir,
                    device=self.device)
                cond_lig_structs[i] = atom_structs.AtomStruct.from_coord_set(
                    cond_lig_coord_set,
                    typer=self.lig_typer,
                    data_root=self.root_dir,
                    device=self.device)
            else:  # same structs as input
                cond_rec_structs[i] = input_rec_structs[i]
                cond_lig_structs[i] = input_lig_structs[i]

            # create input transform
            input_transforms[i] = molgrid.Transform(
                center=input_lig_coord_set.center(),
                random_translate=self.random_translation,
                random_rotation=self.random_rotation,
            )
            if self.diff_cond_transform:

                # create conditional transform
                cond_transforms[i] = molgrid.Transform(
                    center=cond_lig_coord_set.center(),
                    random_translate=self.random_translation,
                    random_rotation=self.random_rotation,
                )
            else:  # same transform as input
                cond_transforms[i] = input_transforms[i]

        if interpolate:  # interpolate conditional transforms
            # i.e. location and orientation of conditional grid
            if not self.cond_interp.is_initialized:
                self.cond_interp.initialize(cond_examples[0])
            cond_transforms = self.cond_interp(
                transforms=cond_transforms,
                spherical=spherical,
            )

        # create density grids
        for i in range(self.batch_size):

            # create input density grid
            self.grid_maker.forward(input_examples[i], input_transforms[i],
                                    input_grids[i])
            if (self.diff_cond_transform or self.diff_cond_structs
                    or interpolate):
                # create conditional density grid
                self.grid_maker.forward(cond_examples[i], cond_transforms[i],
                                        cond_grids[i])
            else:  # same density grid as input
                cond_grids[i] = input_grids[i]

        input_structs = (input_rec_structs, input_lig_structs)
        cond_structs = (cond_rec_structs, cond_lig_structs)
        transforms = (input_transforms, cond_transforms)
        return (input_grids, cond_grids, input_structs, cond_structs,
                transforms, labels)