示例#1
0
    def __getitem__(self, ndx):
        if self.rand_rot and self.train:
            R = rand_rot_mtx(self.data.align)
        else:
            R = np.eye(3, dtype=np.float32)

        d = self.elems[ndx]


        inp_coords_intra = d['inp_positions_intra']
        inp_blobbs_intra = voxelize_gauss(inp_coords_intra, self.sigma_inp, self.grid)
        inp_features_intra = d['inp_intra_featvec'][:, :, None, None, None] * inp_blobbs_intra[:, None, :, :, :]

        features = np.sum(inp_features_intra, 0)
        #inp_features_intra = inp_blobbs_intra

        inp_coords = inp_coords_intra
        out_coords_inter = np.zeros((1,3))
        #if d['inp_positions_inter']:
        if self.n_env_mols:
            inp_coords_inter = np.dot(d['inp_positions_inter'], R.T)
            inp_blobbs_inter = voxelize_gauss(inp_coords_inter, self.sigma_inp, self.grid)
            inp_features_inter = d['inp_inter_featvec'][:, :, None, None, None] * inp_blobbs_inter[:, None, :, :, :]
            inp_features_inter = np.sum(inp_features_inter, 0)
            features = np.concatenate((features, inp_features_inter), 0)

            inp_coords = np.concatenate((inp_coords_intra, inp_coords_inter), 0)

            if self.out_env:
                out_coords_inter = np.dot(d['out_positions_inter'], R.T)
                out_blobbs_inter = voxelize_gauss(out_coords_inter, self.sigma_out, self.grid)
                out_features_inter = d['out_inter_featvec'][:, :, None, None, None] * out_blobbs_inter[:, None, :, :, :]
                out_features_inter = np.sum(out_features_inter, 0)
                features = np.concatenate((features, out_features_inter), 0)

        out_coords_intra = d['out_positions_intra']
        out_positions_intra = voxelize_gauss(np.dot(d['out_positions_intra'], R.T), self.sigma_out, self.grid)
        target = out_positions_intra
        #print(target)
        #print(target.shape)
        #print(features.shape)

        energy_ndx_inp = (d['inp_bond_ndx'], d['inp_ang_ndx'], d['inp_dih_ndx'], d['inp_lj_intra_ndx'], d['inp_lj_ndx'])
        energy_ndx_out = (d['out_bond_ndx'], d['out_ang_ndx'], d['out_dih_ndx'], d['out_lj_intra_ndx'],  d['out_lj_ndx'])

        elems = (inp_coords_intra, out_coords_intra)

        return elems, energy_ndx_inp, energy_ndx_out
示例#2
0
    def __getitem__(self, ndx):
        if self.rand_rot and self.train:
            R = rand_rot_mtx(self.data.align)
        else:
            R = np.eye(3, dtype=np.float32)

        d_out = self.elems_out[ndx]
        d_inp = self.elems_inp[ndx]

        targets = voxelize_gauss(np.dot(d_out['targets'], R.T), self.sigma, self.grid)
        atom_grid_out = voxelize_gauss(np.dot(d_out['positions'], R.T), self.sigma, self.grid)

        atom_grid_inp = voxelize_gauss(np.dot(d_inp['positions'], R.T), self.sigma, self.grid)

        elems = (targets, d_out['featvec'], d_out['repl'], d_inp['featvec'], d_inp['repl'])
        initial = (atom_grid_out, atom_grid_inp)
        energy_ndx_out = (d_out['bond_ndx'], d_out['angle_ndx'], d_out['dih_ndx'], d_out['lj_ndx'])
        energy_ndx_inp = (d_inp['bond_ndx'], d_inp['angle_ndx'], d_inp['dih_ndx'], d_inp['lj_ndx'])

        return elems, initial, energy_ndx_out, energy_ndx_inp
示例#3
0
    def __getitem__(self, ndx):
        if self.rand_rot:
            R = rand_rot_mtx(self.data.align)
        else:
            R = np.eye(3)

        d = self.elems[ndx]

        #item = self.array(self.elems[ndx][1:], np.float32)
        #target_pos, target_type, aa_feat, repl, mask, aa_pos, cg_pos, cg_feat, *energy_ndx = item
        #energy_ndx = self.array(energy_ndx, np.int64)

        target_atom = voxelize_gauss(np.dot(d['target_pos'], R.T), self.sigma,
                                     self.grid)
        atom_grid = voxelize_gauss(np.dot(d['aa_pos'], R.T), self.sigma,
                                   self.grid)
        bead_grid = voxelize_gauss(np.dot(d['cg_pos'], R.T), self.sigma,
                                   self.grid)

        cg_features = d['cg_feat'][:, :, None, None,
                                   None] * bead_grid[:, None, :, :, :]
        # (N_beads, N_chn, 1, 1, 1) * (N_beads, 1, N_x, N_y, N_z)
        cg_features = np.sum(cg_features, 0)

        elems = (target_atom, d['target_type'], d['repl'], d['mask'],
                 d['bonds_ndx_atom'], d['angles_ndx1_atom'],
                 d['angles_ndx2_atom'], d['dihs_ndx_atom'], d['ljs_ndx_atom'])
        initial = (atom_grid, cg_features)
        energy_ndx = (d['bonds_ndx'], d['angles_ndx'], d['dihs_ndx'],
                      d['ljs_ndx'])

        #print(":::_______________:::::::::")
        #print(d['target_pos'])

        #print(d['ljs_ndx'].shape)
        #energy_ndx = (bonds_ndx, angles_ndx, dihs_ndx, ljs_ndx)

        #return atom_grid, bead_grid, target_atom, target_type, aa_feat, repl, mask, energy_ndx
        #return atom_grid, cg_features, target_atom, d['target_type'], d['aa_feat'], d['repl'], d['mask'], energy_ndx, d['aa_pos']
        return elems, initial, energy_ndx
示例#4
0
    def __getitem__(self, ndx):
        if self.rand_rot and self.train:
            R = rand_rot_mtx(self.data.align)
        else:
            R = np.eye(3, dtype=np.float32)

        d = self.elems[ndx]


        targets = voxelize_gauss(np.dot(d['targets'], R.T), self.sigma, self.grid)
        atom_grid = voxelize_gauss(np.dot(d['positions'], R.T), self.sigma, self.grid)

        elems = (targets, d['featvec'], d['repl'])
        initial = (atom_grid)
        energy_ndx = (d['bond_ndx'], d['angle_ndx'], d['dih_ndx'], d['lj_ndx'])

        #print(d['ljs_ndx'].shape)
        #energy_ndx = (bonds_ndx, angles_ndx, dihs_ndx, ljs_ndx)

        #return atom_grid, bead_grid, target_atom, target_type, aa_feat, repl, mask, energy_ndx
        #return atom_grid, cg_features, target_atom, d['target_type'], d['aa_feat'], d['repl'], d['mask'], energy_ndx, d['aa_pos']
        return elems, initial, energy_ndx
示例#5
0
    def __getitem__(self, ndx):
        if self.rand_rot and self.train:
            R = rand_rot_mtx(self.data.align)
        else:
            R = np.eye(3, dtype=np.float32)

        d = self.elems[ndx]

        inp_coords_intra = np.dot(d['inp_positions_intra'], R.T)
        inp_blobbs_intra = voxelize_gauss(inp_coords_intra, self.sigma_inp,
                                          self.grid)
        inp_features_intra = d[
            'inp_intra_featvec'][:, :, None, None,
                                 None] * inp_blobbs_intra[:, None, :, :, :]

        gen_input = np.sum(inp_features_intra, 0)
        #inp_features_intra = inp_blobbs_intra

        out_positions_intra = voxelize_gauss(
            np.dot(d['out_positions_intra'], R.T), self.sigma_out, self.grid)
        target = out_positions_intra

        inp_coords = inp_coords_intra
        out_coords_inter = np.zeros((1, 3))
        #if d['inp_positions_inter']:

        inp_coords_inter = np.dot(d['inp_positions_inter'], R.T)
        inp_blobbs_inter = voxelize_gauss(inp_coords_inter, self.sigma_inp,
                                          self.grid)
        inp_features_inter = d[
            'inp_inter_featvec'][:, :, None, None,
                                 None] * inp_blobbs_inter[:, None, :, :, :]
        inp_features_inter = np.sum(inp_features_inter, 0)
        gen_input = np.concatenate((gen_input, inp_features_inter), 0)

        inp_coords = np.concatenate((inp_coords_intra, inp_coords_inter), 0)

        out_coords_inter = np.dot(d['out_positions_inter'], R.T)
        out_blobbs_inter = voxelize_gauss(out_coords_inter, self.sigma_out,
                                          self.grid)
        out_features_inter = d[
            'out_inter_featvec'][:, :, None, None,
                                 None] * out_blobbs_inter[:, None, :, :, :]
        out_features_inter = np.sum(out_features_inter, 0)
        crit_input_real = np.concatenate((target, out_features_inter), 0)

        #print(target)
        #print(target.shape)
        #print(features.shape)

        energy_ndx_inp = (d['inp_bond_ndx'], d['inp_ang_ndx'],
                          d['inp_dih_ndx'], d['inp_lj_ndx'])
        energy_ndx_out = (d['out_bond_ndx'], d['out_ang_ndx'],
                          d['out_dih_ndx'], d['out_lj_ndx'])
        """
        #print([a.type.name for a in d['inp_mol'].atoms])
        fig = plt.figure(figsize=(20, 20))
        n_chns = 4
        colours = ['red', 'black', 'green', 'blue']
        ax = fig.add_subplot(1, 1, 1, projection='3d')
        # ax.scatter(mol_inp.com[0], mol_inp.com[1],mol_inp.com[2], s=20, marker='o', color='blue', alpha=0.5)

        for i in range(0, self.resolution):
            for j in range(0, self.resolution):
                for k in range(0, self.resolution):
                    for n in range(0,1):
                        #ax.scatter(i,j,k, s=2, marker='o', color='black', alpha=min(target[n,i,j,k], 1.0))
                        if features[n,i,j,k] > 0.1:
                            ax.scatter(i,j,k, s=2, marker='o', color='black', alpha=min(features[n,i,j,k], 1.0))
        #carb_ndx = [1,4,7,10,13,16,19,22]
        carb_ndx = [0,1,2,3]
        print(inp_coords_intra)
        for z in range(0, len(inp_coords_intra)):
            if z in carb_ndx:
                ax.scatter(inp_coords_intra[z, 0]/self.delta_s + self.resolution/2-0.5, inp_coords_intra[z, 1]/self.delta_s+ self.resolution/2-0.5, inp_coords_intra[z, 2]/self.delta_s+ self.resolution/2-0.5, s=8, marker='o', color='red')
            #else:
            #    ax.scatter(inp_coords_intra[z, 0]/self.delta_s + self.resolution/2, inp_coords_intra[z, 1]/self.delta_s+ self.resolution/2, inp_coords_intra[z, 2]/self.delta_s+ self.resolution/2, s=4, marker='o', color='blue')
        """
        """
        f = d['inp_positions_intra']
        carb_ndx = [1,4,7,10,13,16,19,22]
        for z in range(0, len(f)):
            if z in carb_ndx:
                ax.scatter(f[z, 0], f[z, 1], f[z, 2], s=4, marker='o', color='red')
            else:
                ax.scatter(f[z, 0], f[z, 1], f[z, 2], s=4, marker='o', color='blue')
        



        ax.set_xlim3d(1.0, self.resolution)
        ax.set_ylim3d(1.0, self.resolution)
        ax.set_zlim3d(1.0, self.resolution)
        #ax.set_xticks(np.arange(-1, 1, step=0.5))
        #ax.set_yticks(np.arange(-1, 1, step=0.5))
        #ax.set_zticks(np.arange(-1, 1, step=0.5))
        #ax.tick_params(labelsize=6)
        #plt.plot([0.0, 0.0], [0.0, 0.0], [-1.0, 1.0])
        plt.show()
        """

        #print("features", features.dtype)
        #print("target", target.dtype)
        #print("inp_coords_intra", inp_coords_intra.dtype)
        #print("inp_coords", inp_coords.dtype)

        elems = (gen_input, crit_input_real, inp_features_inter,
                 inp_coords_intra, inp_coords, inp_coords_inter,
                 out_coords_inter)

        return elems, energy_ndx_inp, energy_ndx_out
示例#6
0
文件: gan2.py 项目: mstieffe/deepMap
    def __getitem__(self, ndx):
        if self.rand_rot and self.train:
            R = rand_rot_mtx(self.data.align)
        else:
            R = np.eye(3, dtype=np.float32)

        d = self.elems[ndx]

        aa_coords_intra = np.dot(d['aa_positions_intra'], R.T)
        aa_blobbs_intra = voxelize_gauss(aa_coords_intra, self.sigma_aa,
                                         self.grid)
        aa_features_intra = d[
            'aa_intra_featvec'][:, :, None, None,
                                None] * aa_blobbs_intra[:, None, :, :, :]
        aa_features_intra = np.sum(aa_features_intra, 0)
        #aa_features_intra = aa_blobbs_intra

        #print(np.dot(d['cg_positions_intra'], R.T))
        cg_positions_intra = voxelize_gauss(
            np.dot(d['cg_positions_intra'], R.T), self.sigma_cg, self.grid)

        #if d['aa_positions_inter']:
        if self.n_env_mols:
            #print(len(d['aa_positions_inter']))
            aa_coords_inter = np.dot(d['aa_positions_inter'], R.T)
            aa_coords = np.concatenate((aa_coords_intra, aa_coords_inter), 0)
            aa_blobbs_inter = voxelize_gauss(aa_coords_inter, self.sigma_aa,
                                             self.grid)
            aa_features_inter = d[
                'aa_inter_featvec'][:, :, None, None,
                                    None] * aa_blobbs_inter[:, None, :, :, :]
            aa_features_inter = np.sum(aa_features_inter, 0)
            features = np.concatenate((aa_features_intra, aa_features_inter),
                                      0)

        else:
            features = aa_features_intra
            aa_coords = aa_coords_intra

        if self.n_env_mols:
            #print(len(d['cg_positions_inter']))

            cg_positions_inter = voxelize_gauss(
                np.dot(d['cg_positions_inter'], R.T), self.sigma_cg, self.grid)
            cg_features_inter = np.sum(cg_positions_inter, 0, keepdims=True)
            target = np.concatenate((cg_positions_intra, cg_features_inter), 0)
        else:
            target = cg_positions_intra

        energy_ndx_aa = (d['aa_bond_ndx'], d['aa_ang_ndx'], d['aa_dih_ndx'],
                         d['aa_lj_intra_ndx'], d['aa_lj_ndx'])
        energy_ndx_cg = (d['cg_bond_ndx'], d['cg_ang_ndx'], d['cg_dih_ndx'],
                         d['cg_lj_intra_ndx'], d['cg_lj_ndx'])
        """
        fig = plt.figure(figsize=(20, 20))
        n_chns = 4
        colours = ['red', 'black', 'green', 'blue']
        ax = fig.add_subplot(1, 1, 1, projection='3d')
        # ax.scatter(mol_aa.com[0], mol_aa.com[1],mol_aa.com[2], s=20, marker='o', color='blue', alpha=0.5)
        for i in range(0, self.resolution):
            for j in range(0, self.resolution):
                for k in range(0, self.resolution):
                    for n in range(0,2):
                        #ax.scatter(i,j,k, s=2, marker='o', color='black', alpha=min(target[n,i,j,k], 1.0))
                        ax.scatter(i,j,k, s=2, marker='o', color='black', alpha=min(features[n,i,j,k], 1.0))

            #ax.set_xlim3d(-1.0, 1.0)
            #ax.set_ylim3d(-1.0, 1.0)
            #ax.set_zlim3d(-1.0, 1.0)
            #ax.set_xticks(np.arange(-1, 1, step=0.5))
            #ax.set_yticks(np.arange(-1, 1, step=0.5))
            #ax.set_zticks(np.arange(-1, 1, step=0.5))
            #ax.tick_params(labelsize=6)
            #plt.plot([0.0, 0.0], [0.0, 0.0], [-1.0, 1.0])
        plt.show()
        """

        #print("features", features.dtype)
        #print("target", target.dtype)
        #print("aa_coords_intra", aa_coords_intra.dtype)
        #print("aa_coords", aa_coords.dtype)

        elems = (features, target, aa_coords_intra, aa_coords)

        return elems, energy_ndx_aa, energy_ndx_cg