Ejemplo n.º 1
0
def test_coords2grid():
    gmaker = molgrid.GridMaker(resolution=0.5,
                               dimension=23.5,
                               radius_scale=1,
                               radius_type_indexed=True)
    n_types = molgrid.defaultGninaLigandTyper.num_types()
    radii = np.array(list(molgrid.defaultGninaLigandTyper.get_type_radii()),
                     np.float32)
    dims = gmaker.grid_dimensions(n_types)
    grid_size = dims[0] * dims[1] * dims[2] * dims[3]

    c2grid = molgrid.Coords2Grid(gmaker, center=(0, 0, 0))
    n_atoms = 2
    batch_size = 1
    coords = nn.Parameter(torch.randn(n_atoms, 3, device='cuda'))
    types = nn.Parameter(torch.randn(n_atoms, n_types + 1, device='cuda'))

    coords.data[0, :] = torch.tensor([1, 0, 0])
    coords.data[1, :] = torch.tensor([-1, 0, 0])
    types.data[...] = 0
    types.data[:, 10] = 1

    batch_radii = torch.tensor(np.tile(radii, (batch_size, 1)),
                               dtype=torch.float32,
                               device='cuda')

    grid_gen = c2grid(coords.unsqueeze(0),
                      types.unsqueeze(0)[:, :, :-1], batch_radii)

    assert float(grid_gen[0][10].sum()) == approx(float(grid_gen.sum()))
    assert grid_gen.sum() > 0

    target = torch.zeros_like(grid_gen)
    target[0, :, 24, 24, 24] = 1000.0

    grad_coords = molgrid.MGrid2f(n_atoms, 3)
    grad_types = molgrid.MGrid2f(n_atoms, n_types)
    r = molgrid.MGrid1f(len(radii))
    r.copyFrom(radii)

    grid_loss = F.mse_loss(target, grid_gen)
    grid_loss.backward()
    print(grid_loss)
    print(coords.grad.detach().cpu().numpy())
Ejemplo n.º 2
0
    def __init__(
        self,
        beam_size=1,
        multi_atom=False,
        n_atoms_detect=1,
        apply_conv=False,
        threshold=0.1,
        peak_value=1.5,
        min_dist=0.0,
        apply_prop_conv=False,
        constrain_types=False,
        constrain_frags=False,
        estimate_types=False,
        fit_L1_loss=False,
        interm_gd_iters=10,
        final_gd_iters=100,
        gd_kwargs=dict(
            lr=0.1,
            betas=(0.9, 0.999),
            weight_decay=0.0,
        ),
        dkoes_make_mol=True,
        use_openbabel=False,
        output_kernel=False,
        device='cuda',
        verbose=0,
        debug=False,
    ):
        # number of best structures to store and expand during search
        self.beam_size = beam_size

        # maximum number of atoms to detect in remaining density
        self.n_atoms_detect = n_atoms_detect

        # try placing all detected atoms at once, then try individually
        self.multi_atom = multi_atom

        # settings for detecting atoms in element channels
        self.apply_conv = apply_conv
        self.threshold = threshold
        self.peak_value = peak_value
        self.min_dist = min_dist

        # setting for detecting properties in property channels
        self.apply_prop_conv = apply_prop_conv

        # can constrain to find exact atom type counts or single fragment
        self.constrain_types = constrain_types
        self.constrain_frags = constrain_frags
        self.estimate_types = estimate_types

        # can perform gradient descent at each step and/or at final step
        self.fit_L1_loss = fit_L1_loss
        self.interm_gd_iters = interm_gd_iters
        self.final_gd_iters = final_gd_iters
        self.gd_kwargs = gd_kwargs

        self.output_kernel = output_kernel
        self.device = device
        self.verbose = verbose
        self.debug = debug

        self.grid_maker = molgrid.GridMaker(gaussian_radius_multiple=-1.5)
        self.c2grid = molgrid.Coords2Grid(self.grid_maker)

        # lazily initialize atom density kernel
        self.kernel = None
Ejemplo n.º 3
0
def simple_atom_fit(mgrid, types, iters=10, tol=0.01, device='cuda', grm=-1.5):
    '''Fit atoms to AtomGrid.  types are ignored as the number of 
    atoms of each type is always inferred from the density.
    Returns the AtomGrid of the placed atoms and the AtomStruct'''

    t_start = time.time()
    #for every channel, select some coordinates and setup the type/radius vectors
    initcoords = []
    typevecs = []
    radii = []
    typeindices = []
    numatoms = 0
    tcnts = {}
    values = torch.tensor(mgrid.values, device=device)

    for (t, G) in enumerate(values):
        ch = mgrid.channels[t]
        coords = select_atom_starts(mgrid, G, ch.atomic_radius)
        if coords:
            tvec = np.zeros(len(mgrid.channels))
            tvec[t] = 1.0
            tcnt = len(coords)
            numatoms += tcnt

            r = mgrid.channels[t].atomic_radius
            initcoords += coords
            typevecs += [tvec] * tcnt
            typeindices += [t] * tcnt
            radii += [r] * tcnt
            tcnts[t] = tcnt

    typevecs = np.array(typevecs)
    initcoords = np.array(initcoords)
    typeindices = np.array(typeindices)
    #print('typeindices',typeindices)
    #setup gridder
    center = tuple([float(c) for c in mgrid.center])
    gridder = molgrid.Coords2Grid(molgrid.GridMaker(
        dimension=mgrid.dimension,
        resolution=mgrid.resolution,
        gaussian_radius_multiple=grm),
                                  center=center)

    #having setup input coordinates, optimize with BFGS
    coords = torch.tensor(initcoords,
                          dtype=torch.float32,
                          requires_grad=True,
                          device=device)
    types = torch.tensor(typevecs, dtype=torch.float32, device=device)
    radii = torch.tensor(radii, dtype=torch.float32, device=device)
    best_loss = np.inf
    best_coords = None
    best_typeindices = typeindices  #save in case number of atoms changes
    goodcoords = False
    bestagrid = torch.zeros(values.shape, dtype=torch.float32, device=device)

    if len(initcoords) == 0:  #no atoms
        mol = AtomStruct(np.zeros((0, 3)),
                         np.zeros(0),
                         mgrid.channels,
                         L2_loss=values.square().sum() / values.numel(),
                         time=time.time() - t_start,
                         iterations=0,
                         numfixes=0,
                         type_diff=0,
                         est_type_diff=0,
                         visited_structs=[])
        return mol, bestagrid

    for inum in range(iters):
        optimizer = torch.optim.LBFGS([coords],
                                      max_iter=20000,
                                      tolerance_grad=1e-9,
                                      line_search_fn='strong_wolfe')

        def closure():
            optimizer.zero_grad()
            agrid = gridder.forward(coords, types, radii)
            loss = torch.square(agrid - values).sum() / numatoms
            loss.backward()
            return loss

        optimizer.step(closure)
        final_loss = optimizer.state_dict()['state'][0][
            'prev_loss']  #todo - check for convergence?

        if final_loss < best_loss:
            best_loss = final_loss
            best_coords = coords.detach().cpu()

        if inum == iters - 1:  #stick with these coordinates
            break
        #otherwise, try different starting coordinates for only those
        #atom types that have errors
        goodcoords = True
        with torch.no_grad():
            offset = 0
            agrid = gridder.forward(coords, types, radii)
            t = 0
            while offset < len(typeindices):
                t = typeindices[offset]
                #eval max error - mse will downplay a single atom of many being off
                maxerr = float(torch.square(agrid[t] - values[t]).max())
                if maxerr > tol:
                    goodcoords = False
                    ch = mgrid.channels[t]
                    newcoords = select_atom_starts(mgrid, values[t],
                                                   ch.atomic_radius)
                    for (i, coord) in enumerate(newcoords):
                        coords[i + offset] = torch.tensor(coord,
                                                          dtype=torch.float)
                offset += tcnts[t]
        if goodcoords:
            break
    bestagrid = agrid.clone()
    numfixes = 0
    if not goodcoords:
        #try to fix up an atom at a time
        offset = 0
        #reset corods to best found so far
        with torch.no_grad():
            coords[:] = best_coords
            agrid = gridder.forward(coords, types, radii)
        t = 0
        while offset < len(typeindices):
            t = typeindices[offset]
            maxerr = float(torch.square(agrid[t] - values[t]).max())
            #print('maxerr',maxerr)
            per_atom_volume = float(radii[offset])**3 * ((2 * np.pi)**1.5)
            while maxerr > tol:
                #identify the atom of this type closest to the place with too much density
                #and move it to the location with too little density
                tcoords = coords[offset:offset + tcnts[t]].detach().cpu(
                ).numpy()  #coordinates for this type

                diff = agrid[t] - values[t]
                possum = float(diff[diff > 0].sum())
                negsum = float(diff[diff < 0].sum())
                maxdiff = float(diff.max())
                mindiff = float(diff.min())
                missing_density = -(negsum + possum)
                #print('Type %d numcoords %d maxdiff %.5f mindiff %.5f missing %.5f'%(t,len(tcoords),maxdiff,mindiff,missing_density))
                if missing_density > .25 * per_atom_volume:  #add atom  MAGIC NUMBER ALERT
                    #needs to be enough total missing density to be close to a whole atom,
                    #but the missing density also needs to be somewhat concentrated
                    #print("Missing density - not enough atoms?")
                    numfixes += 1
                    minpos = int((agrid[t] - values[t]).argmin())
                    minpos = grid_to_xyz(
                        np.unravel_index(minpos, agrid[t].shape), mgrid)
                    #add atom: change coords, types, radii, typeindices and tcnts, numatoms
                    numatoms += 1
                    typeindices = np.insert(typeindices, offset, t)
                    tcnts[t] += 1
                    with torch.no_grad():
                        newcoord = torch.tensor([minpos],
                                                device=coords.device,
                                                dtype=coords.dtype,
                                                requires_grad=True)
                        coords = torch.cat(
                            (coords[:offset], newcoord, coords[offset:]))
                        radii = torch.cat(
                            (radii[:offset], radii[offset:offset + 1],
                             radii[offset:]))
                        types = torch.cat(
                            (types[:offset], types[offset:offset + 1],
                             types[offset:]))

                        coords.requires_grad_(True)
                        radii.requires_grad_(True)
                        types.requires_grad_(True)

                elif missing_density < -.75 * per_atom_volume:
                    print("Too many atoms?")
                    break
                    #todo, remove atom
                else:  #move an atom
                    numfixes += 1
                    maxpos = int((agrid[t] - values[t]).argmax())
                    minpos = int((agrid[t] - values[t]).argmin())
                    maxpos = grid_to_xyz(
                        np.unravel_index(maxpos, agrid[t].shape), mgrid)
                    minpos = grid_to_xyz(
                        np.unravel_index(minpos, agrid[t].shape), mgrid)

                    dists = np.square(tcoords - maxpos).sum(axis=1)
                    closesti = np.argmin(dists)
                    with torch.no_grad():
                        coords[offset + closesti] = torch.tensor(minpos)

                #reoptimize
                optimizer = torch.optim.LBFGS([coords],
                                              max_iter=20000,
                                              tolerance_grad=1e-9,
                                              line_search_fn='strong_wolfe')
                #TODO: only optimize this grid
                optimizer.step(closure)
                final_loss = optimizer.state_dict()['state'][0][
                    'prev_loss']  #todo - check for convergence?
                agrid = gridder.forward(coords, types, radii)  #recompute grid

                #if maxerr hasn't improved, give up
                newerr = float(torch.square(agrid[t] - values[t]).max())
                #print(t,'newerr',newerr,'maxerr',maxerr,'maxdiff',maxdiff,'mindiff',mindiff,'missing',missing_density)
                if newerr >= maxerr:
                    #don't give up if there's still a lot left to fit
                    #and the missing density isn't all (very) shallow
                    if missing_density < per_atom_volume or mindiff > -0.1:  #magic number!
                        break
                else:
                    maxerr = newerr
                    best_loss = final_loss
                    best_coords = coords.detach().cpu()
                    best_typeindices = typeindices.copy()
                    bestagrid = agrid.clone()

                #otherwise update coordinates and repeat

            offset += tcnts[t]

    #create struct from coordinates
    mol = AtomStruct(best_coords.numpy(),
                     best_typeindices,
                     mgrid.channels,
                     L2_loss=float(best_loss),
                     time=time.time() - t_start,
                     iterations=inum,
                     numfixes=numfixes,
                     type_diff=0,
                     est_type_diff=0,
                     visited_structs=[])
    # print('losses',final_loss,best_loss,len(best_coords))
    return mol, bestagrid
Ejemplo n.º 4
0
def simple_atom_fit(mgrid, types, iters=10, tol=0.01):
    '''Fit atoms to MolGrid.  types are ignored as the number of 
    atoms of each type is always inferred from the density.
    Returns the MolGrid of the placed atoms and the MolStruct'''
    t_start = time.time()

    # mtr22 - match the input API of generate.AtomFitter.fit
    mgrid = generate.MolGrid(
        values=torch.as_tensor(mgrid.values, device=device),
        channels=mgrid.channels,
        center=mgrid.center,
        resolution=mgrid.resolution,
    )

    #for every channel, select some coordinates and setup the type/radius vectors
    initcoords = []
    typevecs = []
    radii = []
    typeindices = []
    numatoms = 0
    tcnts = {}
    types_est = []  # mtr22
    for (t, G) in enumerate(mgrid.values):
        ch = mgrid.channels[t]
        #print(ch)
        coords = select_atom_starts(mgrid, G, ch.atomic_radius)
        if coords:
            tvec = np.zeros(len(mgrid.channels))
            tvec[t] = 1.0
            tcnt = len(coords)
            numatoms += tcnt
            types_est.append(tcnt)  #mtr22

            r = mgrid.channels[t].atomic_radius
            initcoords += coords
            typevecs += [tvec] * tcnt
            typeindices += [t] * tcnt
            radii += [r] * tcnt
            tcnts[t] = tcnt
        else:
            types_est.append(0)  #mtr22

    typevecs = np.array(typevecs)
    initcoords = np.array(initcoords)
    typeindices = np.array(typeindices)

    # mtr22 - for computing type_diff metrics in returned molstruct
    types_true = torch.tensor(types, dtype=torch.float32, device=device)
    types_est = torch.tensor(types_est, dtype=torch.float32, device=device)

    #print(types_est)

    #setup gridder
    gridder = molgrid.Coords2Grid(molgrid.GridMaker(
        dimension=mgrid.dimension,
        resolution=mgrid.resolution,
        gaussian_radius_multiple=-1.5),
                                  center=tuple(mgrid.center.astype(float)))
    mgrid.values = mgrid.values.to(device)

    #having setup input coordinates, optimize with BFGS
    coords = torch.tensor(initcoords,
                          dtype=torch.float32,
                          requires_grad=True,
                          device=device)
    types = torch.tensor(typevecs, dtype=torch.float32, device=device)
    radii = torch.tensor(radii, dtype=torch.float32, device=device)
    best_loss = np.inf
    best_coords = None
    best_typeindices = typeindices  #save in case number of atoms changes
    goodcoords = False

    for inum in range(iters):
        optimizer = torch.optim.LBFGS([coords],
                                      max_iter=20000,
                                      tolerance_grad=1e-9,
                                      line_search_fn='strong_wolfe')

        def closure():
            optimizer.zero_grad()
            agrid = gridder.forward(coords, types, radii)
            loss = torch.square(agrid - mgrid.values).sum() / numatoms
            loss.backward()
            return loss

        optimizer.step(closure)
        final_loss = optimizer.state_dict()['state'][0][
            'prev_loss']  #todo - check for convergence?

        print('iter {} (loss={}, n_atoms={})'.format(inum, final_loss,
                                                     len(best_typeindices)))

        if final_loss < best_loss:
            best_loss = final_loss
            best_coords = coords.detach()

        if inum == iters - 1:  #stick with these coordinates
            break
        #otherwise, try different starting coordinates for only those
        #atom types that have errors
        goodcoords = True
        with torch.no_grad():
            offset = 0
            agrid = gridder.forward(coords, types, radii)
            t = 0
            while offset < len(typeindices):
                t = typeindices[offset]
                #eval max error - mse will downplay a single atom of many being off
                maxerr = float(torch.square(agrid[t] - mgrid.values[t]).max())
                if maxerr > tol:
                    goodcoords = False
                    ch = mgrid.channels[t]
                    newcoords = select_atom_starts(mgrid, mgrid.values[t],
                                                   ch.atomic_radius)
                    for (i, coord) in enumerate(newcoords):
                        coords[i + offset] = torch.tensor(coord,
                                                          dtype=torch.float)
                offset += tcnts[t]
        if goodcoords:
            break

    numfixes = 0
    fix_iter = 0
    if not goodcoords:
        #try to fix up an atom at a time
        offset = 0
        #reset corods to best found so far
        with torch.no_grad():
            coords[:] = best_coords
            agrid = gridder.forward(coords, types, radii)
        t = 0
        while offset < len(typeindices):
            t = typeindices[offset]
            maxerr = float(torch.square(agrid[t] - mgrid.values[t]).max())
            per_atom_volume = float(radii[offset])**3 * ((2 * np.pi)**1.5)
            while maxerr > tol:
                #identify the atom of this type closest to the place with too much density
                #and move it to the location with too little density
                tcoords = coords[offset:offset + tcnts[t]].detach().cpu(
                ).numpy()  #coordinates for this type

                diff = agrid[t] - mgrid.values[t]
                possum = float(diff[diff > 0].sum())
                negsum = float(diff[diff < 0].sum())
                maxdiff = float(diff.max())
                mindiff = float(diff.min())
                missing_density = -(negsum + possum)
                if missing_density > .75 * per_atom_volume:  #add atom
                    print("Missing density - not enough atoms?")
                    numfixes += 1
                    minpos = int((agrid[t] - mgrid.values[t]).argmin())
                    minpos = grid_to_xyz(
                        np.unravel_index(minpos, agrid[t].shape), mgrid)
                    #add atom: change coords, types, radii, typeindices and tcnts, numatoms
                    numatoms += 1
                    typeindices = np.insert(typeindices, offset, t)
                    tcnts[t] += 1
                    with torch.no_grad():
                        newcoord = torch.tensor([minpos],
                                                device=coords.device,
                                                dtype=coords.dtype,
                                                requires_grad=True)
                        coords = torch.cat(
                            (coords[:offset], newcoord, coords[offset:]))
                        radii = torch.cat(
                            (radii[:offset], radii[offset:offset + 1],
                             radii[offset:]))
                        types = torch.cat(
                            (types[:offset], types[offset:offset + 1],
                             types[offset:]))

                        coords.requires_grad_(True)
                        radii.requires_grad_(True)
                        types.requires_grad_(True)

                elif mindiff**2 < tol:
                    print("No significant density underage - too many atoms?")
                    break
                    #todo, remove atom
                else:  #move an atom
                    numfixes += 1
                    maxpos = int((agrid[t] - mgrid.values[t]).argmax())
                    minpos = int((agrid[t] - mgrid.values[t]).argmin())
                    maxpos = grid_to_xyz(
                        np.unravel_index(maxpos, agrid[t].shape), mgrid)
                    minpos = grid_to_xyz(
                        np.unravel_index(minpos, agrid[t].shape), mgrid)

                    dists = np.square(tcoords - maxpos).sum(axis=1)
                    closesti = np.argmin(dists)
                    with torch.no_grad():
                        coords[offset + closesti] = torch.tensor(minpos)

                #reoptimize
                optimizer = torch.optim.LBFGS([coords],
                                              max_iter=20000,
                                              tolerance_grad=1e-9,
                                              line_search_fn='strong_wolfe')
                #TODO: only optimize this grid
                optimizer.step(closure)
                final_loss = optimizer.state_dict()['state'][0][
                    'prev_loss']  #todo - check for convergence?
                agrid = gridder.forward(coords, types, radii)  #recompute grid

                #if maxerr hasn't improved, give up
                newerr = float(torch.square(agrid[t] - mgrid.values[t]).max())
                fix_iter += 1
                print(
                    'fix_iter {} (loss={}, n_atoms={}, newerr={}, numfixes={})'
                    .format(fix_iter, final_loss, len(typeindices), newerr,
                            numfixes))

                if newerr >= maxerr:
                    break
                else:
                    maxerr = newerr
                    best_loss = final_loss
                    best_coords = coords.detach()
                    best_typeindices = typeindices.copy()

                #otherwise update coordinates and repeat

            offset += tcnts[t]

    # mtr22 - match the output API of generate.AtomFitter.fit
    n_atoms = len(best_typeindices)
    n_channels = len(mgrid.channels)
    best_types = torch.zeros((n_atoms, n_channels),
                             dtype=torch.float32,
                             device=device)
    best_radii = torch.zeros((n_atoms, ), dtype=torch.float32, device=device)
    for i, t in enumerate(best_typeindices):
        ch = mgrid.channels[t]
        best_types[i, t] = 1.0
        best_radii[i] = ch.atomic_radius

    #create struct and grid from coordinates
    struct_best = generate.MolStruct(
        xyz=best_coords.cpu().numpy(),
        c=best_typeindices,
        channels=mgrid.channels,
        loss=float(best_loss),
        type_diff=(types_est - best_types.sum(dim=0)).abs().sum().item(),
        est_type_diff=(types_true - types_est).abs().sum().item(),
        time=time.time() - t_start,
        n_steps=numfixes,
    )

    grid_pred = generate.MolGrid(
        values=gridder.forward(best_coords, best_types,
                               best_radii).cpu().detach().numpy(),
        channels=mgrid.channels,
        center=mgrid.center,
        resolution=mgrid.resolution,
        visited_structs=[],
        src_struct=struct_best,
    )

    return grid_pred