コード例 #1
0
    def calculate_per_member_stats(self, cur):
        '''
        Calculate the mean distance of all members of the newly merged cluster to all other members

        Parameters
        ----------
        self: obj
            that's me
        cur: obj
            database cursor

        Returns
        -------
        0
        '''

        self.member_stats = {}

        logging.info(
            "Calculating mean distance of all members of merging cluster %s on level %s.",
            self.final_name, self.t_level)

        for fm in self.final_members:
            others = [x for x in self.final_members if x != fm]
            dists = get_distances(cur, fm, others)
            x = [d for (s, d) in dists]
            self.member_stats[fm] = sum(x) / float(len(x))

        return 0
コード例 #2
0
def get_mean_distance_for_merged_cluster(cur, samid, mems):
    """
    Get the mean distance of a sample (samid) to all samples in the mems list.

    Parameters
    ----------
    cur: obj
        database cursor
    samid: int
        sample_id
    mems: list of int
        all members if this cluster
    Returns
    -------
    m: float
        mean distance
    """

    m = None
    assert samid in mems
    others = [x for x in mems if x != samid]
    dists = get_distances(cur, samid, others)
    d = [d for (s, d) in dists]
    assert len(d) == len(others)
    m = sum(d) / float(len(d))
    return m
コード例 #3
0
    def get_samples_below_threshold(self, sam_name, dis, levels=[0, 2, 5, 10, 25, 50, 100, 250]):
        """
        Get all samples that are below or equal to a given distance from the query sample.

        Parameters
        ----------
        sam_name: str
            query sample name
        dis: int
            distance threshold
        levels: list of ints
            default: [0, 2, 5, 10, 25, 50, 100, 250]
            better don't change it

        Returns
        -------
        result_samples: list of tuples
            sorted [(sample_name, distance), (sample_name, distance), (sample_name, distance), ..]
        """

        # get the snp address of the query sample
        sql = "SELECT s.pk_id, c.t0, c.t2, c.t5, c.t10, c.t25, c.t50, c.t100, c.t250 FROM sample_clusters c, samples s WHERE s.pk_id=c.fk_sample_id AND s.sample_name=%s"
        self.cur.execute(sql,(sam_name, ))
        if self.cur.rowcount < 1:
            raise SnapperDBInterrogationError("No clustering information found for sample %s" % (sam_name))
        row = self.cur.fetchone()
        snad = [row['t0'], row['t2'], row['t5'], row['t10'], row['t25'], row['t50'], row['t100'], row['t250']]
        samid = row['pk_id']

        ct = get_closest_threshold(dis)
        if ct != None:
            # selected distance <250 -> use only samples in associated cluster for calculation
            t_ct = 't%i' % (ct)
            cluster = snad[levels.index(ct)]
            sql = "SELECT s.sample_name AS samname, c.fk_sample_id AS samid FROM sample_clusters c, samples s WHERE c."+t_ct+"=%s AND s.pk_id=c.fk_sample_id"
            self.cur.execute(sql, (cluster, ))
        else:
            # selected distance >250 -> use all samples that have been clustered and are not ignored for calculation
            sql = "SELECT s.sample_name AS samname, c.fk_sample_id AS samid FROM sample_clusters c, samples s WHERE s.pk_id=c.fk_sample_id AND s.ignore_sample IS FALSE AND s.sample_name<>%s"
            self.cur.execute(sql, (sam_name, ))

        id2name = {}
        rows = self.cur.fetchall()
        neighbours = []
        for r in rows:
            id2name[r['samid']] = r['samname']
            if r['samid'] != samid:
                neighbours.append(r['samid'])

        if len(neighbours) <= 0:
            logging.info("No samples found this close to the query sample.")
            return []
        else:
            logging.info("Calculating distances to %i samples.", len(neighbours))
            distances = get_distances(self.cur, samid, neighbours)
            result_samples = [(id2name[s], d) for (s, d) in distances if d <= dis]
            if len(result_samples) <= 0:
                logging.info("No samples found this close to the query sample.")
            return result_samples
コード例 #4
0
def get_tree_samples_set(cur, t5_name):
    """
    Find out which samples need to go into the tree for this t5 cluster

    Parameters
    ----------
    cur: obj
        database cursor
    t5_name: int
        the name of the t5 cluster

    Returns
    -------
    sample_set: set
        it's in the name
    t50_size: int
        nof members in the t50 cluster
    """

    logging.debug("Processing t5 cluster %s", t5_name)

    sample_set = set()
    t50_cluster = set()
    t5_members = []
    t50_members = set()

    # get cluster members WITHOUT 'known outlier'
    sql = "select c.fk_sample_id AS samid, c.t50 AS tfifty FROM sample_clusters c, samples s WHERE c.t5=%s AND c.fk_sample_id=s.pk_id AND s.ignore_zscore=FALSE"
    cur.execute(sql, (t5_name, ))
    rows = cur.fetchall()
    t5_members = [r['samid'] for r in rows]
    t50_cluster.update([r['tfifty'] for r in rows])

    logging.debug("t5 %s has %i members.", t5_name, len(t5_members))

    # all members of the t5 are definitely in the tree
    sample_set.update(t5_members)

    assert len(t50_cluster) == 1, (
        "Why are not all members of t5 %s in the same t50?" % (t5_name))

    t50_name = t50_cluster.pop()

    logging.debug("t5 %s sits within t50 %s", t5_name, t50_name)

    sql = "select c.fk_sample_id AS samid FROM sample_clusters c, samples s WHERE c.t50=%s AND c.fk_sample_id=s.pk_id AND s.ignore_zscore=FALSE"
    cur.execute(sql, (t50_name, ))
    rows = cur.fetchall()
    t50_members.update([r['samid'] for r in rows])
    t50_size = len(t50_members)

    logging.debug("t50 %s has %i members.", t50_name, t50_size)

    for t5_mem in t5_members:
        check_samples = t50_members.difference(sample_set)
        dists = get_distances(cur, t5_mem, list(check_samples))
        sample_set.update([sa for (sa, di) in dists if di <= 50])

    return sample_set, t50_size
コード例 #5
0
ファイル: remove_sample.py プロジェクト: connor-lab/snapper3
def get_distances_from_memory(cur, distances, a, targets):
    """
    Get all the distances from 'a' to the target list. Check if they are in
    distances before calculating them. Put the newly calculatd into distances.

    Parameters
    ----------
    cur: obj
        database cursor
    distances: dist
        distances[a][b] = d
        distances[b][a] = d
    a: int
        sample id
    targets: list of int
        list of sample ids

    Returns
    -------
    result: list of tuples
        sorted list of tuples with (sample_id, distance) with closes sample first
        e.g. [(298, 0), (37, 3), (55, 4)]

    """

    result = []
    others = []
    for t in targets:
        try:
            result.append((t, distances[a][t]))
        except KeyError:
            others.append(t)

    if len(others) > 0:
        d = get_distances(cur, a, others)
        for (sa, di) in d:
            remember_distance(distances, a, sa, di)
        result += d

    result = sorted(result, key=lambda x: x[1])
    return result
コード例 #6
0
    def get_closest_samples(self, sam_name, neighbours, levels=[0, 2, 5, 10, 25, 50, 100, 250]):
        """
        Get the closest n samples.

        Parameters
        ----------
        sam_name: str
            name of the query sample
        neighbours: int
            number on neighbours
        levels: list of int
            default: [0, 2, 5, 10, 25, 50, 100, 250]
            better don't change it

        Returns
        -------
        result_samples: list of tuples
            sorted [(sample_name, distance), (sample_name, distance), (sample_name, distance), ..]
        """

        # get the snp address of the query sample
        sql = "SELECT s.pk_id, c.t0, c.t2, c.t5, c.t10, c.t25, c.t50, c.t100, c.t250 FROM sample_clusters c, samples s WHERE s.pk_id=c.fk_sample_id AND s.sample_name=%s"
        self.cur.execute(sql,(sam_name, ))
        if self.cur.rowcount < 1:
            raise SnapperDBInterrogationError("No clustering information found for sample %s" % (sam_name))
        row = self.cur.fetchone()
        snad = [row['t0'], row['t2'], row['t5'], row['t10'], row['t25'], row['t50'], row['t100'], row['t250']]
        samid = row['pk_id']

        close_samples = set()
        id2name = {}
        for clu, lvl in zip(snad, levels):

            t_lvl = 't%i' % (lvl)
            sql = "SELECT s.sample_name, c.fk_sample_id FROM sample_clusters c, samples s WHERE c."+t_lvl+"=%s AND s.pk_id=c.fk_sample_id"
            self.cur.execute(sql, (clu, ))
            rows = self.cur.fetchall()
            for r in rows:
                id2name[r['fk_sample_id']] = r['sample_name']
                if r['fk_sample_id'] != samid:
                    close_samples.add(r['fk_sample_id'])

            logging.info("Number of samples in same %s cluster: %i.", t_lvl, len(close_samples))

            if len(close_samples) >= neighbours:
                break

        distances = None
        if len(close_samples) < neighbours:
            distances = get_relevant_distances(self.cur, samid)
            sql = "SELECT pk_id, sample_name FROM samples"
            self.cur.execute(sql)
            id2name = {r['pk_id']: r['sample_name'] for r in self.cur.fetchall()}
        else:
            distances = get_distances(self.cur, samid, list(close_samples))
        result_samples = distances[:neighbours]

        for (sa, di) in distances[neighbours:]:
            if di == result_samples[-1][1]:
                result_samples.append((sa, di))

        result_samples = [(id2name[sa], di) for (sa, di) in result_samples if sa != samid]
        return result_samples
コード例 #7
0
def get_stats_for_merge(cur, oMerge):
    """
    Get a stats object for two (or more) clusters after they have been merged:
    either: get the biggest cluster and get the stats from the database
            the add on emember at a time from the other cluster(s)
    or: (if we're merging clusters with only one member) get all pw distances
        in the merged cluster and create stats object with that

    Parameters
    ----------
    cur: obj
        database cursor
    oMerge: obj
        ClusterMerge object

    Returns
    -------
    oStats: obj
        ClusterStats object
    current_mems: list of ints
        list of members of the merged cluster
    None if problem
    """

    # get the members for each of the clusters to merge and put them in dict
    # members[clu_id] = [list of sample ids]
    clu_to_merge = oMerge.org_clusters
    members = {}
    t_lvl = oMerge.t_level
    for ctm in clu_to_merge:
        # get the members for all cluster that need merging, ignoring the clusters that fail zscore
        sql = "SELECT c.fk_sample_id FROM sample_clusters c, samples s WHERE c." + t_lvl + "=%s AND s.pk_id=c.fk_sample_id AND s.ignore_zscore IS FALSE;"
        cur.execute(sql, (ctm, ))
        rows = cur.fetchall()
        ctm_mems = [r['fk_sample_id'] for r in rows]
        members[ctm] = ctm_mems

    # this now has sample_id of the largest cluster first
    clu_to_merge = sorted(members, key=lambda k: len(members[k]), reverse=True)
    oMerge.final_name = clu_to_merge[0]

    # get stats for the biggest cluster from database
    sql = "SELECT nof_members, nof_pairwise_dists, mean_pwise_dist, stddev FROM cluster_stats WHERE cluster_level=%s AND cluster_name=%s"
    cur.execute(sql, (
        t_lvl,
        clu_to_merge[0],
    ))
    if cur.rowcount != 1:
        return None, None
    row = cur.fetchone()

    if row['nof_members'] > 1:

        # make a stats obj from the stuff for the biggest cluster from the db
        oMerge.stats = ClusterStats(members=row['nof_members'],
                                    stddev=row['stddev'],
                                    mean=row['mean_pwise_dist'])

        # get members of biggest cluster
        current_mems = members[clu_to_merge[0]]
        # get all the samples that need adding to it to facilitate the merge
        new_members = []
        for ctm in clu_to_merge[1:]:
            new_members += members[ctm]

        # get all distances for new members and update stats obj iteratively
        for nm in new_members:
            dists = get_distances(cur, nm, current_mems)
            all_dists_to_new_member = [d for (s, d) in dists]
            oMerge.stats.add_member(all_dists_to_new_member)
            current_mems.append(nm)

    else:
        # if the biggest cluster has only one member, get all members of all clusters to be merged
        # and get all pw distances - shouldn't be many

        # make a flat list out of the values in members which are lists
        current_mems = [
            item for sublist in members.values() for item in sublist
        ]

        all_pw_dists = get_all_pw_dists(cur, current_mems)
        oMerge.stats = ClusterStats(members=len(current_mems),
                                    dists=all_pw_dists)

    oMerge.final_members = current_mems

    return oMerge.final_members
コード例 #8
0
def update_an_existing_tree(cur, conn, tree_row_id, t5_name, t5_members,
                            tree_sample_set, mod_time, t50_size, args):
    """
    Updates an existing tree in the database

    Parameters
    ----------
    cur: obj
        database cursor object
    conn: onj
        database connection object
    tree_row_id: int
        pk_id in the trees table
    t5_name: int
        name of the t5 cluster
    t5_members: set
        set of members for this t5 cluster
    tree_sample_set: set
        set of samples in the tree
    mod_time: datetime.datetime
        time tree was last updated
    t50_size: int
        size of the t50 cluster last time it was updated
    args: dict
        as passed to main function

    Returns
    -------
    1 or None if fail
    """

    logging.debug("=== Checking if tree for t5 cluster %s needs updating. ===",
                  t5_name)

    t50_name = get_t50_cluster(cur, t5_name, t5_members)
    logging.debug("t5 %s sits within t50 %s. It has %i members", t5_name,
                  t50_name, len(t5_members))

    t50_members = get_members(cur, 't50', t50_name)
    logging.debug("t50 %s has %i members.", t50_name, len(t50_members))

    logging.debug("t50 size at last update was %i, t50 size now is %i",
                  t50_size, len(t50_members))
    if len(t50_members) <= t50_size:
        logging.debug("Tree for t5 cluster %s does not need updating.",
                      t5_name)
        return 1

    needs_update = False

    # get the maximum t0 cluster number in the previous tree
    sql = "SELECT max(t0) FROM sample_clusters WHERE fk_sample_id IN %s"
    cur.execute(
        sql,
        (tuple(tree_sample_set), ),
    )
    tree_t0_max = cur.fetchone()[0]

    logging.debug("Max t0 in this tree is: %i", tree_t0_max)

    # set with t5 members that are not in the tree yet
    new_t5_members = t5_members.difference(tree_sample_set)
    # set with t5 members already in the tree
    old_t5_members = t5_members.intersection(tree_sample_set)
    # for all t5 members that are not in the tree yet
    # check whether there are members in the t50 that are not in the tree yet, but should be
    logging.debug("There are %i new members in this t5 cluster.",
                  len(new_t5_members))
    for new_t5_mem in new_t5_members:
        # definitely needs updating when there is new t5 members
        needs_update = True
        # check only the distances to t50 cluster samples that are not in the tree yet
        check_samples = t50_members.difference(tree_sample_set)
        logging.debug(
            "Checking distances from new t5 member %s to %i t50 members that are not in the tree yet.",
            new_t5_mem, len(check_samples))
        dists = get_distances(cur, new_t5_mem, list(check_samples))
        new_members = [sa for (sa, di) in dists if di <= 50]
        if len(new_members) > 0:
            logging.debug("These samples need to be in the tree now: %s.",
                          str(new_members))
            tree_sample_set.update(new_members)

    # reduce the list of members in the 50 to those REALLY neeeding checking
    filtered_t50_members = filter_samples_to_be_checked(
        cur, t50_members, tree_t0_max)
    logging.debug("Reduced nof t50 members to be checked from %i to %i.",
                  len(t50_members), len(filtered_t50_members))

    for old_t5_mem in old_t5_members:
        # check only the distances to t50 cluster samples that are not in the tree yet
        check_samples = filtered_t50_members.difference(tree_sample_set)
        logging.debug(
            "Checking distances from old t5 member %s to %i t50 members that are not in the tree yet.",
            old_t5_mem, len(check_samples))
        dists = get_distances(cur, old_t5_mem, list(check_samples))
        new_members = [sa for (sa, di) in dists if di <= 50]
        if len(new_members) > 0:
            logging.debug("These samples need to be in the tree now: %s.",
                          str(new_members))
            tree_sample_set.update(new_members)
            needs_update = True

    if needs_update == True:

        # lock table row during tree update
        sql = "UPDATE trees SET nwkfile=%s, sample_set=%s, lockdown=%s WHERE pk_id=%s"
        cur.execute(sql, (
            None,
            None,
            True,
            tree_row_id,
        ))
        conn.commit()

        logging.info(
            "The tree for t5 cluster %s needs updating and will now contain %i samples.",
            t5_name, len(tree_sample_set))

        sample_names = get_sample_names(cur, tree_sample_set)

        # if the reference is part of the tree we need to remove this here
        # it is always part of all trees anyway
        try:
            sample_names.remove(args['refname'])
        except KeyError:
            pass

        # make a tree now using SnapperDBInterrogation interface
        nwktree = None
        with SnapperDBInterrogation(conn_string=args['db']) as sdbi:
            try:
                nwktree = sdbi.get_tree(list(sample_names),
                                        None,
                                        'ML',
                                        ref=args['ref'],
                                        refname=args['refname'],
                                        rmref=True,
                                        overwrite_max=True)
            except SnapperDBInterrogationError as e:
                logging.error(e)
                return None
            else:
                logging.info("Tree calculation completed successfully.")

        nownow = datetime.now()

        # update the database - unlock the table row
        sql = "UPDATE trees SET nwkfile=%s, lockdown=%s, mod_date=%s, sample_set=%s, t50_size=%s WHERE pk_id=%s"
        cur.execute(sql, (
            nwktree,
            False,
            nownow,
            list(tree_sample_set),
            len(t50_members),
            tree_row_id,
        ))
        conn.commit()

    else:
        sql = "UPDATE trees SET t50_size=%s WHERE pk_id=%s"
        cur.execute(sql, (
            len(t50_members),
            tree_row_id,
        ))
        logging.debug("Tree for t5 cluster %i does not need updating.",
                      t5_name)

    return 1
コード例 #9
0
def make_a_new_tree(cur, t5_name, t5_members, args):
    """
    Makes a new tree for a cluster that previously did not have one.

    Parameters
    ----------
    t5_name: int
        name of the t5 cluster
    t5_members: set
        set of members of the t5 cluster
    args: dict
        as passed to main function

    Returns
    -------
    0
    """

    t50_name = get_t50_cluster(cur, t5_name, t5_members)
    logging.debug("t5 %s sits within t50 %s", t5_name, t50_name)

    t50_members = get_members(cur, 't50', t50_name)
    logging.debug("t50 %s has %i members.", t50_name, len(t50_members))

    sample_set = set()
    sample_set.update(list(t5_members))

    for t5_mem in t5_members:
        check_samples = t50_members.difference(sample_set)
        dists = get_distances(cur, t5_mem, list(check_samples))
        sample_set.update([sa for (sa, di) in dists if di <= 50])

    logging.info("The tree for t5 cluster %s will contain %i samples.",
                 t5_name, len(sample_set))

    sample_names = get_sample_names(cur, sample_set)

    try:
        sample_names.remove(args['refname'])
    except KeyError:
        pass

    nwktree = None
    with SnapperDBInterrogation(conn_string=args['db']) as sdbi:
        try:
            nwktree = sdbi.get_tree(list(sample_names),
                                    None,
                                    'ML',
                                    ref=args['ref'],
                                    refname=args['refname'],
                                    rmref=True,
                                    overwrite_max=True)
        except SnapperDBInterrogationError as e:
            logging.error(e)
        else:
            logging.info("Tree calculation completed successfully.")

    nownow = datetime.now()

    sql = "INSERT INTO trees (nwkfile, t5_name, t50_size, sample_set, mod_date, created_at, lockdown) VALUES (%s, %s, %s, %s, %s, %s, %s)"
    cur.execute(sql, (
        nwktree,
        t5_name,
        len(t50_members),
        list(sample_set),
        nownow,
        nownow,
        False,
    ))

    return 0
コード例 #10
0
ファイル: remove_sample.py プロジェクト: connor-lab/snapper3
def check_cluster_integrity(cur,
                            sample_id,
                            snad,
                            distances,
                            levels=[0, 2, 5, 10, 25, 50, 100, 250]):
    """
    Check whether the removal of sample_id from any of its cluster necessitates
    the split of the cluster.

    Parameters
    ----------
    cur: obj
        database cursor
    sample_id: int
        id of sample to remove
    snad: list of 7 int
        snip address
    distances: dist
        distances[a][b] = d
        distances[b][a] = d
    levels: list of 7 int
        better not change this
        [0, 5, 10, 25, 50, 100, 250]

    Returns
    -------
    None if no splits required, else:
    splits: dict
        splits[level] = [(c, a, b), ...] <- no longer connected pair in cluster c
    """

    splits = {}

    for clu, lvl in zip(snad, levels):

        t_lvl = 't%i' % (lvl)

        logging.info("Checking cluster integrity for cluster %s on level %s.",
                     clu, t_lvl)

        # get all other members of the cluster apart from the removee
        mems = get_all_cluster_members(cur, clu, t_lvl)
        mems.remove(sample_id)

        # get distances of the removee to them
        d = get_distances(cur, sample_id, mems)
        connected_mems = []
        for (sa, di) in d:
            # get all samples that are connected to the removee with d <= t
            if di <= lvl:
                connected_mems.append(sa)
            remember_distance(distances, sample_id, sa, di)

        logging.debug("Samples connected via removee: %s",
                      sorted(connected_mems))

        # investigate all pw distances between connected members
        potentially_broken_pairs = []
        for i, a in enumerate(connected_mems):
            for j, b in enumerate(connected_mems):
                if i < j:
                    pwd = None
                    try:
                        pwd = distances[a][b]
                    except KeyError:
                        pwd = get_all_pw_dists(cur, [a, b])[0]
                        remember_distance(distances, a, b, pwd)
                    # if pw distance between the two sampes is bigger than the threshold,
                    # the link between the samples might be broken, unless there is another samples
                    # (or chain of samples) connecting them
                    if pwd > lvl:
                        potentially_broken_pairs.append((a, b))

        # all pairs that were connected through the removee are also directly connected, happy days
        if len(potentially_broken_pairs) == 0:
            splits[lvl] = None
            continue

        logging.debug(
            "Samples potentially no longer connected via removee: %s",
            potentially_broken_pairs)

        # check if there is another path to get from a to b with only steps <= t
        for a, b in potentially_broken_pairs:
            broken = False
            logging.debug(
                "Checking if there is another way to connect %s and %s.", a, b)
            # list of samples connectable to a (directly or over multiple nodes)
            rel_conn_sams_to_a = [a]
            idx = 0
            # when b in connected the a w're done
            while b not in rel_conn_sams_to_a:
                # pivot is the one currently investigated
                pivot = rel_conn_sams_to_a[idx]
                # get all the members of the current cluster except the pivot
                all_mems_but_pivot = [x for x in mems if x != pivot]
                # get all the distances from the pivot to thpse members
                d = get_distances_from_memory(cur, distances, pivot,
                                              all_mems_but_pivot)
                # all new samples that are connectable to the pivot
                # two conditions: a) the sample is connected to the pivot with d<=t
                #         b) we don't have this sample yet in the ones we already know are connected to a
                rel_conn_to_pivot = [
                    sa for (sa, di) in d
                    if (di <= lvl) and (sa not in rel_conn_sams_to_a)
                ]
                # there are no new samples connected to the pivot and the last sample has been considered
                # but b is not yet foud to be connected => cluster is broken
                if len(rel_conn_to_pivot
                       ) == 0 and pivot == rel_conn_sams_to_a[-1]:
                    broken = True
                    break
                else:
                    # otehr wise add any potential new ones to the list and check the next one
                    rel_conn_sams_to_a += rel_conn_to_pivot
                    idx += 1
            # we need to remember what was broken for updating later
            if broken == True:
                try:
                    splits[lvl].append((clu, a, b))
                except KeyError:
                    splits[lvl] = [(clu, a, b)]
                # go to next broken pair, there might be more than one and
                # we want to know for updating later

        # we checked all pairs and always found b somehow, cluster is fine
        if splits.has_key(lvl) == False:
            splits[lvl] = None

    return splits