コード例 #1
0
ファイル: remove_sample.py プロジェクト: connor-lab/snapper3
def check_cluster_integrity(cur,
                            sample_id,
                            snad,
                            distances,
                            levels=[0, 2, 5, 10, 25, 50, 100, 250]):
    """
    Check whether the removal of sample_id from any of its cluster necessitates
    the split of the cluster.

    Parameters
    ----------
    cur: obj
        database cursor
    sample_id: int
        id of sample to remove
    snad: list of 7 int
        snip address
    distances: dist
        distances[a][b] = d
        distances[b][a] = d
    levels: list of 7 int
        better not change this
        [0, 5, 10, 25, 50, 100, 250]

    Returns
    -------
    None if no splits required, else:
    splits: dict
        splits[level] = [(c, a, b), ...] <- no longer connected pair in cluster c
    """

    splits = {}

    for clu, lvl in zip(snad, levels):

        t_lvl = 't%i' % (lvl)

        logging.info("Checking cluster integrity for cluster %s on level %s.",
                     clu, t_lvl)

        # get all other members of the cluster apart from the removee
        mems = get_all_cluster_members(cur, clu, t_lvl)
        mems.remove(sample_id)

        # get distances of the removee to them
        d = get_distances(cur, sample_id, mems)
        connected_mems = []
        for (sa, di) in d:
            # get all samples that are connected to the removee with d <= t
            if di <= lvl:
                connected_mems.append(sa)
            remember_distance(distances, sample_id, sa, di)

        logging.debug("Samples connected via removee: %s",
                      sorted(connected_mems))

        # investigate all pw distances between connected members
        potentially_broken_pairs = []
        for i, a in enumerate(connected_mems):
            for j, b in enumerate(connected_mems):
                if i < j:
                    pwd = None
                    try:
                        pwd = distances[a][b]
                    except KeyError:
                        pwd = get_all_pw_dists(cur, [a, b])[0]
                        remember_distance(distances, a, b, pwd)
                    # if pw distance between the two sampes is bigger than the threshold,
                    # the link between the samples might be broken, unless there is another samples
                    # (or chain of samples) connecting them
                    if pwd > lvl:
                        potentially_broken_pairs.append((a, b))

        # all pairs that were connected through the removee are also directly connected, happy days
        if len(potentially_broken_pairs) == 0:
            splits[lvl] = None
            continue

        logging.debug(
            "Samples potentially no longer connected via removee: %s",
            potentially_broken_pairs)

        # check if there is another path to get from a to b with only steps <= t
        for a, b in potentially_broken_pairs:
            broken = False
            logging.debug(
                "Checking if there is another way to connect %s and %s.", a, b)
            # list of samples connectable to a (directly or over multiple nodes)
            rel_conn_sams_to_a = [a]
            idx = 0
            # when b in connected the a w're done
            while b not in rel_conn_sams_to_a:
                # pivot is the one currently investigated
                pivot = rel_conn_sams_to_a[idx]
                # get all the members of the current cluster except the pivot
                all_mems_but_pivot = [x for x in mems if x != pivot]
                # get all the distances from the pivot to thpse members
                d = get_distances_from_memory(cur, distances, pivot,
                                              all_mems_but_pivot)
                # all new samples that are connectable to the pivot
                # two conditions: a) the sample is connected to the pivot with d<=t
                #         b) we don't have this sample yet in the ones we already know are connected to a
                rel_conn_to_pivot = [
                    sa for (sa, di) in d
                    if (di <= lvl) and (sa not in rel_conn_sams_to_a)
                ]
                # there are no new samples connected to the pivot and the last sample has been considered
                # but b is not yet foud to be connected => cluster is broken
                if len(rel_conn_to_pivot
                       ) == 0 and pivot == rel_conn_sams_to_a[-1]:
                    broken = True
                    break
                else:
                    # otehr wise add any potential new ones to the list and check the next one
                    rel_conn_sams_to_a += rel_conn_to_pivot
                    idx += 1
            # we need to remember what was broken for updating later
            if broken == True:
                try:
                    splits[lvl].append((clu, a, b))
                except KeyError:
                    splits[lvl] = [(clu, a, b)]
                # go to next broken pair, there might be more than one and
                # we want to know for updating later

        # we checked all pairs and always found b somehow, cluster is fine
        if splits.has_key(lvl) == False:
            splits[lvl] = None

    return splits
コード例 #2
0
def get_stats_for_merge(cur, oMerge):
    """
    Get a stats object for two (or more) clusters after they have been merged:
    either: get the biggest cluster and get the stats from the database
            the add on emember at a time from the other cluster(s)
    or: (if we're merging clusters with only one member) get all pw distances
        in the merged cluster and create stats object with that

    Parameters
    ----------
    cur: obj
        database cursor
    oMerge: obj
        ClusterMerge object

    Returns
    -------
    oStats: obj
        ClusterStats object
    current_mems: list of ints
        list of members of the merged cluster
    None if problem
    """

    # get the members for each of the clusters to merge and put them in dict
    # members[clu_id] = [list of sample ids]
    clu_to_merge = oMerge.org_clusters
    members = {}
    t_lvl = oMerge.t_level
    for ctm in clu_to_merge:
        # get the members for all cluster that need merging, ignoring the clusters that fail zscore
        sql = "SELECT c.fk_sample_id FROM sample_clusters c, samples s WHERE c." + t_lvl + "=%s AND s.pk_id=c.fk_sample_id AND s.ignore_zscore IS FALSE;"
        cur.execute(sql, (ctm, ))
        rows = cur.fetchall()
        ctm_mems = [r['fk_sample_id'] for r in rows]
        members[ctm] = ctm_mems

    # this now has sample_id of the largest cluster first
    clu_to_merge = sorted(members, key=lambda k: len(members[k]), reverse=True)
    oMerge.final_name = clu_to_merge[0]

    # get stats for the biggest cluster from database
    sql = "SELECT nof_members, nof_pairwise_dists, mean_pwise_dist, stddev FROM cluster_stats WHERE cluster_level=%s AND cluster_name=%s"
    cur.execute(sql, (
        t_lvl,
        clu_to_merge[0],
    ))
    if cur.rowcount != 1:
        return None, None
    row = cur.fetchone()

    if row['nof_members'] > 1:

        # make a stats obj from the stuff for the biggest cluster from the db
        oMerge.stats = ClusterStats(members=row['nof_members'],
                                    stddev=row['stddev'],
                                    mean=row['mean_pwise_dist'])

        # get members of biggest cluster
        current_mems = members[clu_to_merge[0]]
        # get all the samples that need adding to it to facilitate the merge
        new_members = []
        for ctm in clu_to_merge[1:]:
            new_members += members[ctm]

        # get all distances for new members and update stats obj iteratively
        for nm in new_members:
            dists = get_distances(cur, nm, current_mems)
            all_dists_to_new_member = [d for (s, d) in dists]
            oMerge.stats.add_member(all_dists_to_new_member)
            current_mems.append(nm)

    else:
        # if the biggest cluster has only one member, get all members of all clusters to be merged
        # and get all pw distances - shouldn't be many

        # make a flat list out of the values in members which are lists
        current_mems = [
            item for sublist in members.values() for item in sublist
        ]

        all_pw_dists = get_all_pw_dists(cur, current_mems)
        oMerge.stats = ClusterStats(members=len(current_mems),
                                    dists=all_pw_dists)

    oMerge.final_members = current_mems

    return oMerge.final_members
コード例 #3
0
ファイル: remove_sample.py プロジェクト: connor-lab/snapper3
def update_cluster_stats_post_removal(cur, sid, clu, lvl, distances, split,
                                      zscr_flag):
    """
    Update the cluster stats and the sample stats for removing the sample from the cluster.

    Parameters
    ----------
    cur: obj
        database cursor
    sid: int
        pk id of sample to remove
    clu: int
        cluster id to remove from
    lvl: int
        cluster level of removal
    distances: list of tuples
        [(samid, distance), (samid, distance), (samid, distance), ...]
    split: list of tuples
        [(c, a, b), ...] <- no longer connected pair in cluster c
    Returns
    -------
    0 if fine
    None if fail
    """

    t_lvl = "t%i" % (lvl)

    logging.info("Updating stats for cluster %s on level %s.", clu, t_lvl)

    # get the cluster stats from the database
    sql = "SELECT nof_members, mean_pwise_dist, stddev FROM cluster_stats WHERE cluster_level=%s AND cluster_name=%s"
    cur.execute(sql, (
        t_lvl,
        clu,
    ))
    if cur.rowcount == 0:
        if zscr_flag == True:
            logging.info(
                "Sample is a known outlier and the only member of level %s cluster %s. So there are no stats to update.",
                t_lvl, clu)
            return 0
        else:
            logging.error(
                "Cluster stats for level %s and cluster %s not found.", t_lvl,
                clu)
            return None
    row = cur.fetchone()

    # when deleting the last member of this cluster there is no need to update anything, just get rid of it
    if row['nof_members'] <= 1:
        logging.debug(
            "This is the last member of cluster %s on level %s. Deleting cluster stats.",
            clu, t_lvl)
        sql = "DELETE FROM cluster_stats WHERE cluster_level=%s AND cluster_name=%s"
        cur.execute(sql, (
            t_lvl,
            clu,
        ))
        return 0

    # create cluster stats object from the information in the database
    oStats = ClusterStats(members=row['nof_members'],
                          stddev=row['stddev'],
                          mean=row['mean_pwise_dist'])

    # get other members of this cluster, we know it must be at least one
    sql = "SELECT c.fk_sample_id AS samid FROM sample_clusters c, samples s WHERE c." + t_lvl + "=%s AND c.fk_sample_id=s.pk_id AND s.ignore_zscore IS FALSE"
    cur.execute(sql, (clu, ))
    members = [r['samid'] for r in cur.fetchall()]
    logging.debug("Got the follwoing members: %s", members)
    try:
        logging.debug("Removing %s from list: %s", sid, members)
        members.remove(sid)
    except ValueError:
        if zscr_flag == True:
            logging.debug(
                "Could not find %s as a member of %s cluster %s, but it's OK because it's a known outlier.",
                sid, t_lvl, clu)
        else:
            logging.error(
                "Bizzare data inconsistency for sample id %s and %s cluster %s.",
                sid, t_lvl, clu)
            return None

    # if the was previously ignore do not remove it from stats object, because it was never considered when calculating the stats
    if zscr_flag == False:
        # get all distances from the sample to be removed to all other members of the cluster
        # and update the stats object with this information
        this_di = [distances[sid][m] for m in members]
        assert oStats.members == (len(this_di) + 1)
        oStats.remove_member(this_di)
        # remember which members have been removed from the stats object
        removed_members = [sid]
    else:
        removed_members = []

    # if tere is a split on this level
    if split != None:
        logging.info("Cluster %s need to be split.", clu)
        groups = split_clusters(cur, sid, split, lvl, distances)
        logging.debug("It will be split into these subclusters: %s", groups)

        # groups[a] = [1,2,3]
        # groups[b] = [4,5,6]

        knwntlrs = set()

        # put the largest subcluster at the front of the list of subclusters
        group_lists = sorted(groups.values(), key=len, reverse=True)
        logging.debug("These are the group lists: %s", group_lists)
        # for the largest group
        for grli in group_lists[1:]:
            # for all members of this group
            logging.debug("Current group list: %s", grli)
            for m in grli:
                # remove from members, from stats object and remember that you removed it in that list
                logging.debug("Removing %s from %s", m, members)
                if m in members:
                    members.remove(m)
                    this_di = [
                        d for (s, d) in get_distances_from_memory(
                            cur, distances, m, members)
                    ]
                    assert oStats.members == (len(this_di) + 1)
                    # i.e. turn the oStats object into the stats object for the largest cluster after the split
                    oStats.remove_member(this_di)
                    removed_members.append(m)
                else:
                    knwntlrs.add(m)
                    logging.debug(
                        "Could not remove %s from members, but it's probably a kown outlier."
                    )

        # for the other subclustrs
        for grli in group_lists[1:]:
            # make a new stats object based on the list of members and all pw distances between them
            # remove known outliers previously encountered from consideration
            grli = list(set(grli).difference(knwntlrs))
            all_pw_grdi = get_all_pw_dists(cur, grli)
            oStatsTwo = ClusterStats(members=len(grli), dists=all_pw_grdi)
            sql = "SELECT max(" + t_lvl + ") AS m FROM sample_clusters"
            cur.execute(sql)
            row = cur.fetchone()
            new_clu_name = row['m'] + 1
            sql = "INSERT INTO cluster_stats (cluster_level, cluster_name, nof_members, nof_pairwise_dists, mean_pwise_dist, stddev) VALUES (%s, %s, %s, %s, %s, %s)"
            cur.execute(sql, (
                t_lvl,
                new_clu_name,
                oStatsTwo.members,
                oStatsTwo.nof_pw_dists,
                oStatsTwo.mean_pw_dist,
                oStatsTwo.stddev_pw_dist,
            ))

            # document the upcoming change in the sample history
            update_sample_history(cur, t_lvl, new_clu_name, grli)

            # put all members of this subcluster in the new cluster in the database
            sql = "UPDATE sample_clusters SET " + t_lvl + "=%s WHERE fk_sample_id IN %s"
            cur.execute(sql, (
                new_clu_name,
                tuple(grli),
            ))

            # calculate the mean distance to all other members from scratch for all members of this newly
            # created subcluster and update in the database
            for nm in grli:
                targets = [x for x in grli if x != nm]
                alldis = [
                    di for (sa, di) in get_distances_from_memory(
                        cur, distances, nm, targets)
                ]
                try:
                    mean = sum(alldis) / float(len(alldis))
                except ZeroDivisionError:
                    mean = None
                sql = "UPDATE sample_clusters SET " + t_lvl + "_mean=%s WHERE fk_sample_id=%s"
                cur.execute(sql, (
                    mean,
                    nm,
                ))

    # then update the cluster stats in the database with the info from the object
    sql = "UPDATE cluster_stats SET (nof_members, nof_pairwise_dists, mean_pwise_dist, stddev) = (%s, %s, %s, %s) WHERE cluster_level=%s AND cluster_name=%s"
    cur.execute(sql, (
        oStats.members,
        oStats.nof_pw_dists,
        oStats.mean_pw_dist,
        oStats.stddev_pw_dist,
        t_lvl,
        clu,
    ))

    # for all other members of this cluster
    for mem in members:

        # get the mean distance to all other members
        sql = "SELECT " + t_lvl + "_mean FROM sample_clusters WHERE fk_sample_id=%s"
        cur.execute(sql, (mem, ))
        if cur.rowcount == 0:
            logging.error("Cluster %s not found in sample_clusters table.",
                          mem)
            return None
        row = cur.fetchone()
        p_mean = row[t_lvl + '_mean']
        n_mean = p_mean

        # update this mean by removing one distance and update the database table
        # if there was not split removed_members will have only one member in it
        # else there are more
        # If there was no split on this level and the samples was previously ignored,
        # it will be empty
        for remomem in removed_members:
            x = get_distances_from_memory(cur, distances, mem, [remomem])[0][1]
            try:
                n_mean = (
                    (p_mean * len(members)) - x) / float(len(members) - 1)
            except ZeroDivisionError:
                n_mean = None
            p_mean = n_mean

        sql = "UPDATE sample_clusters SET " + t_lvl + "_mean=%s WHERE fk_sample_id=%s"
        cur.execute(sql, (
            n_mean,
            mem,
        ))

    return 0