def test_res_contacts(self):
     contacts, num_frames = ct.parse_contacts(self.input_lines, set(['sb']))
     rcontacts = ct.res_contacts(contacts)
     self.assertEqual(len(rcontacts), 2)
     self.assertEqual(rcontacts[0][0], 0)
     self.assertEqual(rcontacts[1][0], 1)
     self.assertEqual(rcontacts[1][1], "A:ARG:76")
     self.assertEqual(rcontacts[1][2], "A:GLU:82")
Exemple #2
0
def create_flare(contacts, resi_labels, resi_colors):
    """
    Creates a flare from a contact-list and residue labels. If `resi_labels` isn't `None` then the "trees" and "tracks"
    will be generated as well.

    Parameters
    ----------
    contacts : list of list
        Each entry specifies a frame-number, an interaction type, and 2 to 4 atom-tuples depending on the interaction
        type. Water mediated and water-water mediated interactions will have waters in the third and fourth tuples.

    resi_labels : dict of (str : str))
    resi_colors : dict of (str : str))

    Returns
    -------
    dict of str : list
        The types of the list contents varies depending on the key, but the format corresponds to the specifications of
        jsons used as input for flareplot. For example:
            {
              "edges": [
                {"name1": "ARG1", "name2": "ALA3", "frames": [0,4,10]},
                {"name1": "ALA3", "name2": "THR2", "frames": [1,2]}
              ],
              "trees": [
                {"treeName": "DefaultTree", "treePaths": ["Group1.ARG1", "Group1.THR2", "Group2.ALA3", "Group2.CYS4"]}
              ],
              "tracks": [
                {"trackName": "DefaultTrack", "trackProperties": [
                  {"nodeName": "ARG1", "size": 1.0, "color": "red"},
                  {"nodeName": "THR2", "size": 1.0, "color": "red"},
                  {"nodeName": "ALA3", "size": 1.0, "color": "blue"},
                  {"nodeName": "CYS4", "size": 1.0, "color": "blue"}
                ]}
              ]
            }
    """
    ret = {
        "edges": []
    }

    # Strip atom3, atom4, and atom names
    # unique_chains = set([c[2][0] for c in contacts] + [c[3][0] for c in contacts])
    # contacts = [(c[0], c[1], c[2][0:3], c[3][0:3]) for c in contacts]
    rcontacts = res_contacts(contacts)

    resi_edges = {}
    for contact in rcontacts:
        # Compose a key for atom1 and atom2 that ignores the order of residues
        a1_key = contact[1]
        a2_key = contact[2]
        # a1_key = ":".join(contact[2][0:3])
        # a2_key = ":".join(contact[3][0:3])
        # if a1_key == a2_key:
        #     continue
        # if a1_key > a2_key:
        #     a1_key, a2_key = a2_key, a1_key
        contact_key = a1_key + a2_key

        # Look up labels
        if resi_labels:
            if a1_key not in resi_labels or a2_key not in resi_labels:
                print("create_flare: Omitting contact "+str(contact)+" as it doesn't appear in flare-label file")
                continue
            a1_label = resi_labels[a1_key].split('.')[-1]
            a2_label = resi_labels[a2_key].split('.')[-1]
        else:
            a1_label = a1_key
            a2_label = a2_key

        # Create contact_key if it doesn't exist
        if contact_key not in resi_edges:
            edge = {"name1": a1_label, "name2": a2_label, "frames": []}
            resi_edges[contact_key] = edge
            ret["edges"].append(edge)

        resi_edges[contact_key]["frames"].append(int(contact[0]))

    # Sort edge frames and ensure that there are no duplicates
    for e in ret["edges"]:
        e["frames"] = sorted(set(e["frames"]))

    # Create "trees" and "tracks" sections if resi_labels specified
    if resi_labels is not None:
        tree = {"treeLabel": "DefaultTree", "treePaths": []}
        ret["trees"] = [tree]

        track = {"trackLabel": "DefaultTrack", "trackProperties": []}
        ret["tracks"] = [track]

        for res in resi_labels:
            tree["treePaths"].append(resi_labels[res])
            track["trackProperties"].append({
                "nodeName": resi_labels[res].split(".")[-1],
                "color": resi_colors[res],
                "size": 1.0
            })

    return ret
Exemple #3
0
def main(argv=None):
    # Parse arguments
    parser = ap.PrintUsageParser(__doc__)
    parser.add_argument("--input_contacts",
                        type=argparse.FileType('r'),
                        required=True,
                        metavar="FILE",
                        help="Path to contact file")
    parser.add_argument("--clusters",
                        type=int,
                        required=False,
                        nargs="+",
                        default=[2, 5, 10],
                        metavar="INT",
                        help="Number of clusters [default: 2 5 10]")
    parser.add_argument("--tab_output",
                        type=str,
                        required=False,
                        metavar="FILE",
                        help="Path to TICC output file (tab-separated time/cluster indicators)")
    parser.add_argument("--frequency_output",
                        type=str,
                        required=False,
                        metavar="FILE",
                        help="Prefix to TICC output files (one res-frequency file for each cluster)")
    parser.add_argument("--beta",
                        type=int,
                        required=False,
                        nargs="+",
                        default=[10, 50, 100],
                        metavar="INT",
                        help="Beta parameter [default: 10 50 100]")
    parser.add_argument("--max_dimension",
                        type=int,
                        required=False,
                        default=50,
                        metavar="INT",
                        help="Max number of dimensions [default: 50]")
    args = parser.parse_args(argv)

    # Check output format and call corresponding function(s)
    if all(a is None for a in [args.tab_output, args.frequency_output]):
        parser.error("--tab_output or --frequency_output must be specified")

    print("Reading atomic contacts from " + args.input_contacts.name)
    atomic_contacts, num_frames = parse_contacts(args.input_contacts)
    args.input_contacts.close()

    print("Converting atomic contacts to residue contacts")
    residue_contacts = res_contacts(atomic_contacts)

    print("Performing dimensionality reduction")
    time_matrix = featurize_contacts(residue_contacts, args.max_dimension)

    print("Running TICC (clustered time-segmentation)")
    segmentation = run_ticc(time_matrix, cluster_number=args.clusters, beta=args.beta)

    if args.tab_output is not None:
        print("Writing time-segments to " + args.tab_output)
        with open(args.tab_output, "w") as f:
            f.writelines(map(lambda l: str(int(l)) + "\n", segmentation[0][0]))

    if args.frequency_output is not None:
        k = segmentation[0][2][2]
        for c in range(k):
            cluster_frames = set([frame for frame, cluster in enumerate(segmentation[0][0]) if cluster == c])
            cluster_contacts = [contact for contact in residue_contacts if contact[0] in cluster_frames]
            num_frames = len(cluster_frames)

            counts = gen_counts(cluster_contacts)
            total_frames, frequencies = gen_frequencies([(num_frames, counts)])

            fname = "%s_resfreq_cluster%03d.tsv" % (args.frequency_output, c)
            print("Writing frequency-flare to " + fname)
            with open(fname, "w") as output_file:
                output_file.write('#\ttotal_frames:%d\tinteraction_types:all\n' % total_frames)
                output_file.write('#\tColumns:\tresidue_1,\tresidue_2\tframe_count\tcontact_frequency\n')
                for (res1, res2), (count, frequency) in frequencies.items():
                    output_file.write('\t'.join([res1, res2, "%.3f" % frequency]) + "\n")