Пример #1
0
    def setUp(self):
        self.p2c = PDB2CoordsUnordered()
        pdb_text = """ATOM      0    N GLN A   0       0.000   0.000   0.000
ATOM      1   CA GLN A   0       0.526   0.000  -1.362
ATOM      2   CB GLN A   0       1.349  -1.258  -1.588
ATOM      3   CG GLN A   0       1.488  -1.511  -3.080
ATOM      4   CD GLN A   0       2.281  -0.380  -3.716
ATOM      5  OE1 GLN A   0       3.338  -0.604  -4.304
ATOM      6  NE2 GLN A   0       1.770   0.842  -3.597
ATOM      7    C GLN A   0      -0.471   0.000  -2.516
ATOM      8    O GLN A   0      -1.468  -1.003  -2.607
ATOM      9    N THR A   1      -0.434   0.934  -3.463
ATOM     10   CA THR A   1      -1.494   0.728  -4.446
ATOM     11   CB THR A   1      -2.120   2.065  -4.806
ATOM     12  OG1 THR A   1      -1.960   2.961  -3.719
ATOM     13  CG2 THR A   1      -1.286   2.785  -5.863
ATOM     14    C THR A   1      -1.115   0.124  -5.794
ATOM     15    O THR A   1      -0.118   0.730  -6.599
ATOM     16    N ALA A   2      -1.703  -0.979  -6.250
ATOM     17   CA ALA A   2      -1.162  -1.365  -7.550
ATOM     18   CB ALA A   2      -1.839  -2.641  -8.022
ATOM     19    C ALA A   2      -1.340  -0.399  -8.717
ATOM     20    O ALA A   2      -2.433   0.502  -8.744
ATOM     21    N ALA A   3      -0.482  -0.375  -9.733
ATOM     22   CA ALA A   3      -0.858   0.638 -10.714
ATOM     23   CB ALA A   3      -0.125   0.377 -12.020
ATOM     24    C ALA A   3      -2.324   0.730 -11.127
ATOM     25    O ALA A   3      -3.026   1.956 -11.020"		
"""
        with open("test.pdb", "w") as fout:
            fout.write(pdb_text)
	def __init__(self, box_size=80, resolution=1.5):
		
		self.box_size = box_size
		self.resolution = resolution
		self.box_length = box_size*resolution
		self.rotate = CoordsTransform.CoordsRotate()
		self.translate = CoordsTransform.CoordsTranslate()
		
		self.box_center = torch.zeros(1, 3, dtype=torch.double, device='cpu')
		self.box_center.fill_(self.box_length/2.0)

		self.pdb2coords = PDB2CoordsUnordered()
Пример #3
0
    def __init__(self,
                 model,
                 loss,
                 lr=0.001,
                 lr_decay=0.0001,
                 box_size=120,
                 resolution=1.0,
                 add_neg=False,
                 neg_weight=0.5,
                 add_zero=False,
                 zero_weight=1.0,
                 randomize_rot=True):
        self.lr = lr
        self.lr_decay = lr_decay
        self.model = model
        self.loss = loss
        self.optimizer = optim.Adam(self.model.parameters(), lr=self.lr)
        self.log = None
        self.lr_scheduler = LambdaLR(
            self.optimizer, lambda epoch: 1.0 / (1.0 + epoch * self.lr_decay))

        #zero-score condition
        self.add_zero = add_zero
        self.zero_weight = zero_weight

        #negative-score condition
        self.add_neg = add_neg
        self.neg_weight = neg_weight

        self.box_length = box_size * resolution
        self.box_size = box_size
        self.resolution = resolution

        self.pdb2coords = PDB2CoordsUnordered()
        self.assignTypes = Coords2TypedCoords()
        self.translate = CoordsTranslate()
        self.rotate = CoordsRotate()
        self.project = TypedCoords2Volume(self.box_size, self.resolution)

        self.relu = nn.ReLU()
        self.randomize_rot = randomize_rot

        atexit.register(self.cleanup)
Пример #4
0
    def __init__(self,
                 docking_model,
                 angle_inc=15.0,
                 box_size=80,
                 resolution=1.25,
                 max_conf=1000,
                 randomize_rot=False):
        self.docking_model = docking_model
        self.log = None

        self.box_size = box_size
        self.resolution = resolution
        self.box_length = box_size * resolution

        self.max_conf = max_conf
        self.rot = Rotations(angle_inc=angle_inc)

        self.rotate = CoordsTransform.CoordsRotate()
        self.translate = CoordsTransform.CoordsTranslate()
        self.project = TypedCoords2Volume(self.box_size, self.resolution)
        self.convolve = VolumeConvolution()

        self.box_center = torch.zeros(1, 3, dtype=torch.double, device='cpu')
        self.box_center.fill_(self.box_length / 2.0)

        self.pdb2coords = PDB2CoordsUnordered()
        self.assignTypes = Coords2TypedCoords()
        self.translation = CoordsTranslate()
        self.vol_rotate = VolumeRotation()

        self.randomize_rot = randomize_rot
        if self.randomize_rot:
            self.randR = getRandomRotation(1)
            print("Adding random rotation to the receptor:", self.randR)

        atexit.register(self.cleanup)
	def save_clusters(self, complex_filename, receptor_path, ligand_path, num_clusters=10):
		pdb2coords = PDB2CoordsUnordered()
		num_conf = len(self.top_list)
		
		lcoords, lchainnames, lresnames, lresnums, latomnames, lnum_atoms = self.load_protein([ligand_path], True)
		rcoords, rchainnames, rresnames, rresnums, ratomnames, rnum_atoms = self.load_protein([receptor_path], True)
		
		cluster_num = 0
		for i in range(self.num_conf):
			if self.cluster[i]<cluster_num:
				continue

			elif self.cluster[i]==cluster_num:
				r, t, score = self.top_list[i]
				lcoords_rot = self.rotate(lcoords, r, lnum_atoms)
				lcoords_rot_trans = self.translate(lcoords_rot, t, lnum_atoms)
				otput_filename = complex_filename + '_%d.pdb'%cluster_num
				writePDB(otput_filename, rcoords, rchainnames, rresnames, rresnums, ratomnames, rnum_atoms, add_model=False, rewrite=True)
				writePDB(otput_filename, lcoords_rot_trans, lchainnames, lresnames, lresnums, latomnames, lnum_atoms, add_model=False, rewrite=False)
				
				cluster_num += 1
			
			else:
				raise(Exception("Wrong cluster number"))
Пример #6
0
	def __init__(self, decoys_dir):
		self.p2c = PDB2CoordsUnordered()
		self.rotate = CoordsRotate()
		self.translate = CoordsTranslate()
		self.decoys_dir = decoys_dir
Пример #7
0
def get_irmsd(benchmark,
              parser,
              num_conf=1000,
              debug_plot={
                  "targets": [],
                  "num_decoys": 5,
                  "dir": "Fig"
              }):
    result = {}
    p2c = PDB2CoordsUnordered()
    skip = parser.get_problematic_targets()

    for n, target_name in enumerate(benchmark.get_target_names()):
        if target_name in skip:
            print("Skipping prediction for", target_name, n)
            continue
        print("Processing prediction for", target_name, n)
        target = benchmark.get_target(target_name)

        bound_target, unbound_target = benchmark.parse_structures(target)
        interfaces = benchmark.get_unbound_interfaces(bound_target,
                                                      unbound_target)

        res = parser.parse_output(target_name, header_only=False)
        if res is None:
            print("Skipping prediction for", target_name, n)
            continue

        unbound_receptor = parser.load_protein(
            [unbound_target["receptor"]["path"]])
        unbound_receptor_ = ProteinStructure(*unbound_receptor)
        unbound_receptor = ProteinStructure(*unbound_receptor)
        unbound_receptor.set(*unbound_receptor_.select_CA())

        unbound_ligand = parser.load_protein(
            [unbound_target["ligand"]["path"]])
        unbound_ligand_ = ProteinStructure(*unbound_ligand)
        unbound_ligand = ProteinStructure(*unbound_ligand)
        unbound_ligand.set(*unbound_ligand.select_CA())

        #This interface will be rotated later
        unbound_interfaces = []
        for urec_sel, ulig_sel, brec_sel, blig_sel in interfaces:
            rec = ProteinStructure(
                *unbound_receptor.select_residues_list(urec_sel))
            lig = ProteinStructure(
                *unbound_ligand.select_residues_list(ulig_sel))
            unbound_interfaces.append((rec, lig))

        bound_receptor = ProteinStructure(
            *p2c([bound_target["receptor"]["path"]]))
        bound_receptor.set(*bound_receptor.select_CA())
        bound_ligand = ProteinStructure(*p2c([bound_target["ligand"]["path"]]))
        bound_ligand.set(*bound_ligand.select_CA())

        #This interface is static
        bound_interfaces = []
        for urec_sel, ulig_sel, brec_sel, blig_sel in interfaces:
            rec = ProteinStructure(
                *bound_receptor.select_residues_list(brec_sel))
            lig = ProteinStructure(
                *bound_ligand.select_residues_list(blig_sel))
            cplx = unite_proteins(rec, lig)
            bound_interfaces.append(cplx)

        c2rmsd = Coords2RMSD()
        result[target_name] = []
        Nplotted = 0
        for i in range(num_conf):

            #Plotting transformed unbound structures
            if (target_name in debug_plot["targets"]) and (
                    Nplotted < debug_plot["num_decoys"]):
                fig = plt.figure()
                axis = p3.Axes3D(fig)
                cmap = matplotlib.cm.get_cmap('Set1')
                new_unbound_ligand_ = ProteinStructure(
                    *parser.transform_ligand(unbound_ligand_.get(), i))
                unbound_receptor_.plot_coords(axis,
                                              type='line',
                                              args={"color": "blue"})
                new_unbound_ligand_.plot_coords(axis,
                                                type='line',
                                                args={"color": "red"})

            all_rmsd = []
            for rec, lig in unbound_interfaces:
                new_lig = ProteinStructure(
                    *parser.transform_ligand(lig.get(), i))
                mobile_cplx = unite_proteins(rec, new_lig)
                for static_cplx in bound_interfaces:
                    all_rmsd.append(
                        c2rmsd(mobile_cplx.coords, static_cplx.coords,
                               static_cplx.numatoms).item())

                    #Plotting modile interface
                    if (target_name in debug_plot["targets"]) and (
                            Nplotted < debug_plot["num_decoys"]):
                        mobile_cplx.plot_coords(axis,
                                                type='scatter',
                                                args={"color": "yellow"})

            #Plotting unbount decoy and all unbound interfaces
            if (target_name in debug_plot["targets"]) and (
                    Nplotted < debug_plot["num_decoys"]):
                output_filename = os.path.join(
                    debug_plot["dir"], target_name + '%d.png' % Nplotted)
                plt.savefig(output_filename)
                Nplotted += 1

            min_rmsd = min(all_rmsd)
            result[target_name].append(min_rmsd)

    return result
    def cluster_decoys(self,
                       ligand_path,
                       num_clusters=10,
                       cluster_threshold=15.0):

        pdb2coords = PDB2CoordsUnordered()
        rmsd = Coords2RMSD()
        num_conf = len(self.top_list)

        lcoords, lchainnames, lresnames, lresnums, latomnames, lnum_atoms = pdb2coords(
            [ligand_path])
        a, b = getBBox(lcoords, lnum_atoms)
        lcoords = self.translate(lcoords, -(a + b) * 0.5, lnum_atoms)
        t = torch.zeros(1, 3, dtype=torch.double, device='cpu')

        N = lnum_atoms[0].item()
        is0C = torch.eq(latomnames[:, :, 0], 67).squeeze()
        is1A = torch.eq(latomnames[:, :, 1], 65).squeeze()
        is20 = torch.eq(latomnames[:, :, 2], 0).squeeze()
        isCA = is0C * is1A * is20
        num_ca_atoms = isCA.sum().item()
        num_atoms_single = torch.zeros(1, dtype=torch.int,
                                       device='cpu').fill_(num_ca_atoms)

        lcoords.resize_(1, N, 3)
        ca_x = torch.masked_select(lcoords[:, :, 0], isCA)[:num_ca_atoms]
        ca_y = torch.masked_select(lcoords[:, :, 1], isCA)[:num_ca_atoms]
        ca_z = torch.masked_select(lcoords[:, :, 2], isCA)[:num_ca_atoms]
        ca_coords = torch.stack([ca_x, ca_y, ca_z],
                                dim=1).resize_(1,
                                               num_ca_atoms * 3).contiguous()

        lrmsd = np.zeros((num_conf, num_conf))
        cluster_num = 0
        for i in range(num_conf):
            if self.cluster[i] > -1:
                continue
            else:
                self.cluster[i] = cluster_num
                print("Found %d cluster focus %d" % (cluster_num, i))

            ind, ix, iy, iz, score = self.top_list[i]
            r = self.rot.R[ind, :, :].unsqueeze(dim=0)
            t[0, 0] = ix
            t[0, 1] = iy
            t[0, 2] = iz
            if ix >= self.box_size:
                t[0, 0] = -(2 * self.box_size - ix)
            if iy >= self.box_size:
                t[0, 1] = -(2 * self.box_size - iy)
            if iz >= self.box_size:
                t[0, 2] = -(2 * self.box_size - iz)

            ca_rot = self.rotate(ca_coords, r, num_atoms_single)
            ca_rot_trans_i = self.translate(ca_rot, t * self.resolution,
                                            num_atoms_single)

            for j in range(num_conf):
                if self.cluster[j] > -1:
                    continue

                ind, ix, iy, iz, score = self.top_list[j]
                r = self.R[ind, :, :].unsqueeze(dim=0)
                t[0, 0] = ix
                t[0, 1] = iy
                t[0, 2] = iz
                if ix >= self.box_size:
                    t[0, 0] = -(2 * self.box_size - ix)
                if iy >= self.box_size:
                    t[0, 1] = -(2 * self.box_size - iy)
                if iz >= self.box_size:
                    t[0, 2] = -(2 * self.box_size - iz)

                ca_rot = self.rotate(ca_coords, r, num_atoms_single)
                ca_rot_trans_j = self.translate(ca_rot, t * self.resolution,
                                                num_atoms_single)
                rmsd2 = ((ca_rot_trans_i - ca_rot_trans_j) *
                         (ca_rot_trans_i - ca_rot_trans_j)).sum()
                lrmsd = torch.sqrt(rmsd2 / num_ca_atoms).item()
                if lrmsd < cluster_threshold:
                    self.cluster[j] = cluster_num

            cluster_num += 1
            if cluster_num > num_clusters:
                break
    def save_clusters(self,
                      complex_filename,
                      receptor_path,
                      ligand_path,
                      num_clusters=10):
        pdb2coords = PDB2CoordsUnordered()
        num_conf = len(self.top_list)

        lcoords, lchainnames, lresnames, lresnums, latomnames, lnum_atoms = pdb2coords(
            [ligand_path])
        a, b = getBBox(lcoords, lnum_atoms)
        lcoords = self.translate(lcoords, -(a + b) * 0.5, lnum_atoms)
        t = torch.zeros(1, 3, dtype=torch.double, device='cpu')

        rcoords, rchainnames, rresnames, rresnums, ratomnames, rnum_atoms = pdb2coords(
            [receptor_path])
        a, b = getBBox(rcoords, rnum_atoms)
        rcoords = self.translate(rcoords, -(a + b) * 0.5, rnum_atoms)

        cluster_num = 0
        for i in range(num_conf):
            if self.cluster[i] < cluster_num:
                continue

            elif self.cluster[i] == cluster_num:
                ind, ix, iy, iz, score = self.top_list[i]
                r = self.R[ind, :, :].unsqueeze(dim=0)
                lcoords_rot = self.rotate(lcoords, r, lnum_atoms)
                t[0, 0] = ix
                t[0, 1] = iy
                t[0, 2] = iz
                if ix >= self.box_size:
                    t[0, 0] = -(2 * self.box_size - ix)
                if iy >= self.box_size:
                    t[0, 1] = -(2 * self.box_size - iy)
                if iz >= self.box_size:
                    t[0, 2] = -(2 * self.box_size - iz)
                print(ind, t * self.resolution, score)
                lcoords_rot_trans = self.translate(lcoords_rot,
                                                   t * self.resolution,
                                                   lnum_atoms)
                otput_filename = complex_filename + '_%d.pdb' % cluster_num
                writePDB(otput_filename,
                         rcoords,
                         rchainnames,
                         rresnames,
                         rresnums,
                         ratomnames,
                         rnum_atoms,
                         add_model=False,
                         rewrite=True)
                writePDB(otput_filename,
                         lcoords_rot_trans,
                         lchainnames,
                         lresnames,
                         lresnums,
                         latomnames,
                         lnum_atoms,
                         add_model=False,
                         rewrite=False)

                cluster_num += 1

            else:
                raise (Exception("Wrong cluster number"))
	cres_names = torch.cat([lres_names, rres_names], dim=1).contiguous()
	cres_nums = torch.cat([lres_nums, rres_nums], dim=1).contiguous()
	catom_names = torch.cat([latom_names, ratom_names], dim=1).contiguous()
	cnum_atoms = lnum_atoms + rnum_atoms

	complex = ProteinStructure(ccoords, cchains, cres_names, cres_nums, catom_names, cnum_atoms)
	return complex


if __name__=='__main__':
	benchmark_dir = os.path.join(DATA_DIR, "DockingBenchmarkV4")
	benchmark_list = os.path.join(benchmark_dir, "TableS1.csv")
	benchmark_structures = os.path.join(benchmark_dir, "structures")
	benchmark = DockingBenchmark(benchmark_dir, benchmark_list, benchmark_structures)

	p2c = PDB2CoordsUnordered()
	target_name = '1A2K'
	target = benchmark.get_target(target_name)
	bound_complex, unbound_complex = benchmark.parse_structures(target)
		
	print('Bound receptor:', target["complex"]["chain_rec"], ''.join(list(bound_complex["receptor"]["chains"].keys())))
	print('Bound ligand:', target["complex"]["chain_lig"], ''.join(list(bound_complex["ligand"]["chains"].keys())))
	print('Unbound receptor:', target["receptor"]["chain"], ''.join(list(unbound_complex["receptor"]["chains"].keys())))
	print('Unbound ligand:', target["ligand"]["chain"], ''.join(list(unbound_complex["ligand"]["chains"].keys())))
	
	bound_receptor = ProteinStructure(*p2c([bound_complex["receptor"]["path"]]))
	bound_receptor.set(*bound_receptor.select_CA())
	bound_ligand = ProteinStructure(*p2c([bound_complex["ligand"]["path"]]))
	bound_ligand.set(*bound_ligand.select_CA())

	fig = plt.figure()