Ejemplo n.º 1
0
def find_and_extract_cm_model(args, analyzed_hits, rfam=None, timeout=None):
    if rfam is None:
        rfam = RfamInfo()

    ml.info('Infer homology - searching RFAM for best matching model.')
    cmscan_results = get_cm_model_table(args.blast_query,
                                        threads=args.threads,
                                        rfam=rfam,
                                        timeout=timeout)
    best_matching_cm_model = select_best_matching_model_from_cmscan(
        cmscan_results)
    analyzed_hits.best_matching_model = best_matching_cm_model

    if best_matching_cm_model is not None:
        ml.info('Infer homology - best matching RFAM model: {}'.format(
            analyzed_hits.best_matching_model['target_name']))

    try:
        if args.cm_file:
            # use provided cm file
            ml.info('Infer homology - using provided CM file: {}'.format(
                args.cm_file))
            fd, cm_model_file = mkstemp(prefix='rba_',
                                        suffix='_27',
                                        dir=CONFIG.tmpdir)
            os.close(fd)
            ml.debug('Making a copy of provided cm model file to: {}'.format(
                cm_model_file))
            shutil.copy(args.cm_file, cm_model_file)

        elif args.use_rfam:
            ml.info('Infer homology - using RFAM CM file as reference model.')
            if analyzed_hits.best_matching_model is None:
                ml.error(
                    'No RFAM model was matched with score > 0. Nothing to build homology to.'
                )
                raise exceptions.MissingCMexception

            cm_model_file = run_cmfetch(
                rfam.file_path,
                analyzed_hits.best_matching_model['target_name'],
                timeout=timeout)
        else:
            ml.info('Infer homology - using RSEARCH to build model')
            # default to using RSEARCH
            cm_model_file = build_cm_model_rsearch(analyzed_hits.query,
                                                   CONFIG.rsearch_ribosum,
                                                   timeout=timeout)

        return cm_model_file, analyzed_hits
    except exceptions.SubprocessException as e:
        ml.error("Can't obtain covariance model.")
        ml.error(str(e))
        raise e
Ejemplo n.º 2
0
def wrapped_ending_with_prediction(
    args_inner,
    analyzed_hits,
    pred_method=None,
    method_params=None,
    used_cm_file=None,
    multi_query=False,
    iteration=0,
):
    """
    wrapper for prediction of secondary structures
    :param args_inner: Namespace of input arguments
    :param analyzed_hits: BlastSearchRecompute object
    :param pred_method:
    :param method_params:
    :param used_cm_file: cmfile if cmfile is known (user given or computed)
    :return:
    """
    ml.debug(fname())
    exec_time = {}
    msg = 'Entering structure prediction..'
    if ml.level < 21:
        ml.info(msg)
    else:
        print(msg)
        ml.info(msg)

    if pred_method is None:
        pred_method = args_inner.prediction_method

    if isinstance(pred_method, str):
        pred_method = (pred_method, )

    if method_params is None:
        method_params = args_inner.pred_params

    # ======= filter if needed =======
    # do the filtering based on e-val or bitscore
    # homologous hits still gets used for prediction

    # annotate ambiguous bases
    query = BA_support.annotate_ambiguos_base(analyzed_hits.query)

    # copy the list before filtering
    all_hits_list = [i.extension for i in analyzed_hits.get_all_hits()]

    if args_inner.filter_by_eval is not None:
        hits2predict = filter_by_eval(analyzed_hits.get_all_hits(),
                                      BA_support.blast_hit_getter_from_subseq,
                                      args_inner.filter_by_eval)
        _hits = HitList()
        for h in hits2predict:
            _hits.append(h)
        analyzed_hits.hits = _hits
    elif args_inner.filter_by_bitscore is not None:
        hits2predict = filter_by_bits(analyzed_hits.get_all_hits(),
                                      BA_support.blast_hit_getter_from_subseq,
                                      args_inner.filter_by_bitscore)
        _hits = HitList()
        for h in hits2predict:
            _hits.append(h)
        analyzed_hits.hits = _hits
    else:
        analyzed_hits.hits = analyzed_hits.get_all_hits()

    # if used_cm_file is provided do not override it with CM from RFAM
    # if use_rfam flag was given, then used_cm_file is already the best_matching model
    # if analyzed_hits.best_matching_model is None - then we could not find the best matching model in RFAM
    #  and the rfam based methods should fail (i.e. not predict anything)
    delete_cm = False
    if used_cm_file is None and analyzed_hits.best_matching_model is not None:
        rfam = RfamInfo()
        used_cm_file = run_cmfetch(
            rfam.file_path, analyzed_hits.best_matching_model['target_name'])
        delete_cm = True

    fd, seqs2predict_fasta = mkstemp(prefix='rba_',
                                     suffix='_83',
                                     dir=CONFIG.tmpdir)
    with os.fdopen(fd, 'w') as fah:
        for hit in analyzed_hits.hits:
            if len(hit.extension.seq) == 0:
                continue
            fah.write('>{}\n{}\n'.format(hit.extension.id,
                                         str(hit.extension.seq)))

    if not isinstance(method_params, dict):
        raise Exception('prediction method parameters must be python dict')

    # prediction methods present in analyzed_hits
    #  which might be loaded from intermediate file

    # check if structures of a method are predicted for all required hit
    # check if the prediction parameters of such method were same

    prediction_msgs = []
    # compute prediction methods which were not computed
    for pkey in set(pred_method):
        # add sha1 hashes
        nh = sha1()
        nh.update(str(sorted(method_params.get(pkey, {}).items())).encode())
        current_hash = nh.hexdigest()

        if all(pkey in h.extension.letter_annotations for h in analyzed_hits.hits) and \
                len(
                    {
                        h.extension.annotations.get('sha1', {}).get(pkey, None) for h in analyzed_hits.hits
                    } | {current_hash, }
                ) == 1:
            msg_skip = 'All structures already computed for {}. Skipping...'.format(
                pkey)
            ml.info(msg_skip)
            if ml.level > 20:
                print(msg_skip, flush=True)
            continue

        msg_run = 'Running: {}...'.format(pkey)
        ml.info(msg_run)

        if ml.level > 20:
            print(msg_run, flush=True)

        structures, etime, msgs = repredict_structures_for_homol_seqs(
            query,
            seqs2predict_fasta,
            args_inner.threads,
            prediction_method=pkey,
            pred_method_params=method_params,
            all_hits_list=all_hits_list,
            seqs2predict_list=[i.extension for i in analyzed_hits.hits],
            use_cm_file=used_cm_file,
        )

        exec_time[pkey] = etime

        if structures is None:
            msg = 'Structures not predicted with {} method'.format(pkey)
            ml.info(msg)
            if ml.level > 20:
                print('STATUS: ' + msg)

        else:
            for i, hit in enumerate(analyzed_hits.hits):
                assert str(hit.extension.seq) == str(structures[i].seq)
                hit.extension.annotations['sss'] += [pkey]

                hit.extension.annotations['msgs'] += structures[
                    i].annotations.get('msgs', [])

                # expects "predicted" in annotations - for now, if not given, default is True, as not all prediction
                #  methods implement "predicted" in their output
                if structures[i].annotations.get('predicted', True):
                    hit.extension.letter_annotations[pkey] = structures[
                        i].letter_annotations['ss0']

                if 'sha1' not in hit.extension.annotations:
                    hit.extension.annotations['sha1'] = dict()
                hit.extension.annotations['sha1'][pkey] = current_hash

                try:
                    del hit.extension.letter_annotations['ss0']
                except KeyError:
                    pass
                try:
                    hit.extension.annotations['sss'].remove('ss0')
                except ValueError:
                    pass

            analyzed_hits.update_hit_stuctures()

        # check if msgs are not empty
        if msgs:
            prediction_msgs.append('{}: {}'.format(pkey, '\n'.join(msgs)))

        analyzed_hits.msgs = prediction_msgs

        with open(args_inner.blast_in + '.r-' + args_inner.sha1[:10],
                  'r+') as f:
            all_saved_data = json.load(f)
            all_saved_data[iteration] = blastsearchrecompute2dict(
                analyzed_hits)
            f.seek(0)
            f.truncate()
            json.dump(all_saved_data, f, indent=2)

    # remove structures predicted by different methods (which might be saved from previous computation)
    for hit in analyzed_hits.hits:
        for pkey in set(hit.extension.letter_annotations.keys()):
            if pkey not in pred_method:
                del hit.extension.letter_annotations[pkey]
                try:
                    hit.extension.annotations['sss'].remove(pkey)
                except ValueError:
                    pass

    BA_support.remove_one_file_with_try(seqs2predict_fasta)

    if delete_cm:
        BA_support.remove_one_file_with_try(used_cm_file)

    add_loc_to_description(analyzed_hits)

    # write html if requested
    if args_inner.html:
        html_file = iter2file_name(args_inner.html, multi_query, iteration)
        ml.info('Writing html to {}.'.format(html_file))
        with open(html_file, 'wb') as h:
            h.write(write_html_output(analyzed_hits))

    # write csv file if requested
    if args_inner.csv:
        csv_file = iter2file_name(args_inner.csv, multi_query, iteration)
        ml.info('Writing csv to {}.'.format(csv_file))
        analyzed_hits.to_csv(csv_file)

    # replace with json
    if args_inner.json:
        json_file = iter2file_name(args_inner.json, multi_query, iteration)
        ml.info('Writing json to {}.'.format(json_file))
        j_obj = json.dumps(blastsearchrecompute2dict(analyzed_hits), indent=2)
        if getattr(args_inner, 'zip_json', False):
            with open(json_file + '.gz', 'wb') as ff:
                ff.write(gzip.compress(j_obj.encode()))
        else:
            with open(json_file, 'w') as ff:
                ff.write(j_obj)

    if args_inner.pandas_dump:
        pickle_file = iter2file_name(args_inner.pandas_dump, multi_query,
                                     iteration)
        ml.info('Writing pandas pickle to {}.'.format(pickle_file))
        pandas.to_pickle(analyzed_hits.pandas, pickle_file)

    if args_inner.dump:
        dump_file = iter2file_name(args_inner.dump, multi_query, iteration)
        ml.info('Writing dump files base: {}.'.format(dump_file))
        with open(dump_file, 'wb') as pp:
            pickle.dump(analyzed_hits, pp, pickle.HIGHEST_PROTOCOL)

        with open(dump_file + '.time_dump', 'wb') as pp:
            pickle.dump(exec_time, pp, pickle.HIGHEST_PROTOCOL)

    return analyzed_hits
Ejemplo n.º 3
0
def lunch_computation(args_inner, shared_list=None):
    ml.debug(fname())
    if not shared_list:
        shared_list = []

    # update params if different config is requested
    CONFIG.override(tools_paths(args_inner.config_file))

    p_blast = BA_support.blast_in(args_inner.blast_in, b=args_inner.b_type)
    query_seqs = [i for i in SeqIO.parse(args_inner.blast_query, 'fasta')]

    if len(p_blast) != len(query_seqs):
        ml.error(
            'Number of query sequences in provided BLAST output file ({}) does not match number of query sequences'
            ' in query FASTA file ({}).'.format(len(p_blast), len(query_seqs)))
        sys.exit(1)

    # check if BLAST does not contain unexpected sequence characters
    validate_args.check_blast(p_blast)

    # create list of correct length if needed
    all_saved_data = [None] * len(query_seqs)
    saved_file = '{}.r-{}'.format(args_inner.blast_in, args_inner.sha1[:10])
    with open(saved_file, 'r+') as f:
        _saved = json.load(f)
        if _saved is None:
            f.seek(0)
            f.truncate()
            json.dump(all_saved_data, f)
        else:
            msg = "Loading backup data."
            print('STATUS: ' + msg)
            ml.info(msg + ' file: ' + saved_file)
            all_saved_data = _saved

            for saved_data in all_saved_data:
                # we can have partially computed data
                if saved_data is None:
                    continue
                if saved_data['args']['sha1'] != args_inner.sha1:
                    msg = "Input argument hash does not match the saved argument hash. "
                    if saved_data['args']['sha1'][:10] == args_inner.sha1[:10]:
                        msg += "This is because of truncating hashes to first 10 characters. "
                        msg += "Please remove the '{}' file.".format(
                            saved_file)
                        ml.error(msg)
                        sys.exit(1)
                    else:
                        msg += "Please remove the '{}' file.".format(
                            saved_file)
                        sys.exit(1)

    if len(p_blast) > 1:
        multi_query = True
    else:
        multi_query = False

    # this is done for each query
    ml_out_line = []
    all_analyzed = []
    for iteration, (bhp, query, saved_data) in enumerate(
            zip(p_blast, query_seqs, all_saved_data)):
        if saved_data is None:
            print('STATUS: processing query: {}'.format(query.id))
            validate_args.verify_query_blast(blast=bhp, query=query)

            analyzed_hits = BlastSearchRecompute(args_inner, query, iteration)
            analyzed_hits.multi_query = multi_query

            # run cm model build
            # allows to fail fast if rfam was selected and we dont find the model
            ih_model, analyzed_hits = find_and_extract_cm_model(
                args_inner, analyzed_hits)

            # select all
            all_blast_hits = BA_support.blast_hsps2list(bhp)

            if len(all_blast_hits) == 0:
                ml.error('No hits found in {} - {}. Nothing to do.'.format(
                    args_inner.blast_in, bhp.query))
                continue

            # filter if needed
            if args_inner.filter_by_eval is not None:
                tmp = filter_by_eval(all_blast_hits,
                                     BA_support.blast_hit_getter_from_hits,
                                     args_inner.filter_by_eval)
                if len(tmp) == 0 and len(all_blast_hits) != 0:
                    ml.error(
                        'The requested filter removed all BLAST hits {} - {}. Nothing to do.'
                        .format(args_inner.blast_in, bhp.query))
                    continue
            elif args_inner.filter_by_bitscore is not None:
                tmp = filter_by_bits(all_blast_hits,
                                     BA_support.blast_hit_getter_from_hits,
                                     args_inner.filter_by_bitscore)
                if len(tmp) == 0 and len(all_blast_hits) != 0:
                    ml.error(
                        'The requested filter removed all BLAST hits {} - {}. Nothing to do.'
                        .format(args_inner.blast_in, bhp.query))
                    continue

            all_short = all_blast_hits

            # now this is different for each mode
            if args_inner.mode == 'simple':
                analyzed_hits, homology_prediction, homol_seqs, cm_file_rfam_user = extend_simple_core(
                    analyzed_hits, query, args_inner, all_short, multi_query,
                    iteration, ih_model)
            elif args_inner.mode == 'locarna':
                analyzed_hits, homology_prediction, homol_seqs, cm_file_rfam_user = extend_locarna_core(
                    analyzed_hits, query, args_inner, all_short, multi_query,
                    iteration, ih_model)
            elif args_inner.mode == 'meta':
                analyzed_hits, homology_prediction, homol_seqs, cm_file_rfam_user = extend_meta_core(
                    analyzed_hits, query, args_inner, all_short, multi_query,
                    iteration, ih_model)
            else:
                raise ValueError(
                    'Unknown option - should be cached by argparse.')

            if len(analyzed_hits.hits) == 0:
                ml.error(
                    "Extension failed for all sequences. Please see the error message. You can also try '--mode simple'."
                )
                sys.exit(1)

            analyzed_hits.copy_hits()

            with open(args_inner.blast_in + '.r-' + args_inner.sha1[:10],
                      'r+') as f:
                all_saved_data = json.load(f)
                all_saved_data[iteration] = blastsearchrecompute2dict(
                    analyzed_hits)
                f.seek(0)
                f.truncate()
                json.dump(all_saved_data, f, indent=2)

        else:
            print(
                'STATUS: extended sequences loaded from backup file for query {}'
                .format(query.id))
            analyzed_hits = blastsearchrecomputefromdict(saved_data)

            # overwrite the saved args with current
            # this will update used prediction methods and other non essential stuff
            analyzed_hits.args = args_inner

            if analyzed_hits.args.cm_file:
                cm_file_rfam_user = analyzed_hits.args.cm_file
            else:
                cm_file_rfam_user = None

        all_analyzed.append(analyzed_hits)

        # write all hits to fasta
        fda, all_hits_fasta = mkstemp(prefix='rba_',
                                      suffix='_22',
                                      dir=CONFIG.tmpdir)
        os.close(fda)
        analyzed_hits.write_results_fasta(all_hits_fasta)

        out_line = []
        # multiple prediction params
        if args_inner.dev_pred:
            dp_list = []
            # acomodate more dev pred outputs
            dpfile = None
            if getattr(args_inner, 'dump', False):
                dpfile = args_inner.dump.strip('dump')
            if getattr(args_inner, 'pandas_dump', False):
                dpfile = args_inner.pandas_dump.strip('pandas_dump')
            if getattr(args_inner, 'json', False):
                dpfile = args_inner.json.strip('json')

            # optimization so the rfam cm file is used only once
            if cm_file_rfam_user is None and 'rfam' in ''.join(
                    args_inner.prediction_method):
                best_model = get_cm_model(args_inner.blast_query,
                                          threads=args_inner.threads)
                rfam = RfamInfo()
                cm_file_rfam_user = run_cmfetch(rfam.file_path, best_model)

            for method in args_inner.prediction_method:
                # cycle the prediction method settings
                # get set of params for each preditcion
                selected_pred_params = [
                    kk for kk in args_inner.pred_params if method in kk
                ]
                shuffle(selected_pred_params)
                # for method_params in args_inner.pred_params:
                for i, method_params in enumerate(selected_pred_params):
                    ah = deepcopy(analyzed_hits)

                    random_flag = BA_support.generate_random_name(
                        8, shared_list)
                    shared_list.append(random_flag)

                    pname = re.sub(' ', '', str(method))
                    flag = '|pred_params|' + random_flag

                    # rebuild the args only with actualy used prediction settings
                    ah.args.prediction_method = method
                    ah.args.pred_params = method_params

                    if getattr(args_inner, 'dump', False):
                        spa = args_inner.dump.split('.')
                        ah.args.dump = '.'.join(
                            spa[:-1]) + flag + '.' + spa[-1]
                    if getattr(args_inner, 'pandas_dump', False):
                        spa = args_inner.pandas_dump.split('.')
                        ah.args.pandas_dump = '.'.join(
                            spa[:-1]) + flag + '.' + spa[-1]
                    if getattr(args_inner, 'pdf_out', False):
                        spa = args_inner.pdf_out.split('.')
                        ah.args.pdf_out = '.'.join(
                            spa[:-1]) + flag + '.' + spa[-1]
                    if getattr(args_inner, 'json', False):
                        spa = args_inner.json.split('.')
                        ah.args.json = '.'.join(
                            spa[:-1]) + flag + '.' + spa[-1]

                    wrapped_ending_with_prediction(
                        args_inner=ah.args,
                        analyzed_hits=ah,
                        pred_method=method,
                        method_params=method_params,
                        used_cm_file=cm_file_rfam_user,
                        multi_query=multi_query,
                        iteration=iteration,
                    )
                    success = True
                    out_line.append(to_tab_delim_line_simple(ah.args))

                    dp_list.append((i, method_params, success, flag, pname,
                                    random_flag, args_inner.pred_params))

            if dpfile is not None:
                with open(dpfile + 'devPredRep', 'wb') as devf:
                    pickle.dump(dp_list, devf)
        else:
            wrapped_ending_with_prediction(
                args_inner=args_inner,
                analyzed_hits=analyzed_hits,
                used_cm_file=cm_file_rfam_user,
                multi_query=multi_query,
                iteration=iteration,
            )
            out_line.append(to_tab_delim_line_simple(args_inner))

        ml_out_line.append('\n'.join(out_line))

        if cm_file_rfam_user is not None and os.path.exists(cm_file_rfam_user):
            BA_support.remove_one_file_with_try(cm_file_rfam_user)

        BA_support.remove_one_file_with_try(all_hits_fasta)
    return '\n'.join(ml_out_line), all_analyzed