Exemple #1
0
def main():
    if len(sys.argv) == 1:
        sys.argv.append("-h")
    args = get_parser().parse_args()

    if args.list_platforms:
        list_opencl_platforms()
        sys.exit(0)

    modelfile = os.path.abspath(args.model)
    if args.section is None:
        try:
            args.section = np.load(modelfile).item().meta['section']
        except:
            sys.stderr.write(
                "No 'section' found in modelfile, try specifying --section.\n")
            sys.exit(1)

    #TODO: handle case where there are pre-existing files.
    if args.watch is not None:
        # An optional component
        from nanonet.watcher import Fast5Watcher
        initial_jobs = iterate_fast5(args.input, paths=True)
        fast5_files = Fast5Watcher(args.input,
                                   timeout=args.watch,
                                   initial_jobs=initial_jobs)
    else:
        sort_by_size = 'desc' if args.platforms is not None else None
        fast5_files = iterate_fast5(args.input,
                                    paths=True,
                                    strand_list=args.strand_list,
                                    limit=args.limit,
                                    sort_by_size=sort_by_size)

    fix_args = [modelfile]
    fix_kwargs = {
        a: getattr(args, a)
        for a in ('min_len', 'max_len', 'section', 'event_detect',
                  'fast_decode', 'write_events', 'ed_params', 'sloika_model')
    }

    # Define worker functions
    workers = []
    if not args.exc_opencl:
        cpu_function = partial(process_read, *fix_args, **fix_kwargs)
        workers.extend([(cpu_function, None)] * args.jobs)
    if args.platforms is not None:
        if cl is None:
            raise ImportError('pyopencl is not installed, install with pip.')
        for platform in args.platforms:
            vendor, device_id, n_files = platform.split(':')
            pa = ProcessAttr(use_opencl=True,
                             vendor=vendor,
                             device_id=int(device_id))
            fargs = fix_args + [pa]
            opencl_function = partial(process_read_opencl, *fargs,
                                      **fix_kwargs)
            workers.append((opencl_function, int(n_files)))

    # Select how to spread load
    if args.platforms is None:
        # just CPU
        worker, n_files = workers[0]
        mapper = tang_imap(worker,
                           fast5_files,
                           threads=args.jobs,
                           unordered=True)
    elif len(workers) == 1:
        # single opencl device
        #    need to wrap files in lists, and unwrap results
        worker, n_files = workers[0]
        fast5_files = group_by_list(fast5_files, [n_files])
        mapper = itertools.chain.from_iterable(
            itertools.imap(worker, fast5_files))
    else:
        # Heterogeneous compute
        mapper = JobQueue(fast5_files, workers)

    # Off we go
    n_reads = 0
    n_bases = 0
    n_events = 0
    timings = [0.0, 0.0]
    t0 = now()
    with FastaWrite(args.output, args.fastq) as fasta:
        for result in mapper:
            if result is None:
                continue
            data, time = result
            fname, call_data, _, n_ev = data
            name, _ = short_names(fname)
            basecall, quality = call_data
            if args.fastq:
                fasta.write(name, basecall, quality)
            else:
                fasta.write(name, basecall)
            n_reads += 1
            n_bases += len(basecall)
            n_events += n_ev
            timings = [x + y for x, y in zip(timings, time)]
    t1 = now()
    sys.stderr.write(
        'Basecalled {} reads ({} bases, {} events) in {}s (wall time)\n'.
        format(n_reads, n_bases, n_events, t1 - t0))
    if n_reads > 0:
        network, decoding = timings
        sys.stderr.write(
            'Run network: {:6.2f} ({:6.3f} kb/s, {:6.3f} kev/s)\n'
            'Decoding:    {:6.2f} ({:6.3f} kb/s, {:6.3f} kev/s)\n'.format(
                network,
                n_bases / 1000.0 / network,
                n_events / 1000.0 / network,
                decoding,
                n_bases / 1000.0 / decoding,
                n_events / 1000.0 / decoding,
            ))
Exemple #2
0
def main():
    if len(sys.argv) == 1:
        sys.argv.append("-h")
    args = get_parser().parse_args()
    
    logging.basicConfig(format='[%(asctime)s - %(name)s] %(message)s', datefmt='%H:%M:%S', level=logging.DEBUG)
    if not args.debug:
        logging.disable('root')
    logging.info('Starting 2D basecalling.')
 
    modelfiles = {
        'template': os.path.abspath(args.template_model),
        'complement': os.path.abspath(args.complement_model)
    }
            
    #TODO: handle case where there are pre-existing files.
    if args.watch is not None:
        # An optional component
        from nanonet.watcher import Fast5Watcher
        fast5_files = Fast5Watcher(args.input, timeout=args.watch)
    else:
        sort_by_size = None
        fast5_files = iterate_fast5(args.input, paths=True, strand_list=args.strand_list, limit=args.limit, sort_by_size=sort_by_size)

    fix_args = [
        modelfiles
    ]
    fix_kwargs = {a: getattr(args, a) for a in ( 
        'min_len', 'max_len', 'section',
        'event_detect', 'fast_decode',
        'write_events', 'opencl_2d', 'ed_params',
        'sloika_model'
    )}

    # Define worker functions   
    mapper = tang_imap(
        process_read_2d, fast5_files,
        fix_args=fix_args, fix_kwargs=fix_kwargs,
        threads=args.jobs, unordered=True
    )

    # Off we go
    n_reads = 0
    n_bases = 0
    n_events = 0
    n_bases_2d = 0
    timings = [0.0, 0.0, 0.0]
    t0 = now()
    sections = ('template', 'complement', '2d')
    if args.output_prefix is not None:
        ext = 'fastq' if args.fastq else 'fasta'
        filenames = ['{}_{}.{}'.format(args.output_prefix, x, ext) for x in sections]
    else:
        filenames = ['-'] * 3

    with FastaWrite(filenames[0], args.fastq) as fasta_temp, FastaWrite(filenames[1], args.fastq) as fasta_comp, FastaWrite(filenames[2], args.fastq) as fasta_2d:
        for result in mapper:
            if result['template'] is None:
                continue
            data, time = result['template']
            fname, basecall, _, n_ev = data
            basecall, quality = basecall
            name, _ = short_names(fname)
            if args.fastq:
                fasta_temp.write(name, basecall, quality)
            else:
                fasta_temp.write(name, basecall)
            n_reads += 1
            n_bases += len(basecall)
            n_events += n_ev
            timings = [x + y for x, y in zip(timings, time + (0.0,))]

            if result['complement'] is None:
                continue
            data, time = result['complement']
            _, basecall, _, _ = data
            basecall, quality = basecall
            if args.fastq:
                fasta_comp.write(name, basecall, quality)
            else:
                fasta_comp.write(name, basecall)

            if result['2d'] is None:
                continue
            basecall, time_2d = result['2d']
            basecall, quality = basecall
            if args.fastq:
                fasta_2d.write(name, basecall, quality)
            else:
                fasta_2d.write(name, basecall)
            n_bases_2d += len(basecall)
            timings[2] += time_2d
    t1 = now()

    sys.stderr.write('Processed {} reads in {}s (wall time)\n'.format(n_reads, t1 - t0))
    if n_reads > 0:
        network, decoding, call_2d  = timings
        time_2d = 0 if n_bases_2d == 0 else n_bases_2d/1000.0/call_2d
        sys.stderr.write(
            '1D Run network: {:6.2f} ({:6.3f} kb/s, {:6.3f} kev/s)\n'
            '1D Decoding:    {:6.2f} ({:6.3f} kb/s, {:6.3f} kev/s)\n'
            '2D calling:     {:6.2f} ({:6.3f} kb/s)\n'
            .format(
                network, n_bases/1000.0/network, n_events/1000.0/network,
                decoding, n_bases/1000.0/decoding, n_events/1000.0/decoding,
                call_2d, time_2d
            )
        )
Exemple #3
0
def main():
    if len(sys.argv) == 1:
        sys.argv.append("-h")
    args = get_parser().parse_args()

    if not args.cuda:
        args.nseqs = 1

    if not os.path.exists(args.workspace):
        os.makedirs(args.workspace)

    # file names for training
    tag = random_string()
    modelfile  = os.path.abspath(args.model)
    outputfile = os.path.abspath(args.output)
    temp_name = os.path.abspath(os.path.join(
        args.workspace, 'nn_data_{}_'.format(tag)
    ))
    config_name = os.path.abspath(os.path.join(
        args.workspace, 'nn_{}.cfg'.format(tag)
    ))

    # Create currennt training input files
    trainfile = '{}{}'.format(temp_name, 'train.netcdf')
    valfile = '{}{}'.format(temp_name, 'validation.netcdf')
    inputs = (
        (args.train, args.train_list, trainfile),
        (args.val, args.val_list, valfile),
    )
    fix_kwargs = {
        'window':args.window,
        'kmer_len':args.kmer_length,
        'alphabet':args.bases,
        'callback_kwargs':{'section':args.section, 'kmer_len':args.kmer_length}
    }
    for results in tang_imap(prepare_input_file, inputs, fix_kwargs=fix_kwargs, threads=2):
        n_chunks, n_features, out_kmers = results
        if n_chunks == 0:
            raise RuntimeError("No training data written.")


    # fill-in templated items in model
    n_states = len(out_kmers)
    with open(modelfile, 'r') as model:
        mod = model.read()
    mod = mod.replace('<section>', args.section)
    mod = mod.replace('<n_features>', str(n_features))
    mod = mod.replace('<n_states>', str(n_states))
    try:
        mod_meta = json.loads(mod)['meta']
    except Exception as e:
        mod_meta = dict()
    mod_meta['n_features'] = n_features
    mod_meta['kmers'] = out_kmers
    mod_meta['window'] = args.window

    modelfile = os.path.abspath(os.path.join(
        args.workspace, 'input_model.jsn'
    ))
    with open(modelfile, 'w') as model:
        model.write(mod)
    final_network = "{}_final.jsn".format(outputfile)
    best_network_prefix = "{}_auto".format(outputfile)
    # currennt appends some bits here

    # currennt cfg files
    with open(config_name, 'w') as currennt_cfg:
        if not args.cuda:
            currennt_cfg.write(conf_line('cuda', 'false'))
        # IO
        currennt_cfg.write(conf_line("cache_path", args.cache_path))
        currennt_cfg.write(conf_line("network", modelfile))
        currennt_cfg.write(conf_line("train_file", trainfile))
        currennt_cfg.write(conf_line("val_file", valfile))
        currennt_cfg.write(conf_line("save_network", final_network))
        currennt_cfg.write(conf_line("autosave_prefix", best_network_prefix))
        # Tunable parameters
        currennt_cfg.write(conf_line("max_epochs", args.max_epochs))
        currennt_cfg.write(conf_line("max_epochs_no_best", args.max_epochs_no_best))
        currennt_cfg.write(conf_line("validate_every", args.validate_every))
        currennt_cfg.write(conf_line("parallel_sequences", args.parallel_sequences))
        currennt_cfg.write(conf_line("learning_rate", args.learning_rate))
        currennt_cfg.write(conf_line("momentum", args.momentum))
        # Fixed parameters
        currennt_cfg.write(conf_line("train", "true"))
        currennt_cfg.write(conf_line("weights_dist", "normal"))
        currennt_cfg.write(conf_line("weights_normal_sigma", "0.1"))
        currennt_cfg.write(conf_line("weights_normal_mean", "0"))
        currennt_cfg.write(conf_line("stochastic", "true"))
        currennt_cfg.write(conf_line("input_noise_sigma", "0.0"))
        currennt_cfg.write(conf_line("shuffle_fractions", "false"))
        currennt_cfg.write(conf_line("shuffle_sequences", "true"))
        currennt_cfg.write(conf_line("autosave_best", "true"))

    # autosave metadata in case currennt crashes or is terminated
    meta_name = "{}.meta.jsn".format(best_network_prefix)
    print "Saving metadata as {}".format(meta_name)
    json.dump(mod_meta, open(meta_name, 'w'))

    # run currennt
    print "\n\nRunning currennt with: {}".format(config_name)
    try:
        run_currennt_noisy(config_name, device=args.device)
    except KeyboardInterrupt:
        # in case the user exits currennt, still save the numpy network
        pass

    # Currennt won't pass through our meta in the model, amend the output
    # and write out a numpy version of the network
    best_network = "{}.best.jsn".format(best_network_prefix)
    best_network_numpy = "{}_best.npy".format(outputfile)

    print "Adding model meta to currennt best network: {}".format(best_network)
    mod = json.load(open(best_network, 'r'))
    mod['meta'] = mod_meta
    json.dump(mod, open(best_network, 'w'))
    print "Transforming network to numpy pickle: {}".format(best_network_numpy)
    mod = network_to_numpy(mod)
    np.save(best_network_numpy, mod)