def test_gnina_example_provider(): fname = datadir + "/small.types" e = molgrid.ExampleProvider(data_root=datadir + "/structs") e.populate(fname) batch_size = 100 batch = e.next_batch(batch_size) #extract labels nlabels = e.num_labels() assert nlabels == 3 labels = molgrid.MGrid2f(batch_size, nlabels) gpulabels = molgrid.MGrid2f(batch_size, nlabels) batch.extract_labels(labels.cpu()) batch.extract_labels(gpulabels.gpu()) assert np.array_equal(labels.tonumpy(), gpulabels.tonumpy()) label0 = molgrid.MGrid1f(batch_size) label1 = molgrid.MGrid1f(batch_size) label2 = molgrid.MGrid1f(batch_size) batch.extract_label(0, label0.cpu()) batch.extract_label(1, label1.cpu()) batch.extract_label(2, label2.gpu()) assert label0[0] == 1 assert label1[0] == approx(6.05) assert label2[0] == approx(0.162643) assert labels[0, 0] == 1 assert labels[0][1] == approx(6.05) assert labels[0][2] == approx(0.162643) for i in range(nlabels): assert label0[i] == labels[i][0] assert label1[i] == labels[i][1] assert label2[i] == labels[i][2] ex = batch[0] crec = ex.coord_sets[0] assert crec.size() == 1781 assert list(crec.coords[0]) == approx([45.042, 12.872, 13.001]) assert crec.radii[0] == approx(1.8) assert list(crec.type_index)[:10] == [ 6.0, 1.0, 1.0, 7.0, 0.0, 6.0, 1.0, 1.0, 7.0, 1.0 ] clig = ex.coord_sets[1] assert clig.size() == 10 assert list(clig.coords[9]) == approx([27.0536, 3.2453, 32.4511]) assert list(clig.type_index) == [ 8.0, 1.0, 1.0, 9.0, 10.0, 0.0, 0.0, 1.0, 9.0, 8.0 ] batch = e.next_batch(1) a = np.array([0], dtype=np.float32) batch.extract_label(1, a)
def test_coords2grid(): gmaker = molgrid.GridMaker(resolution=0.5, dimension=23.5, radius_scale=1, radius_type_indexed=True) n_types = molgrid.defaultGninaLigandTyper.num_types() radii = np.array(list(molgrid.defaultGninaLigandTyper.get_type_radii()), np.float32) dims = gmaker.grid_dimensions(n_types) grid_size = dims[0] * dims[1] * dims[2] * dims[3] c2grid = molgrid.Coords2Grid(gmaker, center=(0, 0, 0)) n_atoms = 2 batch_size = 1 coords = nn.Parameter(torch.randn(n_atoms, 3, device='cuda')) types = nn.Parameter(torch.randn(n_atoms, n_types + 1, device='cuda')) coords.data[0, :] = torch.tensor([1, 0, 0]) coords.data[1, :] = torch.tensor([-1, 0, 0]) types.data[...] = 0 types.data[:, 10] = 1 batch_radii = torch.tensor(np.tile(radii, (batch_size, 1)), dtype=torch.float32, device='cuda') grid_gen = c2grid(coords.unsqueeze(0), types.unsqueeze(0)[:, :, :-1], batch_radii) assert float(grid_gen[0][10].sum()) == approx(float(grid_gen.sum())) assert grid_gen.sum() > 0 target = torch.zeros_like(grid_gen) target[0, :, 24, 24, 24] = 1000.0 grad_coords = molgrid.MGrid2f(n_atoms, 3) grad_types = molgrid.MGrid2f(n_atoms, n_types) r = molgrid.MGrid1f(len(radii)) r.copyFrom(radii) grid_loss = F.mse_loss(target, grid_gen) grid_loss.backward() print(grid_loss) print(coords.grad.detach().cpu().numpy())