Example #1
0
    def prepare_trajectory(self, trajectory):
        """Prepare a trajectory for distance calculations based on the contact map.
        
        Each frame in the trajectory will be represented by a vector where
        each entries represents the distance between two residues in the structure.
        Depending on what contacts you pick to use, this can be a 'native biased' 
        picture or not.
        
        Paramters
        ---------
        trajectory : msmbuilder.Trajectory
            The trajectory to prepare
            
        Returns
        -------
        pairwise_distances : ndarray
            1D array of various residue-residue distances
        """

        xyzlist = trajectory['XYZList']
        traj_length = len(xyzlist)
        num_residues = trajectory.GetNumberOfResidues()
        num_atoms = trajectory.GetNumberOfAtoms()

        if self.contacts == 'all':
            contacts = np.empty(
                ((num_residues - 2) * (num_residues - 3) / 2, 2),
                dtype=np.int32)
            p = 0
            for (a, b) in itertools.combinations(range(num_residues), 2):
                if max(a, b) > min(a, b) + 2:
                    contacts[p, :] = [a, b]
                    p += 1
            assert p == len(contacts), 'Something went wrong generating "all"'

        else:
            num, width = self.contacts.shape
            contacts = self.contacts
            if not width == 2:
                raise ValueError('contacts must be width 2')
            if not (0 < len(np.unique(contacts[:, 0])) < num_residues):
                raise ValueError(
                    'contacts should refer to zero-based indexing of the residues'
                )
            if not np.all(
                    np.logical_and(0 <= np.unique(contacts),
                                   np.unique(contacts) < num_residues)):
                raise ValueError(
                    'contacts should refer to zero-based indexing of the residues'
                )

        if self.scheme == 'ca':
            # not all residues have a CA
            #alpha_indices = np.where(trajectory['AtomNames'] == 'CA')[0]
            atom_contacts = np.zeros_like(contacts)
            residue_to_alpha = np.zeros(num_residues)  # zero based indexing
            for i in range(num_atoms):
                if trajectory['AtomNames'][i] == 'CA':
                    residue = trajectory['ResidueID'][i] - 1
                    residue_to_alpha[residue] = i
            #print 'contacts (residues)', contacts
            #print 'residue_to_alpja', residue_to_alpha.shape
            #print 'residue_to_alpja', residue_to_alpha
            atom_contacts = residue_to_alpha[contacts]
            #print 'atom_contacts', atom_contacts
            output = _contactcalc.atom_distances(xyzlist, atom_contacts)

        elif self.scheme in ['closest', 'closest-heavy']:
            if self.scheme == 'closest':
                residue_membership = [None for i in range(num_residues)]
                for i in range(num_residues):
                    residue_membership[i] = np.where(
                        trajectory['ResidueID'] == i + 1)[0]
            elif self.scheme == 'closest-heavy':
                residue_membership = [[] for i in range(num_residues)]
                for i in range(num_atoms):
                    residue = trajectory['ResidueID'][i] - 1
                    if not trajectory['AtomNames'][i].lstrip(
                            '0123456789').startswith('H'):
                        residue_membership[residue].append(i)

            #print 'Residue Membership'
            #print residue_membership
            #for row in residue_membership:
            #    for col in row:
            #        print "%s-%s" % (trajectory['AtomNames'][col], trajectory['ResidueID'][col]),
            #    print
            output = _contactcalc.residue_distances(xyzlist,
                                                    residue_membership,
                                                    contacts)
        else:
            raise ValueError('This is not supposed to happen!')

        return np.double(output)