Ejemplo n.º 1
0
def test_gnina_example_provider():
    fname = datadir + "/small.types"
    e = molgrid.ExampleProvider(data_root=datadir + "/structs")
    e.populate(fname)

    batch_size = 100
    batch = e.next_batch(batch_size)
    #extract labels
    nlabels = e.num_labels()
    assert nlabels == 3
    labels = molgrid.MGrid2f(batch_size, nlabels)
    gpulabels = molgrid.MGrid2f(batch_size, nlabels)

    batch.extract_labels(labels.cpu())
    batch.extract_labels(gpulabels.gpu())
    assert np.array_equal(labels.tonumpy(), gpulabels.tonumpy())
    label0 = molgrid.MGrid1f(batch_size)
    label1 = molgrid.MGrid1f(batch_size)
    label2 = molgrid.MGrid1f(batch_size)
    batch.extract_label(0, label0.cpu())
    batch.extract_label(1, label1.cpu())
    batch.extract_label(2, label2.gpu())

    assert label0[0] == 1
    assert label1[0] == approx(6.05)
    assert label2[0] == approx(0.162643)
    assert labels[0, 0] == 1
    assert labels[0][1] == approx(6.05)
    assert labels[0][2] == approx(0.162643)

    for i in range(nlabels):
        assert label0[i] == labels[i][0]
        assert label1[i] == labels[i][1]
        assert label2[i] == labels[i][2]

    ex = batch[0]
    crec = ex.coord_sets[0]
    assert crec.size() == 1781
    assert list(crec.coords[0]) == approx([45.042, 12.872, 13.001])
    assert crec.radii[0] == approx(1.8)
    assert list(crec.type_index)[:10] == [
        6.0, 1.0, 1.0, 7.0, 0.0, 6.0, 1.0, 1.0, 7.0, 1.0
    ]

    clig = ex.coord_sets[1]
    assert clig.size() == 10
    assert list(clig.coords[9]) == approx([27.0536, 3.2453, 32.4511])
    assert list(clig.type_index) == [
        8.0, 1.0, 1.0, 9.0, 10.0, 0.0, 0.0, 1.0, 9.0, 8.0
    ]

    batch = e.next_batch(1)
    a = np.array([0], dtype=np.float32)
    batch.extract_label(1, a)
Ejemplo n.º 2
0
def test_coords2grid():
    gmaker = molgrid.GridMaker(resolution=0.5,
                               dimension=23.5,
                               radius_scale=1,
                               radius_type_indexed=True)
    n_types = molgrid.defaultGninaLigandTyper.num_types()
    radii = np.array(list(molgrid.defaultGninaLigandTyper.get_type_radii()),
                     np.float32)
    dims = gmaker.grid_dimensions(n_types)
    grid_size = dims[0] * dims[1] * dims[2] * dims[3]

    c2grid = molgrid.Coords2Grid(gmaker, center=(0, 0, 0))
    n_atoms = 2
    batch_size = 1
    coords = nn.Parameter(torch.randn(n_atoms, 3, device='cuda'))
    types = nn.Parameter(torch.randn(n_atoms, n_types + 1, device='cuda'))

    coords.data[0, :] = torch.tensor([1, 0, 0])
    coords.data[1, :] = torch.tensor([-1, 0, 0])
    types.data[...] = 0
    types.data[:, 10] = 1

    batch_radii = torch.tensor(np.tile(radii, (batch_size, 1)),
                               dtype=torch.float32,
                               device='cuda')

    grid_gen = c2grid(coords.unsqueeze(0),
                      types.unsqueeze(0)[:, :, :-1], batch_radii)

    assert float(grid_gen[0][10].sum()) == approx(float(grid_gen.sum()))
    assert grid_gen.sum() > 0

    target = torch.zeros_like(grid_gen)
    target[0, :, 24, 24, 24] = 1000.0

    grad_coords = molgrid.MGrid2f(n_atoms, 3)
    grad_types = molgrid.MGrid2f(n_atoms, n_types)
    r = molgrid.MGrid1f(len(radii))
    r.copyFrom(radii)

    grid_loss = F.mse_loss(target, grid_gen)
    grid_loss.backward()
    print(grid_loss)
    print(coords.grad.detach().cpu().numpy())