Ejemplo n.º 1
0
def main(args):
    global db
    db = mongodb.get_db(args.db, args.ip, args.port, args.user, args.password)
    collection_groups = get_collection_groups(args)
    print_start_info(collection_groups, args)

    for i, collection_group in enumerate(collection_groups):
        print_collection_group_info(collection_group, i)
        seqs = get_sequences(collection_group, args)
        seq_count = len(seqs)
        mr_db = build_mr_db(seqs, args)
        if len(seqs) == 0:
            continue

        # Get the number of cores -- if multiprocessing, just mp.cpu_count();
        # if Celery, need to get the number of worker processes
        if args.num_map_workers > 0:
            cores = args.num_map_workers
        elif args.celery:
            app = celery.Celery()
            app.config_from_object('abtools.celeryconfig')
            stats = app.control.inspect().stats()
            cores = sum(
                [v['pool']["max-concurrency"] for v in list(stats.values())])
        else:
            cores = mp.cpu_count()
        logger.debug('CORES: {}'.format(cores))

        # Divide the input sequences into JSON subfiles
        chunksize = int(math.ceil(float(len(seqs)) / cores))
        logger.debug('CHUNKSIZE: {}'.format(chunksize))
        json_files = []
        seq_ids = []
        for seq_chunk in chunker(seqs, chunksize):
            seq_ids.append([s['seq_id'] for s in seq_chunk])
            json_files.append(
                Cluster.pretty_json(seq_chunk,
                                    as_file=True,
                                    temp_dir=args.temp))
        json_db = build_json_db(seq_ids, json_files, args)

        # Run clonify jobs via multiprocessing or Celery
        # Should return a list of Cluster objects for each subfile
        # if args.test_algo:
        #     test_clonify_algorithm(json_files, mr_db, json_db, args)
        # else:
        #     clusters = clonify(json_files, mr_db, json_db, args)
        print_clonify_input_building(seq_count)
        clusters = clonify(json_files, mr_db, json_db, args)

        if args.output:
            print_output(args)
            name = collection_group if len(collection_group) == 1 else str(i)
            write_output(clusters, mr_db, args, collection_group=name)
        if args.update:
            cluster_sizes = update_db(clusters, collection_group)
        else:
            cluster_sizes = [c.size for c in clusters]
        print_finished(cluster_sizes)
Ejemplo n.º 2
0
def get_freq_df(db,
                collections,
                value,
                chain='heavy',
                match=None,
                normalize=False):

    # database
    if type(db) == pymongo.database.Database:
        DB = db
    elif type(db) == str:
        DB = mongodb.get_db(db)
    else:
        print('Database not correct')
        return

    if type(collections) == list:
        collections = collections
    else:
        collections = [
            collections,
        ]

    if not match:
        match = {'chain': chain, 'prod': 'yes'}
    group = {'_id': '${}'.format(value), 'count': {'$sum': 1}}

    # initialize a dictionary that will hold all of the DataFrames we're making (one per subject)
    data = {}

    # iterate through each of the collections in the subject group
    for collection in collections:
        #print(collection)
        data[collection] = {}

        # get the aggregation data from MongoDB
        res = DB[collection].aggregate([{'$match': match}, {'$group': group}])

        # convert the MongoDB aggregation result into a dictionary iof V-gene counts
        for r in res:
            data[collection][r['_id']] = r['count']

    #print('')

    # construct a DataFrame from the dictionary of V-gene counts
    df = pd.DataFrame(data)

    if normalize:
        df = df / df.sum()
        df = df.dropna(0)

    return df
Ejemplo n.º 3
0
def assign_sample_metadata(database,
                           collection,
                           ip='localhost',
                           port=27017,
                           user=None,
                           password=None,
                           subjects=None,
                           groups=None,
                           experiments=None,
                           timepoints=None):
    db = mongodb.get_db(database, ip=ip, port=port)
    if subjects is not None:
        assign_subjects(db, collection, subjects)
    if groups is not None:
        assign_groups(db, collection, groups)
    if experiments is not None:
        assign_experiments(db, collection, experiments)
    if timepoints is not None:
        assign_timepoints(db, collection, timepoints)
Ejemplo n.º 4
0
def main(args, logfile=None):
    global logger
    logger = log.get_logger('demultiplex')
    print_start_info()
    if all([args.index is None, args.index_file is None]):
        err = 'Indexes must be provided, either using --index or --index-file'
        raise RuntimeError(err)
    log_options(args, logfile=logfile)
    make_directories(args)
    open(args.output, 'w').write('')
    db = mongodb.get_db(args.db,
                        ip=args.ip,
                        port=args.port,
                        user=args.user,
                        password=args.password)
    plate_map = parse_plate_map(args.plate_map)
    # all_seqs = []
    collections = mongodb.get_collections(db,
                                          args.collection,
                                          prefix=args.collection_prefix,
                                          suffix=args.collection_suffix)
    for collection in collections:
        if collection not in plate_map:
            logger.info(
                '\n\n{} was not found in the supplied plate map file.'.format(
                    collection))
            continue
        plate_names = plate_map[collection]
        for plate_num, plate_name in enumerate(plate_names):
            if plate_name is None:
                continue
            print_plate_info(plate_name, collection)
            indexes = get_indexes(args.index, args.index_file,
                                  args.index_length, plate_num)
            for chain in ['heavy', 'kappa', 'lambda']:
                plate_seqs = []
                logger.info('')
                logger.info('Querying for {} chain sequences'.format(chain))
                score_cutoff = args.score_cutoff_heavy if chain == 'heavy' else args.score_cutoff_light
                sequences = get_sequences(db, collection, chain, score_cutoff)
                logger.info(
                    'QUERY RESULTS: {} {} chain sequences met the quality threshold'
                    .format(len(sequences), chain.lower()))
                bins = bin_by_index(sequences, indexes, args.index_length,
                                    args.index_position,
                                    args.index_reverse_complement,
                                    args.raw_seq_field)
                if args.minimum_well_size == 'relative':
                    min_well_size = int(
                        len(sequences) / float(args.minimum_well_size_denom))
                else:
                    min_well_size = int(args.minimum_well_size)
                min_max_well_size = max(min_well_size,
                                        args.minimum_max_well_size)
                if max([len(b) for b in list(bins.values())
                        ]) < int(min_max_well_size):
                    logger.info(
                        'The biggest well had fewer than {} sequences, so the plate was not processed'
                        .format(min_max_well_size))
                    continue
                for b in sorted(bins.keys()):
                    if len(bins[b]) < 25:
                        continue
                    print_bin_info(b)
                    if args.raw_sequence_dir is not None:
                        rs_handle = open(
                            os.path.join(
                                args.raw_sequence_dir,
                                '{}-{}_{}'.format(plate_name, b, chain)),
                            'write')
                        rs_handle.write('\n'.join(
                            ['>{}\n{}'.format(s[0], s[1]) for s in bins[b]]))
                        rs_handle.close()
                    consentroid = cdhit_clustering(
                        bins[b], b, plate_name, args.temp_dir, len(sequences),
                        args.minimum_well_size, args.minimum_well_size_denom,
                        args.minimum_cluster_fraction, args.raw_sequence_dir,
                        args.alignment_pixel_dir, args.consensus,
                        args.cdhit_threshold, chain)
                    if consentroid:
                        consentroid_name = '{}-{}'.format(plate_name, b)
                        plate_seqs.append((consentroid_name, consentroid))
                log_output(bins, plate_seqs, min_well_size)
                # all_seqs.extend(plate_seqs)
                write_output(plate_seqs, args.output)
                logger.info('')
    logger.info('')
Ejemplo n.º 5
0
def QuickDataCheck(db, collections=None, index=False, values=None, match=None):
    # This function will quickly allow you to check the sequencing data of a database
    # database
    if type(db) == pymongo.database.Database:
        DB = db
    elif type(db) == str:
        DB = mongodb.get_db(db)
    else:
        print('Database not correct')
        return

    if collections is None:
        colls = mongodb.get_collections(DB)
    else:
        colls = collections

    #index the collections if applicable
    if index:
        print('Indexing Collections...')
        for collection in tqdm(colls):
            DB[collection].create_index([('chain', 1), ('prod', 1),
                                         ('v_gene.gene', 1), ('cdr3_len', 1)],
                                        name='productive heavychain cdr3_len',
                                        default_language='english')
    #if there is a set values, then use those
    if values:
        print('Getting data...')
        dfs = [
            get_freq_df(DB, colls, value, normalize=True, match=match)
            for value in values
        ]

    else:
        print('Getting data...')
        values = ['v_gene.gene', 'cdr3_len']
        dfs = [
            get_freq_df(DB, colls, value, normalize=True, match=match)
            for value in values
        ]

    #now plot the figures for each value
    for df, value in zip(dfs, values):
        print('-----------')
        print(value)
        print('-----------')
        for collection in df.columns:
            print(collection)
            #Try to plot the value unless the requested value is invalid
            try:
                df2 = pd.DataFrame(df[collection]).reset_index().melt(
                    id_vars='index', value_vars=df.columns)
                try:
                    fam = [d.split('-')[0] for d in df2['index']]
                    df2['fam'] = fam
                except AttributeError:
                    None

                plt.figure(figsize=[12, 4])
                try:
                    g = sns.barplot(data=df2,
                                    x='index',
                                    y='value',
                                    hue='fam',
                                    dodge=False)
                except ValueError:
                    g = sns.barplot(data=df2,
                                    x='index',
                                    y='value',
                                    dodge=False)
                try:
                    g.get_legend().remove()
                except AttributeError:
                    None
                plt.xticks(rotation=90)
                plt.tight_layout()
                plt.show()
                print(' ')
            except ValueError:
                print('The value you requested is in valid')