Esempio n. 1
0
 def test_load_model_from_file_with_wrong_metadata(self, filename):
     metadata = {'reverse': True, 'standardize': False}
     filepath = os.path.join(MODELS_DIR, filename)
     with self.assertWarns(RuntimeWarning):
         net = helpers.load_model(filepath, model_metadata=metadata)
     self.assertEqual(metadata['reverse'], net.metadata['reverse'])
     self.assertEqual(metadata['standardize'], net.metadata['standardize'])
Esempio n. 2
0
def main():
    args = get_parser().parse_args()

    worker_kwarg_names = ['back_prob', 'localpen', 'minscore', 'trim']

    model = helpers.load_model(args.model)

    fast5_reads = fast5utils.iterate_fast5_reads(
        args.read_dir, limit=args.limit, strand_list=args.input_strand_list,
        recursive=args.recursive)

    with helpers.open_file_or_stdout(args.output) as fh:
        for res in imap_mp(
                squiggle_match.worker, fast5_reads, threads=args.jobs,
                fix_kwargs=helpers.get_kwargs(args, worker_kwarg_names),
                unordered=True, init=squiggle_match.init_worker,
                initargs=[model, args.references]):
            if res is None:
                continue
            read_id, sig, score, path, squiggle, bases = res
            bases = bases.decode('ascii')
            fh.write('#{} {}\n'.format(read_id, score))
            for i, (s, p) in enumerate(zip(sig, path)):
                fh.write('{}\t{}\t{}\t{}\t{}\t{}\t{}\t{}\n'.format(
                    read_id, i, s, p, bases[p], squiggle[p, 0], squiggle[p, 1],
                    squiggle[p, 2]))
Esempio n. 3
0
def main(argv):
    """Main function to process mapping for each read using functions in prepare_mapping_funcs"""
    args = parser.parse_args()
    print("Running prepare_mapping using flip-flop remapping")

    if not args.overwrite:
        if os.path.exists(args.output):
            print("Cowardly refusing to overwrite {}".format(args.output))
            sys.exit(1)

    # Make an iterator that yields all the reads we're interested in.
    fast5_reads = fast5utils.iterate_fast5_reads(
        args.input_folder,
        limit=args.limit,
        strand_list=args.input_strand_list)

    # Set up arguments (kwargs) for the worker function for each read
    kwargs = helpers.get_kwargs(args,
                                ['alphabet', 'collapse_alphabet', 'device'])
    kwargs[
        'per_read_params_dict'] = prepare_mapping_funcs.get_per_read_params_dict_from_tsv(
            args.input_per_read_params)
    kwargs['references'] = helpers.fasta_file_to_dict(args.references)
    kwargs['model'] = helpers.load_model(args.model)
    workerFunction = prepare_mapping_funcs.oneread_remap  # remaps a single read using flip-flip network

    results = imap_mp(workerFunction,
                      fast5_reads,
                      threads=args.jobs,
                      fix_kwargs=kwargs,
                      unordered=True)

    # results is an iterable of dicts
    # each dict is a set of return values from a single read
    prepare_mapping_funcs.generate_output_from_results(results, args)
Esempio n. 4
0
def main():
    args = parser.parse_args()
    model = load_model(args.model)

    json_out = model.json(args.params)

    with open_file_or_stdout(args.output) as fh:
        json.dump(json_out, fh, indent=4, cls=JsonEncoder)
Esempio n. 5
0
def main():
    args = get_parser().parse_args()
    model_md5 = file_md5(args.model)
    model = load_model(args.model)

    json_out = model.json()
    json_out['md5sum'] = model_md5

    with open_file_or_stdout(args.output) as fh:
        json.dump(json_out, fh, indent=4, cls=JsonEncoder)
Esempio n. 6
0
def main():
    args = parser.parse_args()
    #model_md5 = file_md5(args.model)
    model = load_model(args.model)
    print("Previous alphabet",model.sublayers[-1].output_alphabet)

    for attr in ["mod_bases","mod_labels","mod_name_conv","ordered_mod_long_names"]:
        if "mod" in attr:
            print(attr,getattr(model.sublayers[-1],attr))
    #print("Previous alphabet",dir(model.sublayers[-1]))#).mod_long_names)

    model.sublayers[-1].output_alphabet = args.alphabet
    save_model(model,args.output)
Esempio n. 7
0
def main():
    args = parser.parse_args()

    predict_squiggle = helpers.load_model(args.model)

    with helpers.open_file_or_stdout(args.output) as fh:
        for seq in SeqIO.parse(args.input, 'fasta'):
            seqstr = str(seq.seq)
            embedded_seq_numpy = np.expand_dims(
                squiggle_match.embed_sequence(seqstr), axis=1)
            embedded_seq_torch = torch.tensor(embedded_seq_numpy,
                                              dtype=torch.float32)

            with torch.no_grad():
                squiggle = np.squeeze(
                    predict_squiggle(embedded_seq_torch).cpu().numpy(), axis=1)

            fh.write('base\tcurrent\tsd\tdwell\n')
            for base, (mean, logsd, dwell) in zip(seq.seq, squiggle):
                fh.write('{}\t{}\t{}\t{}\n'.format(base, mean, np.exp(logsd),
                                                   np.exp(-dwell)))
Esempio n. 8
0
def worker_init(device, modelname, chunk_size, overlap,
                read_params, alphabet, max_concurrent_chunks,
                fastq, qscore_scale, qscore_offset, beam, posterior,
                temperature):
    global all_read_params
    global process_read_partial

    all_read_params = read_params
    device = helpers.set_torch_device(device)
    model = load_model(modelname).to(device)
    stride = guess_model_stride(model)
    chunk_size = chunk_size * stride
    overlap = overlap * stride

    n_can_base = len(alphabet)
    n_can_state = nstate_flipflop(n_can_base)

    def process_read_partial(read_filename, read_id, read_params):
        res = process_read(read_filename, read_id,
                           model, chunk_size, overlap, read_params,
                           n_can_state, stride, alphabet,
                           max_concurrent_chunks, fastq, qscore_scale,
                           qscore_offset, beam, posterior, temperature)
        return (read_id, *res)
Esempio n. 9
0
def main():
    """Main function to process mapping for each read using functions in prepare_mapping_funcs"""
    args = parser.parse_args()
    print("Running prepare_mapping using flip-flop remapping")

    if not args.overwrite:
        if os.path.exists(args.output):
            print("Cowardly refusing to overwrite {}".format(args.output))
            sys.exit(1)

    # Create alphabet and check for consistency
    modified_bases = [elt[0] for elt in args.mod]
    canonical_bases = [elt[1] for elt in args.mod]
    for b in modified_bases:
        assert len(
            b
        ) == 1, "Modified bases must be a single character, got {}".format(b)
        assert b not in args.alphabet, "Modified base must not be a canonical base, got {}".format(
            b)
    for b in canonical_bases:
        assert len(
            b
        ) == 1, "Canonical coding for modified bases must be a single character, got {}".format(
            b)
        assert b in args.alphabet, "Canonical coding for modified base must be a canonical base, got {}".format(
            b)
    full_alphabet = args.alphabet + ''.join(modified_bases)
    flat_alphabet = args.alphabet + ''.join(canonical_bases)
    modification_names = [elt[2] for elt in args.mod]

    alphabet_info = alphabet.AlphabetInfo(full_alphabet,
                                          flat_alphabet,
                                          modification_names,
                                          do_reorder=True)

    print("Converting references to labels using {}".format(
        str(alphabet_info)))

    # Make an iterator that yields all the reads we're interested in.
    fast5_reads = fast5utils.iterate_fast5_reads(
        args.input_folder,
        limit=args.limit,
        strand_list=args.input_strand_list,
        recursive=args.recursive)

    # Set up arguments (kwargs) for the worker function for each read
    kwargs = {}
    kwargs[
        'per_read_params_dict'] = prepare_mapping_funcs.get_per_read_params_dict_from_tsv(
            args.input_per_read_params)
    kwargs['model'] = helpers.load_model(args.model)
    kwargs['alphabet_info'] = alphabet_info
    kwargs['max_read_length'] = args.max_read_length
    kwargs['localpen'] = args.localpen

    # remaps a single read using flip-flip network
    workerFunction = prepare_mapping_funcs.oneread_remap

    def iter_jobs():
        references = bio.fasta_file_to_dict(args.references,
                                            alphabet=full_alphabet)
        for fn, read_id in fast5_reads:
            yield fn, read_id, references.get(read_id, None)

    if args.limit is not None:
        chunksize = args.limit // (2 * args.jobs)
        chunksize = int(np.clip(chunksize, 1, 50))
    else:
        chunksize = 50

    results = imap_mp(workerFunction,
                      iter_jobs(),
                      threads=args.jobs,
                      fix_kwargs=kwargs,
                      unordered=True,
                      chunksize=chunksize)

    # results is an iterable of dicts
    # each dict is a set of return values from a single read
    prepare_mapping_funcs.generate_output_from_results(results, args.output,
                                                       alphabet_info)
Esempio n. 10
0
 def test_load_model_from_file_with_metadata(self, filename):
     metadata = {'reverse': False, 'standardize': False}
     filepath = os.path.join(MODELS_DIR, filename)
     net = helpers.load_model(filepath, model_metadata=metadata)
     self.assertEqual(metadata['reverse'], net.metadata['reverse'])
     self.assertEqual(metadata['standardize'], net.metadata['standardize'])
Esempio n. 11
0
def main():
    args = parser.parse_args()

    np.random.seed(args.seed)

    device = torch.device(args.device)
    if device.type == 'cuda':
        try:
            torch.cuda.set_device(device)
        except AttributeError:
            sys.stderr.write('ERROR: Torch not compiled with CUDA enabled ' +
                             'and GPU device set.')
            sys.exit(1)

    if not os.path.exists(args.output):
        os.mkdir(args.output)
    elif not args.overwrite:
        sys.stderr.write('Error: Output directory {} exists but --overwrite ' +
                         'is false\n'.format(args.output))
        exit(1)
    if not os.path.isdir(args.output):
        sys.stderr.write('Error: Output location {} is not directory\n'.format(
            args.output))
        exit(1)

    copyfile(args.model, os.path.join(args.output, 'model.py'))

    # Create a logging file to save details of chunks.
    # If args.chunk_logging_threshold is set to 0 then we log all chunks
    # including those rejected.
    chunk_log = chunk_selection.ChunkLog(args.output)

    log = helpers.Logger(os.path.join(args.output, 'model.log'), args.quiet)
    log.write('* Taiyaki version {}\n'.format(__version__))
    log.write('* Command line\n')
    log.write(' '.join(sys.argv) + '\n')
    log.write('* Loading data from {}\n'.format(args.input))
    log.write('* Per read file MD5 {}\n'.format(helpers.file_md5(args.input)))

    if args.input_strand_list is not None:
        read_ids = list(set(helpers.get_read_ids(args.input_strand_list)))
        log.write(('* Will train from a subset of {} strands, determined ' +
                   'by read_ids in input strand list\n').format(len(read_ids)))
    else:
        log.write('* Reads not filtered by id\n')
        read_ids = 'all'

    if args.limit is not None:
        log.write('* Limiting number of strands to {}\n'.format(args.limit))

    with mapped_signal_files.HDF5Reader(args.input) as per_read_file:
        alphabet, _, _ = per_read_file.get_alphabet_information()
        read_data = per_read_file.get_multiple_reads(read_ids,
                                                     max_reads=args.limit)
        # read_data now contains a list of reads
        # (each an instance of the Read class defined in
        # mapped_signal_files.py, based on dict)

    if len(read_data) == 0:
        log.write('* No reads remaining for training, exiting.\n')
        exit(1)
    log.write('* Loaded {} reads.\n'.format(len(read_data)))

    # Get parameters for filtering by sampling a subset of the reads
    # Result is a tuple median mean_dwell, mad mean_dwell
    # Choose a chunk length in the middle of the range for this
    sampling_chunk_len = (args.chunk_len_min + args.chunk_len_max) // 2
    filter_parameters = chunk_selection.sample_filter_parameters(
        read_data,
        args.sample_nreads_before_filtering,
        sampling_chunk_len,
        args,
        log,
        chunk_log=chunk_log)

    medmd, madmd = filter_parameters

    log.write(
        "* Sampled {} chunks: median(mean_dwell)={:.2f}, mad(mean_dwell)={:.2f}\n"
        .format(args.sample_nreads_before_filtering, medmd, madmd))
    log.write('* Reading network from {}\n'.format(args.model))
    nbase = len(alphabet)
    model_kwargs = {
        'stride': args.stride,
        'winlen': args.winlen,
        # Number of input features to model e.g. was >1 for event-based
        # models (level, std, dwell)
        'insize': 1,
        'size': args.size,
        'outsize': flipflopfings.nstate_flipflop(nbase)
    }
    network = helpers.load_model(args.model, **model_kwargs).to(device)
    log.write('* Network has {} parameters.\n'.format(
        sum([p.nelement() for p in network.parameters()])))

    optimizer = torch.optim.Adam(network.parameters(),
                                 lr=args.lr_max,
                                 betas=args.adam,
                                 weight_decay=args.weight_decay)

    lr_scheduler = optim.CosineFollowedByFlatLR(optimizer, args.lr_min,
                                                args.lr_cosine_iters)

    score_smoothed = helpers.WindowedExpSmoother()

    log.write('* Dumping initial model\n')
    helpers.save_model(network, args.output, 0)

    total_bases = 0
    total_samples = 0
    total_chunks = 0
    # To count the numbers of different sorts of chunk rejection
    rejection_dict = defaultdict(int)

    t0 = time.time()
    log.write('* Training\n')

    for i in range(args.niteration):
        lr_scheduler.step()
        # Chunk length is chosen randomly in the range given but forced to
        # be a multiple of the stride
        batch_chunk_len = (
            np.random.randint(args.chunk_len_min, args.chunk_len_max + 1) //
            args.stride) * args.stride
        # We choose the batch size so that the size of the data in the batch
        # is about the same as args.min_batch_size chunks of length
        # args.chunk_len_max
        target_batch_size = int(args.min_batch_size * args.chunk_len_max /
                                batch_chunk_len + 0.5)
        # ...but it can't be more than the number of reads.
        batch_size = min(target_batch_size, len(read_data))

        # If the logging threshold is 0 then we log all chunks, including those
        # rejected, so pass the log
        # object into assemble_batch
        if args.chunk_logging_threshold == 0:
            log_rejected_chunks = chunk_log
        else:
            log_rejected_chunks = None
        # Chunk_batch is a list of dicts.
        chunk_batch, batch_rejections = chunk_selection.assemble_batch(
            read_data,
            batch_size,
            batch_chunk_len,
            filter_parameters,
            args,
            log,
            chunk_log=log_rejected_chunks)
        total_chunks += len(chunk_batch)

        # Update counts of reasons for rejection
        for k, v in batch_rejections.items():
            rejection_dict[k] += v

        # Shape of input tensor must be:
        #     (timesteps) x (batch size) x (input channels)
        # in this case:
        #     batch_chunk_len x batch_size x 1
        stacked_current = np.vstack([d['current'] for d in chunk_batch]).T
        indata = torch.tensor(stacked_current,
                              device=device,
                              dtype=torch.float32).unsqueeze(2)
        # Sequence input tensor is just a 1D vector, and so is seqlens
        seqs = torch.tensor(np.concatenate([
            flipflopfings.flipflop_code(d['sequence'], nbase)
            for d in chunk_batch
        ]),
                            device=device,
                            dtype=torch.long)
        seqlens = torch.tensor([len(d['sequence']) for d in chunk_batch],
                               dtype=torch.long,
                               device=device)

        optimizer.zero_grad()
        outputs = network(indata)
        lossvector = ctc.crf_flipflop_loss(outputs, seqs, seqlens,
                                           args.sharpen)
        loss = lossvector.sum() / (seqlens > 0.0).float().sum()
        loss.backward()
        optimizer.step()

        fval = float(loss)
        score_smoothed.update(fval)

        # Check for poison chunk and save losses and chunk locations if we're
        # poisoned If args.chunk_logging_threshold set to zero then we log
        # everything
        if fval / score_smoothed.value >= args.chunk_logging_threshold:
            chunk_log.write_batch(i, chunk_batch, lossvector)

        total_bases += int(seqlens.sum())
        total_samples += int(indata.nelement())

        # Doing this deletion leads to less CUDA memory usage.
        del indata, seqs, seqlens, outputs, loss, lossvector
        if device.type == 'cuda':
            torch.cuda.empty_cache()

        if (i + 1) % args.save_every == 0:
            helpers.save_model(network, args.output,
                               (i + 1) // args.save_every)
            log.write('C')
        else:
            log.write('.')

        if (i + 1) % DOTROWLENGTH == 0:
            # In case of super batching, additional functionality must be
            # added here
            learning_rate = lr_scheduler.get_lr()[0]
            tn = time.time()
            dt = tn - t0
            t = (
                ' {:5d} {:5.3f}  {:5.2f}s ({:.2f} ksample/s {:.2f} kbase/s) ' +
                'lr={:.2e}')
            log.write(
                t.format((i + 1) // DOTROWLENGTH, score_smoothed.value, dt,
                         total_samples / 1000.0 / dt,
                         total_bases / 1000.0 / dt, learning_rate))
            # Write summary of chunk rejection reasons
            for k, v in rejection_dict.items():
                log.write(" {}:{} ".format(k, v))
            log.write("\n")
            total_bases = 0
            total_samples = 0
            t0 = tn

    helpers.save_model(network, args.output)
Esempio n. 12
0
def main():
    args = parser.parse_args()

    device = helpers.set_torch_device(args.device)
    # TODO convert to logging
    sys.stderr.write("* Loading model.\n")
    model = load_model(args.model).to(device)
    is_cat_mod = isinstance(model.sublayers[-1],
                            layers.GlobalNormFlipFlopCatMod)
    do_output_mods = args.modified_base_output is not None
    if do_output_mods and not is_cat_mod:
        sys.stderr.write(
            "Cannot output modified bases from canonical base only model.")
        sys.exit()
    n_can_states = nstate_flipflop(model.sublayers[-1].nbase)
    stride = guess_model_stride(model)
    chunk_size = args.chunk_size * stride
    chunk_overlap = args.overlap * stride

    sys.stderr.write("* Initializing reads file search.\n")
    fast5_reads = list(
        fast5utils.iterate_fast5_reads(args.input_folder,
                                       limit=args.limit,
                                       strand_list=args.input_strand_list,
                                       recursive=args.recursive))
    sys.stderr.write("* Found {} reads.\n".format(len(fast5_reads)))

    if args.scaling is not None:
        sys.stderr.write("* Loading read scaling parameters from {}.\n".format(
            args.scaling))
        all_read_params = get_per_read_params_dict_from_tsv(args.scaling)
        input_read_ids = frozenset(rec[1] for rec in fast5_reads)
        scaling_read_ids = frozenset(all_read_params.keys())
        sys.stderr.write("* {} / {} reads have scaling information.\n".format(
            len(input_read_ids & scaling_read_ids), len(input_read_ids)))
        fast5_reads = [
            rec for rec in fast5_reads if rec[1] in scaling_read_ids
        ]
    else:
        all_read_params = {}

    mods_fp = None
    if do_output_mods:
        mods_fp = h5py.File(args.modified_base_output)
        mods_fp.create_group('Reads')
        mod_long_names = model.sublayers[-1].ordered_mod_long_names
        sys.stderr.write("* Preparing modified base output: {}.\n".format(
            ', '.join(map(str, mod_long_names))))
        mods_fp.create_dataset('mod_long_names',
                               data=np.array(mod_long_names, dtype='S'),
                               dtype=h5py.special_dtype(vlen=str))

    sys.stderr.write("* Calling reads.\n")
    nbase, ncalled, nread, nsample = 0, 0, 0, 0
    t0 = time.time()
    progress = Progress(quiet=args.quiet)
    startcharacter = '@' if args.fastq else '>'
    try:
        with open_file_or_stdout(args.output) as fh:
            for read_filename, read_id in fast5_reads:
                read_params = all_read_params[
                    read_id] if read_id in all_read_params else None
                basecall, qstring, read_nsample = process_read(
                    read_filename, read_id, model, chunk_size, chunk_overlap,
                    read_params, n_can_states, stride, args.alphabet,
                    is_cat_mod, mods_fp, args.max_concurrent_chunks,
                    args.fastq, args.qscore_scale, args.qscore_offset)
                if basecall is not None:
                    fh.write("{}{}\n{}\n".format(
                        startcharacter, read_id,
                        basecall[::-1] if args.reverse else basecall))
                    nbase += len(basecall)
                    ncalled += 1
                    if args.fastq:
                        fh.write("+\n{}\n".format(
                            qstring[::-1] if args.reverse else qstring))
                nread += 1
                nsample += read_nsample
                progress.step()
    finally:
        if mods_fp is not None:
            mods_fp.close()
    total_time = time.time() - t0

    sys.stderr.write("* Called {} reads in {:.2f}s\n".format(
        nread, int(total_time)))
    sys.stderr.write("* {:7.2f} kbase / s\n".format(nbase / total_time /
                                                    1000.0))
    sys.stderr.write("* {:7.2f} ksample / s\n".format(nsample / total_time /
                                                      1000.0))
    sys.stderr.write("* {} reads failed.\n".format(nread - ncalled))
    return
Esempio n. 13
0
        pickle_name = os.path.splitext(args.reference)[0] + '.pkl'
        with open(pickle_name, 'wb') as fh:
            pickle.dump(seq_dict, fh)
        log.write('* Written pickle of processed references to {} for future use.\n'.format(pickle_name))


    log.write('* Reading network from {}\n'.format(args.model))
    nbase = len(args.alphabet)
    model_kwargs = {
        'size' : args.size,
        'stride': args.stride,
        'winlen': args.winlen,
        'insize': 1,  # Number of input features to model e.g. was >1 for event-based models (level, std, dwell)
        'outsize': flipflopfings.nstate_flipflop(nbase)
    }
    network = helpers.load_model(args.model, **model_kwargs).to(device)
    log.write('* Network has {} parameters.\n'.format(sum([p.nelement()
                                                           for p in network.parameters()])))

    optimizer = torch.optim.Adam(network.parameters(), lr=args.lr_max,
                                 betas=args.adam, eps=args.eps)
    lr_scheduler = CosineAnnealingLR(optimizer, args.niteration)

    score_smoothed = helpers.WindowedExpSmoother()

    log.write('* Dumping initial model\n')
    save_model(network, args.outdir, 0)

    total_bases = 0
    total_samples = 0
    total_chunks = 0
             '{}W'.format(name),
             filterW.permute(0, 2, 1).reshape(-1, nf),
             nr=nf2 * winlen - nf2 + nf,
             nc=nfilter)
    cformatV(sys.stdout, '{}b'.format(name), convwrapper.conv.bias.reshape(-1))
    sys.stdout.write("#define {}stride  {}\n".format(name, convwrapper.stride))
    sys.stdout.write("""#define {}nfilter  {}
    #define {}winlen  {}
    """.format(name, nfilter, name, winlen))


if __name__ == '__main__':
    args = parser.parse_args()
    modelid = args.id + '_'

    network = helpers.load_model(args.model)

    if isinstance(network.sublayers[0], DeltaSample):
        sys.stderr.write('* Removing initial DeltaSample layer\n')
        network.sublayers = network.sublayers[1:]

    sys.stdout.write("""#pragma once
    #ifndef FLIPFLOP_{}MODEL_H
    #define FLIPFLOP_{}MODEL_H
    #include "../util.h"
    """.format(modelid.upper(), modelid.upper()))
    """ Convolution layers
    """
    conv1 = network.sublayers[0]
    print_convolution(conv1,
                      'conv1_rnnrf_flipflop5_{}'.format(modelid),
Esempio n. 15
0
def load_network(args, alphabet_info, res_info, log):
    log.write('* Reading network from {}\n'.format(args.model))
    if res_info.is_lead_process:
        # Under pytorch's DistributedDataParallel scheme, we
        # need a clone of the start network to use as a template for saving
        # checkpoints. Necessary because DistributedParallel makes the class
        # structure different.
        model_kwargs = {
            'stride': args.stride,
            'winlen': args.winlen,
            'insize': 1,
            'size': args.size,
            'alphabet_info': alphabet_info
        }
        model_metadata = {
            'reverse': args.reverse,
            'standardize': args.standardize
        }
        net_clone = helpers.load_model(args.model,
                                       model_metadata=model_metadata,
                                       **model_kwargs)
        log.write('* Network has {} parameters.\n'.format(
            sum(p.nelement() for p in net_clone.parameters())))

        if not alphabet_info.is_compatible_model(net_clone):
            sys.stderr.write(
                '* ERROR: Model and mapped signal files contain ' +
                'incompatible alphabet definitions (including modified ' +
                'bases).')
            sys.exit(1)
        if layers.is_cat_mod_model(net_clone):
            log.write('* Loaded categorical modified base model.\n')
            if not alphabet_info.contains_modified_bases():
                sys.stderr.write(
                    '* ERROR: Modified bases model specified, but mapped ' +
                    'signal file does not contain modified bases.')
                sys.exit(1)
        else:
            log.write('* Loaded standard (canonical bases-only) model.\n')
            if alphabet_info.contains_modified_bases():
                sys.stderr.write(
                    '* ERROR: Standard (canonical bases only) model ' +
                    'specified, but mapped signal file does contains ' +
                    'modified bases.')
                sys.exit(1)
        if layers.is_delta_model(net_clone) and model_metadata.standardize:
            log.write('*' * 60 + '\n* WARNING: Delta-scaling models trained ' +
                      'with --standardize are not compatible with Guppy.\n' +
                      '*' * 60)
        log.write('* Dumping initial model\n')
        helpers.save_model(net_clone, args.outdir, 0)
    else:
        net_clone = None

    if res_info.is_multi_gpu:
        # so that processes 1,2,3.. don't try to load before process 0 has
        # saved
        torch.distributed.barrier()
        log.write('* MultiGPU process {}'.format(args.local_rank))
        log.write(': loading initial model saved by process 0\n')
        saved_startmodel_path = os.path.join(
            args.outdir, 'model_checkpoint_00000.checkpoint')
        network = helpers.load_model(saved_startmodel_path).to(res_info.device)
        network_metadata = parse_network_metadata(network)
        # Wrap network for training in the DistributedDataParallel structure
        network = torch.nn.parallel.DistributedDataParallel(
            network,
            device_ids=[args.local_rank],
            output_device=args.local_rank)
    else:
        log.write('* Loading model onto device\n')
        network = net_clone.to(res_info.device)
        network_metadata = parse_network_metadata(network)
        net_clone = None

    log.write('* Estimating filter parameters from training data\n')
    stride = guess_model_stride(network)
    optimiser = torch.optim.AdamW(network.parameters(),
                                  lr=args.lr_max,
                                  betas=args.adam,
                                  weight_decay=args.weight_decay,
                                  eps=args.eps)

    lr_warmup = args.lr_min if args.lr_warmup is None else args.lr_warmup
    adam_beta1, _ = args.adam
    if args.warmup_batches >= args.niteration:
        sys.stderr.write('* Error: --warmup_batches must be < --niteration\n')
        sys.exit(1)
    warmup_fraction = args.warmup_batches / args.niteration
    # Pytorch OneCycleLR crashes if pct_start==1 (i.e. warmup_fraction==1)
    lr_scheduler = torch.optim.lr_scheduler.OneCycleLR(
        optimiser,
        args.lr_max,
        total_steps=args.niteration,
        # pct_start is really fractional, not percent
        pct_start=warmup_fraction,
        div_factor=args.lr_max / lr_warmup,
        final_div_factor=lr_warmup / args.lr_min,
        cycle_momentum=(args.min_momentum is not None),
        base_momentum=adam_beta1 if args.min_momentum is None \
        else args.min_momentum,
        max_momentum=adam_beta1
    )
    log.write(
        ('* Learning rate increases from {:.2e} to {:.2e} over {} ' +
         'iterations using cosine schedule.\n').format(lr_warmup, args.lr_max,
                                                       args.warmup_batches))
    log.write(('* Then learning rate decreases from {:.2e} to {:.2e} over ' +
               '{} iterations using cosine schedule.\n').format(
                   args.lr_max, args.lr_min,
                   args.niteration - args.warmup_batches))

    if args.gradient_clip_num_mads is None:
        log.write('* No gradient clipping\n')
        rolling_mads = None
    else:
        nparams = len([p for p in network.parameters() if p.requires_grad])
        if nparams == 0:
            rolling_mads = None
            log.write('* No gradient clipping due to missing parameters\n')
        else:
            rolling_mads = maths.RollingMAD(nparams,
                                            args.gradient_clip_num_mads)
            log.write((
                '* Gradients will be clipped (by value) at {:3.2f} MADs ' +
                'above the median of the last {} gradient maximums.\n').format(
                    rolling_mads.n_mads, rolling_mads.window))

    net_info = NETWORK_INFO(net=network,
                            net_clone=net_clone,
                            metadata=network_metadata,
                            stride=stride)
    optim_info = OPTIM_INFO(optimiser=optimiser,
                            lr_warmup=lr_warmup,
                            lr_scheduler=lr_scheduler,
                            rolling_mads=rolling_mads)

    return net_info, optim_info
Esempio n. 16
0
def main():
    args = parser.parse_args()
    is_multi_gpu = (args.local_rank is not None)
    is_lead_process = (not is_multi_gpu) or args.local_rank == 0

    if is_multi_gpu:
        #Use distributed parallel processing to run one process per GPU
        try:
            torch.distributed.init_process_group(backend='nccl')
        except:
            raise Exception(
                "Unable to start multiprocessing group. " +
                "The most likely reason is that the script is running with " +
                "local_rank set but without the set-up for distributed " +
                "operation. local_rank should be used " +
                "only by torch.distributed.launch. See the README.")
        device = helpers.set_torch_device(args.local_rank)
        if args.seed is not None:
            #Make sure processes get different random picks of training data
            np.random.seed(args.seed + args.local_rank)
    else:
        device = helpers.set_torch_device(args.device)
        np.random.seed(args.seed)

    if is_lead_process:
        helpers.prepare_outdir(args.outdir, args.overwrite)
        if args.model.endswith('.py'):
            copyfile(args.model, os.path.join(args.outdir, 'model.py'))
        batchlog = helpers.BatchLog(args.outdir)
        logfile = os.path.join(args.outdir, 'model.log')
    else:
        logfile = None

    log = helpers.Logger(logfile, args.quiet)
    log.write(helpers.formatted_env_info(device))

    log.write('* Loading data from {}\n'.format(args.input))
    log.write('* Per read file MD5 {}\n'.format(helpers.file_md5(args.input)))

    if args.input_strand_list is not None:
        read_ids = list(set(helpers.get_read_ids(args.input_strand_list)))
        log.write(('* Will train from a subset of {} strands, determined ' +
                   'by read_ids in input strand list\n').format(len(read_ids)))
    else:
        log.write('* Reads not filtered by id\n')
        read_ids = 'all'

    if args.limit is not None:
        log.write('* Limiting number of strands to {}\n'.format(args.limit))

    with mapped_signal_files.HDF5Reader(args.input) as per_read_file:
        alphabet_info = per_read_file.get_alphabet_information()
        read_data = per_read_file.get_multiple_reads(read_ids,
                                                     max_reads=args.limit)
        # read_data now contains a list of reads
        # (each an instance of the Read class defined in
        # mapped_signal_files.py, based on dict)
    log.write('* Using alphabet definition: {}\n'.format(str(alphabet_info)))

    if len(read_data) == 0:
        log.write('* No reads remaining for training, exiting.\n')
        exit(1)
    log.write('* Loaded {} reads.\n'.format(len(read_data)))

    # Get parameters for filtering by sampling a subset of the reads
    # Result is a tuple median mean_dwell, mad mean_dwell
    # Choose a chunk length in the middle of the range for this
    sampling_chunk_len = (args.chunk_len_min + args.chunk_len_max) // 2
    filter_params = chunk_selection.sample_filter_parameters(
        read_data, args.sample_nreads_before_filtering, sampling_chunk_len,
        args.filter_mean_dwell, args.filter_max_dwell)

    log.write("* Sampled {} chunks".format(
        args.sample_nreads_before_filtering))
    log.write(": median(mean_dwell)={:.2f}".format(
        filter_params.median_meandwell))
    log.write(", mad(mean_dwell)={:.2f}\n".format(filter_params.mad_meandwell))
    log.write('* Reading network from {}\n'.format(args.model))
    model_kwargs = {
        'stride': args.stride,
        'winlen': args.winlen,
        # Number of input features to model e.g. was >1 for event-based
        # models (level, std, dwell)
        'insize': 1,
        'size': args.size,
        'alphabet_info': alphabet_info
    }

    if is_lead_process:
        # Under pytorch's DistributedDataParallel scheme, we
        # need a clone of the start network to use as a template for saving
        # checkpoints. Necessary because DistributedParallel makes the class
        # structure different.
        network_save_skeleton = helpers.load_model(args.model, **model_kwargs)
        log.write('* Network has {} parameters.\n'.format(
            sum([p.nelement() for p in network_save_skeleton.parameters()])))
        if not alphabet_info.is_compatible_model(network_save_skeleton):
            sys.stderr.write(
                '* ERROR: Model and mapped signal files contain incompatible '
                + 'alphabet definitions (including modified bases).')
            sys.exit(1)
        if is_cat_mod_model(network_save_skeleton):
            log.write('* Loaded categorical modified base model.\n')
            if not alphabet_info.contains_modified_bases():
                sys.stderr.write(
                    '* ERROR: Modified bases model specified, but mapped ' +
                    'signal file does not contain modified bases.')
                sys.exit(1)
        else:
            log.write('* Loaded standard (canonical bases-only) model.\n')
            if alphabet_info.contains_modified_bases():
                sys.stderr.write(
                    '* ERROR: Standard (canonical bases only) model ' +
                    'specified, but mapped signal file does contains ' +
                    'modified bases.')
                sys.exit(1)
        log.write('* Dumping initial model\n')
        helpers.save_model(network_save_skeleton, args.outdir, 0)

    if is_multi_gpu:
        #so that processes 1,2,3.. don't try to load before process 0 has saved
        torch.distributed.barrier()
        log.write('* MultiGPU process {}'.format(args.local_rank))
        log.write(': loading initial model saved by process 0\n')
        saved_startmodel_path = os.path.join(
            args.outdir, 'model_checkpoint_00000.checkpoint')
        network = helpers.load_model(saved_startmodel_path).to(device)
        # Wrap network for training in the DistributedDataParallel structure
        network = torch.nn.parallel.DistributedDataParallel(
            network,
            device_ids=[args.local_rank],
            output_device=args.local_rank)
    else:
        network = network_save_skeleton.to(device)
        network_save_skeleton = None

    optimizer = torch.optim.Adam(network.parameters(),
                                 lr=args.lr_max,
                                 betas=args.adam,
                                 weight_decay=args.weight_decay,
                                 eps=args.eps)

    if args.lr_warmup is None:
        lr_warmup = args.lr_min
    else:
        lr_warmup = args.lr_warmup

    if args.lr_frac_decay is not None:
        lr_scheduler = optim.ReciprocalLR(optimizer, args.lr_frac_decay,
                                          args.warmup_batches, lr_warmup)
        log.write('* Learning rate schedule lr_max*k/(k+t)')
        log.write(', k={}, t=iterations.\n'.format(args.lr_frac_decay))
    else:
        lr_scheduler = optim.CosineFollowedByFlatLR(optimizer, args.lr_min,
                                                    args.lr_cosine_iters,
                                                    args.warmup_batches,
                                                    lr_warmup)
        log.write('* Learning rate goes like cosine from lr_max to lr_min ')
        log.write('over {} iterations.\n'.format(args.lr_cosine_iters))
    log.write('* At start, train for {} '.format(args.warmup_batches))
    log.write('batches at warm-up learning rate {:3.2}\n'.format(lr_warmup))

    score_smoothed = helpers.WindowedExpSmoother()

    # prepare modified base paramter tensors
    network_is_catmod = is_cat_mod_model(network)
    mod_factor_t = torch.tensor(args.mod_factor,
                                dtype=torch.float32).to(device)
    can_mods_offsets = (network.sublayers[-1].can_mods_offsets
                        if network_is_catmod else None)
    # mod cat inv freq weighting is currently disabled. Compute and set this
    # value to enable mod cat weighting
    mod_cat_weights = np.ones(alphabet_info.nbase, dtype=np.float32)

    #Generating list of batches for standard loss reporting
    reporting_chunk_len = (args.chunk_len_min + args.chunk_len_max) // 2
    reporting_batch_list = list(
        prepare_random_batches(device, read_data, reporting_chunk_len,
                               args.min_sub_batch_size,
                               args.reporting_sub_batches, alphabet_info,
                               filter_params, network, network_is_catmod, log))

    log.write(
        ('* Standard loss report: chunk length = {} & sub-batch size ' +
         '= {} for {} sub-batches. \n').format(reporting_chunk_len,
                                               args.min_sub_batch_size,
                                               args.reporting_sub_batches))

    #Set cap at very large value (before we have any gradient stats).
    gradient_cap = constants.LARGE_VAL
    if args.gradient_cap_fraction is None:
        log.write('* No gradient capping\n')
    else:
        rolling_quantile = maths.RollingQuantile(args.gradient_cap_fraction)
        log.write('* Gradient L2 norm cap will be upper' +
                  ' {:3.2f} quantile of the last {} norms.\n'.format(
                      args.gradient_cap_fraction, rolling_quantile.window))

    total_bases = 0
    total_samples = 0
    total_chunks = 0
    # To count the numbers of different sorts of chunk rejection
    rejection_dict = defaultdict(int)

    t0 = time.time()
    log.write('* Training\n')

    for i in range(args.niteration):

        # Chunk length is chosen randomly in the range given but forced to
        # be a multiple of the stride
        batch_chunk_len = (
            np.random.randint(args.chunk_len_min, args.chunk_len_max + 1) //
            args.stride) * args.stride

        # We choose the size of a sub-batch so that the size of the data in
        # the sub-batch is about the same as args.min_sub_batch_size chunks of
        # length args.chunk_len_max
        sub_batch_size = int(args.min_sub_batch_size * args.chunk_len_max /
                             batch_chunk_len + 0.5)

        optimizer.zero_grad()

        main_batch_gen = prepare_random_batches(
            device, read_data, batch_chunk_len, sub_batch_size,
            args.sub_batches, alphabet_info, filter_params, network,
            network_is_catmod, log)

        chunk_count, fval, chunk_samples, chunk_bases, batch_rejections = \
                            calculate_loss( network, network_is_catmod,
                                            main_batch_gen, args.sharpen,
                                            can_mods_offsets, mod_cat_weights,
                                            mod_factor_t, calc_grads = True )

        gradnorm_uncapped = torch.nn.utils.clip_grad_norm_(
            network.parameters(), gradient_cap)
        if args.gradient_cap_fraction is not None:
            gradient_cap = rolling_quantile.update(gradnorm_uncapped)

        optimizer.step()
        if is_lead_process:
            batchlog.record(
                fval, gradnorm_uncapped,
                None if args.gradient_cap_fraction is None else gradient_cap)

        total_chunks += chunk_count
        total_samples += chunk_samples
        total_bases += chunk_bases

        # Update counts of reasons for rejection
        for k, v in batch_rejections.items():
            rejection_dict[k] += v

        score_smoothed.update(fval)

        if (i + 1) % args.save_every == 0 and is_lead_process:
            helpers.save_model(network, args.outdir,
                               (i + 1) // args.save_every,
                               network_save_skeleton)
            log.write('C')
        else:
            log.write('.')

        if (i + 1) % DOTROWLENGTH == 0:

            _, rloss, _, _, _ = calculate_loss(network, network_is_catmod,
                                               reporting_batch_list,
                                               args.sharpen, can_mods_offsets,
                                               mod_cat_weights, mod_factor_t)

            # In case of super batching, additional functionality must be
            # added here
            learning_rate = lr_scheduler.get_lr()[0]
            tn = time.time()
            dt = tn - t0
            t = (' {:5d} {:7.5f} {:7.5f}  {:5.2f}s ({:.2f} ksample/s {:.2f} ' +
                 'kbase/s) lr={:.2e}')
            log.write(
                t.format((i + 1) // DOTROWLENGTH, score_smoothed.value, rloss,
                         dt, total_samples / 1000.0 / dt,
                         total_bases / 1000.0 / dt, learning_rate))
            # Write summary of chunk rejection reasons
            if args.full_filter_status:
                for k, v in rejection_dict.items():
                    log.write(" {}:{} ".format(k, v))
            else:
                n_tot = n_fail = 0
                for k, v in rejection_dict.items():
                    n_tot += v
                    if k != 'pass':
                        n_fail += v
                log.write("  {:.1%} chunks filtered".format(n_fail / n_tot))
            log.write("\n")
            total_bases = 0
            total_samples = 0
            t0 = tn

            # Uncomment the lines below to check synchronisation of models
            # between processes in multi-GPU operation
            #for p in network.parameters():
            #    v = p.data.reshape(-1)[:5].to('cpu')
            #    u = p.data.reshape(-1)[-5:].to('cpu')
            #    break
            #if args.local_rank is not None:
            #    log.write("* GPU{} params:".format(args.local_rank))
            #log.write("{}...{}\n".format(v,u))

        lr_scheduler.step()

    if is_lead_process:
        helpers.save_model(network,
                           args.outdir,
                           model_skeleton=network_save_skeleton)
Esempio n. 17
0
def main():
    args = parser.parse_args()

    assert args.device != 'cpu', "Flipflop basecalling in taiyaki requires a GPU and for cupy to be installed"
    device = torch.device(args.device)
    # TODO convert to logging
    sys.stderr.write("* Loading model.\n")
    model = load_model(args.model).to(device)
    is_cat_mod = isinstance(model.sublayers[-1], layers.GlobalNormFlipFlopCatMod)
    do_output_mods = args.modified_base_output is not None
    if do_output_mods and not is_cat_mod:
        sys.stderr.write(
            "Cannot output modified bases from canonical base only model.")
        sys.exit()
    n_can_states = nstate_flipflop(model.sublayers[-1].nbase)
    stride = guess_model_stride(model, device=device)
    chunk_size, chunk_overlap = basecall_helpers.round_chunk_values(
        args.chunk_size, args.overlap, stride)

    sys.stderr.write("* Initializing reads file search.\n")
    fast5_reads = fast5utils.iterate_fast5_reads(
        args.input_folder, limit=args.limit, strand_list=args.input_strand_list,
        recursive=args.recursive)

    mods_fp = None
    if do_output_mods:
        mods_fp = h5py.File(args.modified_base_output)
        mods_fp.create_group('Reads')
        mod_long_names = model.sublayers[-1].ordered_mod_long_names
        sys.stderr.write("* Preparing modified base output: {}.\n".format(
            ', '.join(map(str, mod_long_names))))
        mods_fp.create_dataset(
            'mod_long_names', data=np.array(mod_long_names, dtype='S'),
            dtype=h5py.special_dtype(vlen=str))

    sys.stderr.write("* Calling reads.\n")
    nbase, ncalled, nread, nsample = 0, 0, 0, 0
    t0 = time.time()
    progress = Progress(quiet=args.quiet)
    try:
        with open_file_or_stdout(args.output) as fh:
            for read_filename, read_id in fast5_reads:
                basecall, read_nsample = process_read(
                    read_filename, read_id, model, chunk_size,
                    chunk_overlap, device, n_can_states, stride, args.alphabet,
                    is_cat_mod, mods_fp)
                if basecall is not None:
                    fh.write(">{}\n{}\n".format(read_id, basecall))
                    nbase += len(basecall)
                    ncalled += 1
                nread += 1
                nsample += read_nsample
                progress.step()
    finally:
        if mods_fp is not None:
            mods_fp.close()
    total_time = time.time() - t0

    sys.stderr.write("* Called {} reads in {}s\n".format(nread, int(total_time)))
    sys.stderr.write("* {:7.2f} kbase / s\n".format(nbase / total_time / 1000.0))
    sys.stderr.write("* {:7.2f} ksample / s\n".format(nsample / total_time / 1000.0))
    sys.stderr.write("* {} reads failed.\n".format(nread - ncalled))
    return
Esempio n. 18
0
 def test_load_model_from_file_with_no_metadata(self, filename):
     helpers.load_model(os.path.join(MODELS_DIR, filename))
Esempio n. 19
0
#!/usr/bin/env python3
import argparse
import json

from taiyaki.cmdargs import AutoBool, FileExists, FileAbsent
from taiyaki.helpers import load_model
from taiyaki.json import JsonEncoder

parser = argparse.ArgumentParser(description='Dump JSON representation of model',
                                 formatter_class=argparse.ArgumentDefaultsHelpFormatter)

parser.add_argument('--out_file', default=None, action=FileAbsent, help='Output JSON file to this file location')
parser.add_argument('--params', default=True, action=AutoBool, help='Output parameters as well as model structure')

parser.add_argument('model', action=FileExists, help='Model file to read from')


if __name__ == "__main__":
    args = parser.parse_args()
    model = load_model(args.model)

    json_out = model.json(args.params)

    if args.out_file is not None:
        with open(args.out_file, 'w') as f:
            print("Writing to file: ", args.out_file)
            json.dump(json_out, f, indent=4, cls=JsonEncoder)
    else:
        print(json.dumps(json_out, indent=4, cls=JsonEncoder))
Esempio n. 20
0
    log.write('* Reading network from {}\n'.format(args.model))
    alphabet_info = alphabet.AlphabetInfo(args.alphabet, args.alphabet)

    model_kwargs = {
        'size': args.size,
        'stride': args.stride,
        'winlen': args.winlen,
        # Number of input features to model e.g. was >1 for event-based models
        # (level, std, dwell)
        'insize': 1,
        'alphabet_info': alphabet_info
    }
    model_metadata = {'reverse': False, 'standardize': True}
    network = helpers.load_model(args.model,
                                 model_metadata=model_metadata,
                                 **model_kwargs).to(device)
    log.write('* Network has {} parameters.\n'.format(
        sum([p.nelement() for p in network.parameters()])))

    optimizer = torch.optim.AdamW(network.parameters(),
                                  lr=args.lr_max,
                                  betas=args.adam,
                                  eps=args.eps,
                                  weight_decay=args.weight_decay)
    lr_scheduler = CosineAnnealingLR(optimizer, args.niteration)

    score_smoothed = helpers.WindowedExpSmoother()

    log.write('* Dumping initial model\n')
    save_model(network, args.outdir, 0)
Esempio n. 21
0
def main():
    args = parser.parse_args()
    log, loss_log, chunk_log, device = _setup_and_logs(args)

    read_data, alphabet_info = _load_data(args, log)
    # Get parameters for filtering by sampling a subset of the reads
    # Result is a tuple median mean_dwell, mad mean_dwell
    # Choose a chunk length in the middle of the range for this
    filter_parameters = chunk_selection.sample_filter_parameters(
        read_data,
        args.sample_nreads_before_filtering,
        (args.chunk_len_min + args.chunk_len_max) // 2,
        args,
        log,
        chunk_log=chunk_log)
    log.write(("* Sampled {} chunks: median(mean_dwell)={:.2f}, " +
               "mad(mean_dwell)={:.2f}\n").format(
                   args.sample_nreads_before_filtering, *filter_parameters))

    log.write('* Reading network from {}\n'.format(args.model))
    model_kwargs = {
        'insize': 1,
        'winlen': args.winlen,
        'stride': args.stride,
        'size': args.size,
        'alphabet_info': alphabet_info
    }
    network = helpers.load_model(args.model, **model_kwargs).to(device)
    if not isinstance(network.sublayers[-1], layers.GlobalNormFlipFlopCatMod):
        log.write(
            'ERROR: Model must end with GlobalNormCatModFlipFlop layer, ' +
            'not {}.\n'.format(str(network.sublayers[-1])))
        sys.exit(1)
    can_mods_offsets = network.sublayers[-1].can_mods_offsets
    flipflop_can_labels = network.sublayers[-1].can_labels
    flipflop_mod_labels = network.sublayers[-1].mod_labels
    flipflop_ncan_base = network.sublayers[-1].ncan_base
    log.write('* Loaded categorical modifications flip-flop model.\n')
    log.write('* Network has {} parameters.\n'.format(
        sum([p.nelement() for p in network.parameters()])))

    optimizer = torch.optim.Adam(network.parameters(),
                                 lr=args.lr_max,
                                 betas=args.adam,
                                 weight_decay=args.weight_decay)
    lr_scheduler = optim.CosineFollowedByFlatLR(optimizer, args.lr_min,
                                                args.lr_cosine_iters)

    if args.scale_mod_loss:
        try:
            mod_cat_weights = alphabet_info.compute_mod_inv_freq_weights(
                read_data, args.num_inv_freq_reads)
            log.write('* Modified base weights: {}\n'.format(
                str(mod_cat_weights)))
        except NotImplementedError:
            log.write(
                '* WARNING: Some mods not found when computing inverse ' +
                'frequency weights. Consider raising ' +
                '[--num_inv_freq_reads].\n')
            mod_cat_weights = np.ones(alphabet_info.nbase, dtype=np.float32)
    else:
        mod_cat_weights = np.ones(alphabet_info.nbase, dtype=np.float32)

    log.write('* Dumping initial model\n')
    save_model(network, args.outdir, 0)

    total_bases = 0
    total_chunks = 0
    total_samples = 0
    # To count the numbers of different sorts of chunk rejection
    rejection_dict = defaultdict(int)
    score_smoothed = helpers.WindowedExpSmoother()
    t0 = time.time()
    log.write('* Training\n')
    for i in range(args.niteration):
        lr_scheduler.step()
        mod_factor_t = torch.tensor(args.mod_factor, dtype=torch.float32)

        # Chunk length is chosen randomly in the range given but forced to
        # be a multiple of the stride
        batch_chunk_len = (
            np.random.randint(args.chunk_len_min, args.chunk_len_max + 1) //
            args.stride) * args.stride
        # We choose the batch size so that the size of the data in the batch
        # is about the same as args.min_batch_size chunks of length
        # args.chunk_len_max
        target_batch_size = int(args.min_batch_size * args.chunk_len_max /
                                batch_chunk_len + 0.5)
        # ...but it can't be more than the number of reads.
        batch_size = min(target_batch_size, len(read_data))

        # If the logging threshold is 0 then we log all chunks, including those
        # rejected, so pass the log
        # object into assemble_batch
        if args.chunk_logging_threshold == 0:
            log_rejected_chunks = chunk_log
        else:
            log_rejected_chunks = None
        # chunk_batch is a list of dicts.
        chunk_batch, batch_rejections = chunk_selection.assemble_batch(
            read_data,
            batch_size,
            batch_chunk_len,
            filter_parameters,
            args,
            log,
            chunk_log=log_rejected_chunks)
        total_chunks += len(chunk_batch)
        # Update counts of reasons for rejection
        for k, v in batch_rejections.items():
            rejection_dict[k] += v

        # Shape of input tensor must be:
        #     (timesteps) x (batch size) x (input channels)
        # in this case:
        #     batch_chunk_len x batch_size x 1
        stacked_current = np.vstack([d['current'] for d in chunk_batch]).T
        indata = torch.tensor(stacked_current,
                              device=device,
                              dtype=torch.float32).unsqueeze(2)
        seqs, mod_cats, seqlens = [], [], []
        for chunk in chunk_batch:
            chunk_labels = chunk['sequence']
            seqlens.append(len(chunk_labels))
            chunk_seq = flipflop_code(
                np.ascontiguousarray(flipflop_can_labels[chunk_labels]),
                flipflop_ncan_base)
            chunk_mod_cats = np.ascontiguousarray(
                flipflop_mod_labels[chunk_labels])
            seqs.append(chunk_seq)
            mod_cats.append(chunk_mod_cats)
        seqs, mod_cats = np.concatenate(seqs), np.concatenate(mod_cats)
        seqs = torch.tensor(seqs, dtype=torch.float32, device=device)
        seqlens = torch.tensor(seqlens, dtype=torch.long, device=device)
        mod_cats = torch.tensor(mod_cats, dtype=torch.long, device=device)

        optimizer.zero_grad()
        outputs = network(indata)
        lossvector = ctc.cat_mod_flipflop_loss(outputs, seqs, seqlens,
                                               mod_cats, can_mods_offsets,
                                               mod_cat_weights, mod_factor_t,
                                               args.sharpen)
        loss = lossvector.sum() / (seqlens > 0.0).float().sum()
        loss.backward()
        optimizer.step()

        fval = float(loss)
        score_smoothed.update(fval)

        # Check for poison chunk and save losses and chunk locations if we're
        # poisoned If args.chunk_logging_threshold set to zero then we log
        # everything
        if fval / score_smoothed.value >= args.chunk_logging_threshold:
            chunk_log.write_batch(i, chunk_batch, lossvector)

        total_bases += int(seqlens.sum())
        total_samples += int(indata.nelement())

        del indata, seqs, mod_cats, seqlens, outputs, loss, lossvector
        if device.type == 'cuda':
            torch.cuda.empty_cache()

        loss_log.write('{}\t{:.10f}\t{:.10f}\n'.format(i, fval,
                                                       score_smoothed.value))
        if (i + 1) % args.save_every == 0:
            save_model(network, args.outdir, (i + 1) // args.save_every)
            log.write('C')
        else:
            log.write('.')

        if (i + 1) % 50 == 0:
            # In case of super batching, additional functionality must be
            # added here
            learning_rate = lr_scheduler.get_lr()[0]
            tn = time.time()
            dt = tn - t0
            log.write((' {:5d} {:5.3f}  {:5.2f}s ({:.2f} ksample/s ' +
                       '{:.2f} kbase/s) lr={:.2e}').format(
                           (i + 1) // 50, score_smoothed.value, dt,
                           total_samples / 1000 / dt, total_bases / 1000 / dt,
                           learning_rate))
            # Write summary of chunk rejection reasons
            for k, v in rejection_dict.items():
                log.write(" {}:{} ".format(k, v))
            log.write("\n")
            total_bases = 0
            total_samples = 0
            t0 = tn

    save_model(network, args.outdir)

    return
Esempio n. 22
0
parser = argparse.ArgumentParser(
    description='Predict squiggle from sequence',
    formatter_class=argparse.ArgumentDefaultsHelpFormatter)

parser.add_argument('--version',
                    nargs=0,
                    action=display_version_and_exit,
                    metavar=__version__,
                    help='Display version information')
parser.add_argument('model', action=FileExists, help='Model file')
parser.add_argument('input', action=FileExists, help='Fasta file')

if __name__ == '__main__':
    args = parser.parse_args()

    predict_squiggle = helpers.load_model(args.model)

    for seq in SeqIO.parse(args.input, 'fasta'):
        seqstr = str(seq.seq).encode('ascii')
        embedded_seq_numpy = np.expand_dims(
            squiggle_match.embed_sequence(seqstr), axis=1)
        embedded_seq_torch = torch.tensor(embedded_seq_numpy,
                                          dtype=torch.float32)

        with torch.no_grad():
            squiggle = np.squeeze(
                predict_squiggle(embedded_seq_torch).cpu().numpy(), axis=1)

        print('base', 'current', 'sd', 'dwell', sep='\t')
        for base, (mean, logsd, dwell) in zip(seq.seq, squiggle):
            print(base, mean, np.exp(logsd), np.exp(-dwell), sep='\t')
Esempio n. 23
0
                    nargs=2,
                    type=NonNegative(int),
                    metavar=('beginning', 'end'),
                    help='Number of samples to trim off start and end')
parser.add_argument('model', action=FileExists, help='Model file')
parser.add_argument('references', action=FileExists, help='Fasta file')
parser.add_argument('read_dir',
                    action=FileExists,
                    help='Directory for fast5 reads')

if __name__ == '__main__':
    args = parser.parse_args()

    worker_kwarg_names = ['back_prob', 'localpen', 'minscore', 'trim']

    model = helpers.load_model(args.model)

    fast5_reads = fast5utils.iterate_fast5_reads(
        args.read_dir, limit=args.limit, strand_list=args.input_strand_list)

    for res in imap_mp(squiggle_match.worker,
                       fast5_reads,
                       threads=args.jobs,
                       fix_kwargs=helpers.get_kwargs(args, worker_kwarg_names),
                       unordered=True,
                       init=squiggle_match.init_worker,
                       initargs=[model, args.references]):
        if res is None:
            continue
        read_id, sig, score, path, squiggle, bases = res
        bases = bases.decode('ascii')