def test_example_merge(): m = pybel.readstring('smi', 'c1ccccc1CO') m.addh() m.make3D() c = molgrid.CoordinateSet(m, molgrid.ElementIndexTyper()) c2 = molgrid.CoordinateSet(m) c2.make_vector_types() #this should not screw up index types ex = molgrid.Example() ex.coord_sets.append(c) ex.coord_sets.append(c2) assert ex.type_size() == (c.max_type + c2.max_type) assert ex.coordinate_size() == (c.coord.dimension(0) + c2.type_index.size()) c3 = ex.merge_coordinates() assert c3.coord.tonumpy().shape == (24, 3) t = np.concatenate( [c.type_index.tonumpy(), c2.type_index.tonumpy() + c.max_type]) assert np.array_equal(t, c3.type_index.tonumpy()) #test merging without unique types, which makes no sense c4 = ex.merge_coordinates(0, False) assert c4.coord.tonumpy().shape == (24, 3) t = np.concatenate([c.type_index.tonumpy(), c2.type_index.tonumpy()]) assert np.array_equal(t, c4.type_index.tonumpy()) #test sliced merging c5 = ex.merge_coordinates(1, False) assert c5.coord.tonumpy().shape == (8, 3) #no hydrogens in this slice
def test_examplevec(): m = pybel.readstring('smi','c1ccccc1CO') m.addh() m.make3D() c = molgrid.CoordinateSet(m,molgrid.ElementIndexTyper()) c2 = molgrid.CoordinateSet(m) c2.make_vector_types() #this should not screw up index types ex = molgrid.Example() ex.coord_sets.append(c) ex.labels.append(0) ex2 = molgrid.Example() ex2.coord_sets.append(c2) ex2.labels.append(1) evec = molgrid.ExampleVec([ex,ex2])
def load_examples(T): examples = [] for coord, types, energy, diff in T: radii = np.array([typeradii[int(index)] for index in types], dtype=np.float32) c = molgrid.CoordinateSet(coord, types, radii, 4) ex = molgrid.Example() ex.coord_sets.append(c) ex.labels.append(diff) examples.append(ex) return examples
# Grid dimensions (including types) gdims = gm.grid_dimensions(t.num_types()) # Pre-allocate grid # Only one example (batch size is 1) grid = torch.zeros(1, *gdims, dtype=torch.float32, device="cuda:0") obmol = next(pybel.readfile("sdf", args.sdf)) obmol.addh() print(obmol, end="") # Use OpenBabel molecule object (obmol.OBmol) instead of PyBel molecule (obmol) cs = molgrid.CoordinateSet(obmol.OBMol, t) ex = molgrid.Example() ex.coord_sets.append(cs) c = ex.coord_sets[0].center() # Only one coordinate set print("center:", tuple(c)) # https://gnina.github.io/libmolgrid/python/index.html#the-transform-class transform = molgrid.Transform( c, random_translate=0.0, random_rotation=False, # float # bool ) transform.forward(ex, ex) # Compute grid gm.forward(ex, grid[0])
def forward(self, interpolate=False, spherical=False): assert len(self) > 0, 'data is empty' # get next batch of structures examples = self.ex_provider.next_batch(self.batch_size) labels = torch.zeros(self.batch_size, device=self.device) examples.extract_label(0, labels) # create lists for examples, structs and transforms batch_list = lambda: [None] * self.batch_size input_examples = batch_list() input_rec_structs = batch_list() input_lig_structs = batch_list() input_transforms = batch_list() cond_examples = batch_list() cond_rec_structs = batch_list() cond_lig_structs = batch_list() cond_transforms = batch_list() # create output tensors for atomic density grids input_grids = torch.zeros( self.batch_size, self.n_channels, *self.grid_maker.spatial_grid_dimensions(), dtype=torch.float32, device=self.device, ) cond_grids = torch.zeros( self.batch_size, self.n_channels, *self.grid_maker.spatial_grid_dimensions(), dtype=torch.float32, device=self.device, ) # split examples, create structs and transforms for i, ex in enumerate(examples): if self.diff_cond_structs: # different input and conditional molecules input_rec_coord_set, input_lig_coord_set, \ cond_rec_coord_set, cond_lig_coord_set = ex.coord_sets # split example into inputs and conditions input_ex = molgrid.Example() input_ex.coord_sets.append(input_rec_coord_set) input_ex.coord_sets.append(input_lig_coord_set) cond_ex = molgrid.Example() cond_ex.coord_sets.append(cond_rec_coord_set) cond_ex.coord_sets.append(cond_lig_coord_set) else: # same conditional molecules as input input_rec_coord_set, input_lig_coord_set = ex.coord_sets cond_rec_coord_set, cond_lig_coord_set = ex.coord_sets input_ex = cond_ex = ex # store split examples for gridding input_examples[i] = input_ex cond_examples[i] = cond_ex # convert coord sets to atom structs input_rec_structs[i] = atom_structs.AtomStruct.from_coord_set( input_rec_coord_set, typer=self.rec_typer, data_root=self.root_dir, device=self.device) input_lig_structs[i] = atom_structs.AtomStruct.from_coord_set( input_lig_coord_set, typer=self.lig_typer, data_root=self.root_dir, device=self.device) if self.diff_cond_structs: cond_rec_structs[i] = atom_structs.AtomStruct.from_coord_set( cond_rec_coord_set, typer=self.rec_typer, data_root=self.root_dir, device=self.device) cond_lig_structs[i] = atom_structs.AtomStruct.from_coord_set( cond_lig_coord_set, typer=self.lig_typer, data_root=self.root_dir, device=self.device) else: # same structs as input cond_rec_structs[i] = input_rec_structs[i] cond_lig_structs[i] = input_lig_structs[i] # create input transform input_transforms[i] = molgrid.Transform( center=input_lig_coord_set.center(), random_translate=self.random_translation, random_rotation=self.random_rotation, ) if self.diff_cond_transform: # create conditional transform cond_transforms[i] = molgrid.Transform( center=cond_lig_coord_set.center(), random_translate=self.random_translation, random_rotation=self.random_rotation, ) else: # same transform as input cond_transforms[i] = input_transforms[i] if interpolate: # interpolate conditional transforms # i.e. location and orientation of conditional grid if not self.cond_interp.is_initialized: self.cond_interp.initialize(cond_examples[0]) cond_transforms = self.cond_interp( transforms=cond_transforms, spherical=spherical, ) # create density grids for i in range(self.batch_size): # create input density grid self.grid_maker.forward(input_examples[i], input_transforms[i], input_grids[i]) if (self.diff_cond_transform or self.diff_cond_structs or interpolate): # create conditional density grid self.grid_maker.forward(cond_examples[i], cond_transforms[i], cond_grids[i]) else: # same density grid as input cond_grids[i] = input_grids[i] input_structs = (input_rec_structs, input_lig_structs) cond_structs = (cond_rec_structs, cond_lig_structs) transforms = (input_transforms, cond_transforms) return (input_grids, cond_grids, input_structs, cond_structs, transforms, labels)