Esempio n. 1
0
def create_labelled_samples(args):
    logger = get_named_logger('Prepare')
    regions = get_regions(args.bam, args.regions)
    reg_str = '\n'.join(['\t\t\t{}'.format(r) for r in regions])
    logger.info('Got regions:\n{}'.format(reg_str))

    labels_counter = Counter()

    no_data = False
    with DataStore(args.output, 'w') as ds:
        # write feature options to file
        logger.info("Writing meta data to file.")
        with DataStore(args.model) as model:
            meta = {
                k: model.meta[k]
                for k in ('medaka_features_kwargs', 'medaka_feature_decoding')
            }
        ds.update_meta(meta)
        # TODO: this parallelism would be better in `SampleGenerator.bams_to_training_samples`
        #       since training alignments are usually chunked.
        with concurrent.futures.ProcessPoolExecutor(
                max_workers=args.threads) as executor:
            # break up overly long chunks
            MAX_SIZE = int(1e6)
            regions = itertools.chain(*(r.split(MAX_SIZE) for r in regions))
            futures = [
                executor.submit(_labelled_samples_worker, args, reg)
                for reg in regions
            ]
            for fut in concurrent.futures.as_completed(futures):
                if fut.exception() is None:
                    samples, region, fencoder_args, fencoder_decoder = fut.result(
                    )
                    logger.info("Writing {} samples for region {}".format(
                        len(samples), region))
                    for sample in samples:
                        ds.write_sample(sample)
                else:
                    logger.info(fut.exception())
                fut._result = None  # python issue 27144
        no_data = ds.n_samples == 0

    if no_data:
        logger.critical(
            "Warning: No training data was written to file, deleting output.")
        os.remove(args.output)
Esempio n. 2
0
def predict(args):
    """Inference program."""
    logger_level = logging.getLogger(__package__).level
    if logger_level > logging.DEBUG:
        os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2"

    from keras.models import load_model
    from keras import backend as K

    args.regions = get_regions(args.bam, region_strs=args.regions)
    logger = get_named_logger('Predict')
    logger.info('Processing region(s): {}'.format(' '.join(
        str(r) for r in args.regions)))

    # write class names to output
    with DataStore(args.model) as ds:
        meta = ds.meta
    with DataStore(args.output, 'w', verify_on_close=False) as ds:
        ds.update_meta(meta)

    logger.info("Setting tensorflow threads to {}.".format(args.threads))
    K.tf.logging.set_verbosity(K.tf.logging.ERROR)
    K.set_session(
        K.tf.Session(config=K.tf.ConfigProto(
            intra_op_parallelism_threads=args.threads,
            inter_op_parallelism_threads=args.threads)))

    # Split overly long regions to maximum size so as to not create
    #   massive feature matrices
    MAX_REGION_SIZE = int(1e6)  # 1Mb
    regions = []
    for region in args.regions:
        if region.size > MAX_REGION_SIZE:
            regs = region.split(MAX_REGION_SIZE, args.chunk_ovlp)
        else:
            regs = [region]
        regions.extend(regs)

    logger.info("Processing {} long region(s) with batching.".format(
        len(regions)))
    model = medaka.models.load_model(args.model, time_steps=args.chunk_len)
    # the returned regions are those where the pileup width is smaller than chunk_len
    remainder_regions = run_prediction(args.output,
                                       args.bam,
                                       regions,
                                       model,
                                       args.model,
                                       args.rle_ref,
                                       args.read_fraction,
                                       args.chunk_len,
                                       args.chunk_ovlp,
                                       batch_size=args.batch_size,
                                       save_features=args.save_features,
                                       tag_name=args.tag_name,
                                       tag_value=args.tag_value,
                                       tag_keep_missing=args.tag_keep_missing)

    # short/remainder regions: just do things without chunking. We can do this
    # here because we now have the size of all pileups (and know they are small).
    # TODO: can we avoid calculating pileups twice whilst controlling memory?
    if len(remainder_regions) > 0:
        logger.info("Processing {} short region(s).".format(
            len(remainder_regions)))
        model = medaka.models.load_model(args.model, time_steps=None)
        for region in remainder_regions:
            new_remainders = run_prediction(
                args.output,
                args.bam,
                [region[0]],
                model,
                args.model,
                args.rle_ref,
                args.read_fraction,
                args.chunk_len,
                args.chunk_ovlp,  # these won't be used
                batch_size=args.batch_size,
                save_features=args.save_features,
                tag_name=args.tag_name,
                tag_value=args.tag_value,
                tag_keep_missing=args.tag_keep_missing,
                enable_chunking=False)
            if len(new_remainders) > 0:
                # shouldn't get here
                ignored = [x[0] for x in new_remainders]
                n_ignored = len(ignored)
                logger.warning("{} regions were not processed: {}.".format(
                    n_ignored, ignored))

    logger.info("Finished processing all regions.")

    if args.check_output:
        logger.info("Validating and finalising output data.")
        with DataStore(args.output, 'a') as ds:
            pass
Esempio n. 3
0
def predict(args):
    """Inference program."""
    os.environ["TF_CPP_MIN_LOG_LEVEL"]="2"
    from keras.models import load_model
    from keras import backend as K

    args.regions = get_regions(args.bam, region_strs=args.regions)
    logger = get_named_logger('Predict')
    logger.info('Processing region(s): {}'.format(' '.join(str(r) for r in args.regions)))

    # write class names to output
    with DataStore(args.model) as ds:
        meta = ds.meta
    with DataStore(args.output, 'w') as ds:
        ds.update_meta(meta)

    logger.info("Setting tensorflow threads to {}.".format(args.threads))
    K.tf.logging.set_verbosity(K.tf.logging.ERROR)
    K.set_session(K.tf.Session(
        config=K.tf.ConfigProto(
            intra_op_parallelism_threads=args.threads,
            inter_op_parallelism_threads=args.threads)
    ))

    # Split overly long regions to maximum size so as to not create
    #   massive feature matrices, then segregate those which cannot be
    #   batched.
    MAX_REGION_SIZE = int(1e6)  # 1Mb
    long_regions = []
    short_regions = []
    for region in args.regions:
        if region.size > MAX_REGION_SIZE:
            regs = region.split(MAX_REGION_SIZE, args.chunk_ovlp)
        else:
            regs = [region]
        for r in regs:
            if r.size > args.chunk_len:
                long_regions.append(r)
            else:
                short_regions.append(r)
    logger.info("Found {} long and {} short regions.".format(
        len(long_regions), len(short_regions)))

    if len(long_regions) > 0:
        logger.info("Processing long regions.")
        model = medaka.models.load_model(args.model, time_steps=args.chunk_len)
        run_prediction(
            args.output, args.bam, long_regions, model, args.model, args.rle_ref, args.read_fraction,
            args.chunk_len, args.chunk_ovlp,
            batch_size=args.batch_size, save_features=args.save_features,
            tag_name=args.tag_name, tag_value=args.tag_value, tag_keep_missing=args.tag_keep_missing
        )

    # short regions must be done individually since they have differing lengths
    #   TODO: consider masking (it appears slow to apply wholesale), maybe
    #         step down args.chunk_len by a constant factor until nothing remains.
    if len(short_regions) > 0:
        logger.info("Processing short regions")
        model = medaka.models.load_model(args.model, time_steps=None)
        for region in short_regions:
            chunk_len = region.size // 2
            chunk_ovlp = chunk_len // 10 # still need overlap as features will be longer
            run_prediction(
                args.output, args.bam, [region], model, args.model, args.rle_ref, args.read_fraction,
                chunk_len, chunk_ovlp,
                batch_size=args.batch_size, save_features=args.save_features,
                tag_name=args.tag_name, tag_value=args.tag_value, tag_keep_missing=args.tag_keep_missing
            )
    logger.info("Finished processing all regions.")