Esempio n. 1
0
    def setUpClass(self):
        self.nbases = 4
        self.batchsize = 1
        self.seqlen = 3
        self.nblocks = 4
        self.sharpen = 1.0
        self.nflipflop_transitions = flipflopfings.nstate_flipflop(
            self.nbases)  # 40 for ACGT
        self.dx_size = 0.001  # Size of small changes for gradient check
        self.grad_dp = 5  # Number of decimal places for gradient check

        # Sequence ACC. Flip-flop coded ACc or 015
        # 'sequences' rather than 'sequence' because
        # we could have more than one sequence packed together
        self.sequences = {
            '015': torch.tensor([0, 1, 5]),
            '237': torch.tensor([2, 3, 7]),
            '510': torch.tensor([5, 1, 0])
        }  # prob 0 according to outputs
        self.seqlens = torch.tensor([self.seqlen])
        # Network outputs are weights for flip-flop transitions
        # Define some example paths and assign weights to them.
        # These will be used to define the example output matrix
        #
        paths = {}
        weights = {}

        paths['015'] = [0, 0, 1, 5, 5]
        weights['015'] = [1.0, 1.0, 0.5, 1.0]

        paths['237'] = [2, 2, 3, 7, 7]
        weights['237'] = [1.0, 0.5, 1.0, 1.0]

        weights['510'] = [0.0]  # No weight for this sequence/path

        self.path_probabilities = {k: np.prod(v) for k, v in weights.items()}

        # Normalise path probabilities
        psum = sum(self.path_probabilities.values())
        self.path_probabilities = {
            k: v / psum
            for k, v in self.path_probabilities.items()
        }

        # Make output (transition weight) matrix with these path probs
        self.outputs = torch.zeros(self.nblocks,
                                   self.batchsize,
                                   self.nflipflop_transitions,
                                   dtype=torch.float)
        for k in paths.keys():
            for block in range(self.nblocks):
                transcode = flipflop_transitioncode(paths[k][block],
                                                    paths[k][block + 1],
                                                    self.nbases)
                self.outputs[block, 0, transcode] = weights[k][block]

        # Log and normalise output (transition weight) matrix
        self.outputs = torch.log(self.outputs + SMALL_VAL)
        self.outputs = layers.global_norm_flipflop(self.outputs)
Esempio n. 2
0
 def __init__(self, insize, nbase, has_bias=True, _never_use_cupy=False):
     super().__init__()
     self.insize = insize
     self.nbase = nbase
     self.size = flipflopfings.nstate_flipflop(nbase)
     self.has_bias = has_bias
     self.linear = nn.Linear(insize, self.size, bias=has_bias)
     self.reset_parameters()
     self._never_use_cupy = _never_use_cupy
Esempio n. 3
0
def worker_init(device, modelname, chunk_size, overlap,
                read_params, alphabet, max_concurrent_chunks,
                fastq, qscore_scale, qscore_offset, beam, posterior,
                temperature):
    global all_read_params
    global process_read_partial

    all_read_params = read_params
    device = helpers.set_torch_device(device)
    model = load_model(modelname).to(device)
    stride = guess_model_stride(model)
    chunk_size = chunk_size * stride
    overlap = overlap * stride

    n_can_base = len(alphabet)
    n_can_state = nstate_flipflop(n_can_base)

    def process_read_partial(read_filename, read_id, read_params):
        res = process_read(read_filename, read_id,
                           model, chunk_size, overlap, read_params,
                           n_can_state, stride, alphabet,
                           max_concurrent_chunks, fastq, qscore_scale,
                           qscore_offset, beam, posterior, temperature)
        return (read_id, *res)
Esempio n. 4
0
def main():
    args = parser.parse_args()

    device = helpers.set_torch_device(args.device)
    # TODO convert to logging
    sys.stderr.write("* Loading model.\n")
    model = load_model(args.model).to(device)
    is_cat_mod = isinstance(model.sublayers[-1],
                            layers.GlobalNormFlipFlopCatMod)
    do_output_mods = args.modified_base_output is not None
    if do_output_mods and not is_cat_mod:
        sys.stderr.write(
            "Cannot output modified bases from canonical base only model.")
        sys.exit()
    n_can_states = nstate_flipflop(model.sublayers[-1].nbase)
    stride = guess_model_stride(model)
    chunk_size = args.chunk_size * stride
    chunk_overlap = args.overlap * stride

    sys.stderr.write("* Initializing reads file search.\n")
    fast5_reads = list(
        fast5utils.iterate_fast5_reads(args.input_folder,
                                       limit=args.limit,
                                       strand_list=args.input_strand_list,
                                       recursive=args.recursive))
    sys.stderr.write("* Found {} reads.\n".format(len(fast5_reads)))

    if args.scaling is not None:
        sys.stderr.write("* Loading read scaling parameters from {}.\n".format(
            args.scaling))
        all_read_params = get_per_read_params_dict_from_tsv(args.scaling)
        input_read_ids = frozenset(rec[1] for rec in fast5_reads)
        scaling_read_ids = frozenset(all_read_params.keys())
        sys.stderr.write("* {} / {} reads have scaling information.\n".format(
            len(input_read_ids & scaling_read_ids), len(input_read_ids)))
        fast5_reads = [
            rec for rec in fast5_reads if rec[1] in scaling_read_ids
        ]
    else:
        all_read_params = {}

    mods_fp = None
    if do_output_mods:
        mods_fp = h5py.File(args.modified_base_output)
        mods_fp.create_group('Reads')
        mod_long_names = model.sublayers[-1].ordered_mod_long_names
        sys.stderr.write("* Preparing modified base output: {}.\n".format(
            ', '.join(map(str, mod_long_names))))
        mods_fp.create_dataset('mod_long_names',
                               data=np.array(mod_long_names, dtype='S'),
                               dtype=h5py.special_dtype(vlen=str))

    sys.stderr.write("* Calling reads.\n")
    nbase, ncalled, nread, nsample = 0, 0, 0, 0
    t0 = time.time()
    progress = Progress(quiet=args.quiet)
    startcharacter = '@' if args.fastq else '>'
    try:
        with open_file_or_stdout(args.output) as fh:
            for read_filename, read_id in fast5_reads:
                read_params = all_read_params[
                    read_id] if read_id in all_read_params else None
                basecall, qstring, read_nsample = process_read(
                    read_filename, read_id, model, chunk_size, chunk_overlap,
                    read_params, n_can_states, stride, args.alphabet,
                    is_cat_mod, mods_fp, args.max_concurrent_chunks,
                    args.fastq, args.qscore_scale, args.qscore_offset)
                if basecall is not None:
                    fh.write("{}{}\n{}\n".format(
                        startcharacter, read_id,
                        basecall[::-1] if args.reverse else basecall))
                    nbase += len(basecall)
                    ncalled += 1
                    if args.fastq:
                        fh.write("+\n{}\n".format(
                            qstring[::-1] if args.reverse else qstring))
                nread += 1
                nsample += read_nsample
                progress.step()
    finally:
        if mods_fp is not None:
            mods_fp.close()
    total_time = time.time() - t0

    sys.stderr.write("* Called {} reads in {:.2f}s\n".format(
        nread, int(total_time)))
    sys.stderr.write("* {:7.2f} kbase / s\n".format(nbase / total_time /
                                                    1000.0))
    sys.stderr.write("* {:7.2f} ksample / s\n".format(nsample / total_time /
                                                      1000.0))
    sys.stderr.write("* {} reads failed.\n".format(nread - ncalled))
    return
Esempio n. 5
0
def main():
    args = parser.parse_args()

    np.random.seed(args.seed)

    device = torch.device(args.device)
    if device.type == 'cuda':
        try:
            torch.cuda.set_device(device)
        except AttributeError:
            sys.stderr.write('ERROR: Torch not compiled with CUDA enabled ' +
                             'and GPU device set.')
            sys.exit(1)

    if not os.path.exists(args.output):
        os.mkdir(args.output)
    elif not args.overwrite:
        sys.stderr.write('Error: Output directory {} exists but --overwrite ' +
                         'is false\n'.format(args.output))
        exit(1)
    if not os.path.isdir(args.output):
        sys.stderr.write('Error: Output location {} is not directory\n'.format(
            args.output))
        exit(1)

    copyfile(args.model, os.path.join(args.output, 'model.py'))

    # Create a logging file to save details of chunks.
    # If args.chunk_logging_threshold is set to 0 then we log all chunks
    # including those rejected.
    chunk_log = chunk_selection.ChunkLog(args.output)

    log = helpers.Logger(os.path.join(args.output, 'model.log'), args.quiet)
    log.write('* Taiyaki version {}\n'.format(__version__))
    log.write('* Command line\n')
    log.write(' '.join(sys.argv) + '\n')
    log.write('* Loading data from {}\n'.format(args.input))
    log.write('* Per read file MD5 {}\n'.format(helpers.file_md5(args.input)))

    if args.input_strand_list is not None:
        read_ids = list(set(helpers.get_read_ids(args.input_strand_list)))
        log.write(('* Will train from a subset of {} strands, determined ' +
                   'by read_ids in input strand list\n').format(len(read_ids)))
    else:
        log.write('* Reads not filtered by id\n')
        read_ids = 'all'

    if args.limit is not None:
        log.write('* Limiting number of strands to {}\n'.format(args.limit))

    with mapped_signal_files.HDF5Reader(args.input) as per_read_file:
        alphabet, _, _ = per_read_file.get_alphabet_information()
        read_data = per_read_file.get_multiple_reads(read_ids,
                                                     max_reads=args.limit)
        # read_data now contains a list of reads
        # (each an instance of the Read class defined in
        # mapped_signal_files.py, based on dict)

    if len(read_data) == 0:
        log.write('* No reads remaining for training, exiting.\n')
        exit(1)
    log.write('* Loaded {} reads.\n'.format(len(read_data)))

    # Get parameters for filtering by sampling a subset of the reads
    # Result is a tuple median mean_dwell, mad mean_dwell
    # Choose a chunk length in the middle of the range for this
    sampling_chunk_len = (args.chunk_len_min + args.chunk_len_max) // 2
    filter_parameters = chunk_selection.sample_filter_parameters(
        read_data,
        args.sample_nreads_before_filtering,
        sampling_chunk_len,
        args,
        log,
        chunk_log=chunk_log)

    medmd, madmd = filter_parameters

    log.write(
        "* Sampled {} chunks: median(mean_dwell)={:.2f}, mad(mean_dwell)={:.2f}\n"
        .format(args.sample_nreads_before_filtering, medmd, madmd))
    log.write('* Reading network from {}\n'.format(args.model))
    nbase = len(alphabet)
    model_kwargs = {
        'stride': args.stride,
        'winlen': args.winlen,
        # Number of input features to model e.g. was >1 for event-based
        # models (level, std, dwell)
        'insize': 1,
        'size': args.size,
        'outsize': flipflopfings.nstate_flipflop(nbase)
    }
    network = helpers.load_model(args.model, **model_kwargs).to(device)
    log.write('* Network has {} parameters.\n'.format(
        sum([p.nelement() for p in network.parameters()])))

    optimizer = torch.optim.Adam(network.parameters(),
                                 lr=args.lr_max,
                                 betas=args.adam,
                                 weight_decay=args.weight_decay)

    lr_scheduler = optim.CosineFollowedByFlatLR(optimizer, args.lr_min,
                                                args.lr_cosine_iters)

    score_smoothed = helpers.WindowedExpSmoother()

    log.write('* Dumping initial model\n')
    helpers.save_model(network, args.output, 0)

    total_bases = 0
    total_samples = 0
    total_chunks = 0
    # To count the numbers of different sorts of chunk rejection
    rejection_dict = defaultdict(int)

    t0 = time.time()
    log.write('* Training\n')

    for i in range(args.niteration):
        lr_scheduler.step()
        # Chunk length is chosen randomly in the range given but forced to
        # be a multiple of the stride
        batch_chunk_len = (
            np.random.randint(args.chunk_len_min, args.chunk_len_max + 1) //
            args.stride) * args.stride
        # We choose the batch size so that the size of the data in the batch
        # is about the same as args.min_batch_size chunks of length
        # args.chunk_len_max
        target_batch_size = int(args.min_batch_size * args.chunk_len_max /
                                batch_chunk_len + 0.5)
        # ...but it can't be more than the number of reads.
        batch_size = min(target_batch_size, len(read_data))

        # If the logging threshold is 0 then we log all chunks, including those
        # rejected, so pass the log
        # object into assemble_batch
        if args.chunk_logging_threshold == 0:
            log_rejected_chunks = chunk_log
        else:
            log_rejected_chunks = None
        # Chunk_batch is a list of dicts.
        chunk_batch, batch_rejections = chunk_selection.assemble_batch(
            read_data,
            batch_size,
            batch_chunk_len,
            filter_parameters,
            args,
            log,
            chunk_log=log_rejected_chunks)
        total_chunks += len(chunk_batch)

        # Update counts of reasons for rejection
        for k, v in batch_rejections.items():
            rejection_dict[k] += v

        # Shape of input tensor must be:
        #     (timesteps) x (batch size) x (input channels)
        # in this case:
        #     batch_chunk_len x batch_size x 1
        stacked_current = np.vstack([d['current'] for d in chunk_batch]).T
        indata = torch.tensor(stacked_current,
                              device=device,
                              dtype=torch.float32).unsqueeze(2)
        # Sequence input tensor is just a 1D vector, and so is seqlens
        seqs = torch.tensor(np.concatenate([
            flipflopfings.flipflop_code(d['sequence'], nbase)
            for d in chunk_batch
        ]),
                            device=device,
                            dtype=torch.long)
        seqlens = torch.tensor([len(d['sequence']) for d in chunk_batch],
                               dtype=torch.long,
                               device=device)

        optimizer.zero_grad()
        outputs = network(indata)
        lossvector = ctc.crf_flipflop_loss(outputs, seqs, seqlens,
                                           args.sharpen)
        loss = lossvector.sum() / (seqlens > 0.0).float().sum()
        loss.backward()
        optimizer.step()

        fval = float(loss)
        score_smoothed.update(fval)

        # Check for poison chunk and save losses and chunk locations if we're
        # poisoned If args.chunk_logging_threshold set to zero then we log
        # everything
        if fval / score_smoothed.value >= args.chunk_logging_threshold:
            chunk_log.write_batch(i, chunk_batch, lossvector)

        total_bases += int(seqlens.sum())
        total_samples += int(indata.nelement())

        # Doing this deletion leads to less CUDA memory usage.
        del indata, seqs, seqlens, outputs, loss, lossvector
        if device.type == 'cuda':
            torch.cuda.empty_cache()

        if (i + 1) % args.save_every == 0:
            helpers.save_model(network, args.output,
                               (i + 1) // args.save_every)
            log.write('C')
        else:
            log.write('.')

        if (i + 1) % DOTROWLENGTH == 0:
            # In case of super batching, additional functionality must be
            # added here
            learning_rate = lr_scheduler.get_lr()[0]
            tn = time.time()
            dt = tn - t0
            t = (
                ' {:5d} {:5.3f}  {:5.2f}s ({:.2f} ksample/s {:.2f} kbase/s) ' +
                'lr={:.2e}')
            log.write(
                t.format((i + 1) // DOTROWLENGTH, score_smoothed.value, dt,
                         total_samples / 1000.0 / dt,
                         total_bases / 1000.0 / dt, learning_rate))
            # Write summary of chunk rejection reasons
            for k, v in rejection_dict.items():
                log.write(" {}:{} ".format(k, v))
            log.write("\n")
            total_bases = 0
            total_samples = 0
            t0 = tn

    helpers.save_model(network, args.output)
Esempio n. 6
0
        log.write('* Loaded references from {}.\n'.format(args.reference))
        #  Write pickle for future
        pickle_name = os.path.splitext(args.reference)[0] + '.pkl'
        with open(pickle_name, 'wb') as fh:
            pickle.dump(seq_dict, fh)
        log.write('* Written pickle of processed references to {} for future use.\n'.format(pickle_name))


    log.write('* Reading network from {}\n'.format(args.model))
    nbase = len(args.alphabet)
    model_kwargs = {
        'size' : args.size,
        'stride': args.stride,
        'winlen': args.winlen,
        'insize': 1,  # Number of input features to model e.g. was >1 for event-based models (level, std, dwell)
        'outsize': flipflopfings.nstate_flipflop(nbase)
    }
    network = helpers.load_model(args.model, **model_kwargs).to(device)
    log.write('* Network has {} parameters.\n'.format(sum([p.nelement()
                                                           for p in network.parameters()])))

    optimizer = torch.optim.Adam(network.parameters(), lr=args.lr_max,
                                 betas=args.adam, eps=args.eps)
    lr_scheduler = CosineAnnealingLR(optimizer, args.niteration)

    score_smoothed = helpers.WindowedExpSmoother()

    log.write('* Dumping initial model\n')
    save_model(network, args.outdir, 0)

    total_bases = 0
Esempio n. 7
0
def main():
    args = parser.parse_args()

    assert args.device != 'cpu', "Flipflop basecalling in taiyaki requires a GPU and for cupy to be installed"
    device = torch.device(args.device)
    # TODO convert to logging
    sys.stderr.write("* Loading model.\n")
    model = load_model(args.model).to(device)
    is_cat_mod = isinstance(model.sublayers[-1], layers.GlobalNormFlipFlopCatMod)
    do_output_mods = args.modified_base_output is not None
    if do_output_mods and not is_cat_mod:
        sys.stderr.write(
            "Cannot output modified bases from canonical base only model.")
        sys.exit()
    n_can_states = nstate_flipflop(model.sublayers[-1].nbase)
    stride = guess_model_stride(model, device=device)
    chunk_size, chunk_overlap = basecall_helpers.round_chunk_values(
        args.chunk_size, args.overlap, stride)

    sys.stderr.write("* Initializing reads file search.\n")
    fast5_reads = fast5utils.iterate_fast5_reads(
        args.input_folder, limit=args.limit, strand_list=args.input_strand_list,
        recursive=args.recursive)

    mods_fp = None
    if do_output_mods:
        mods_fp = h5py.File(args.modified_base_output)
        mods_fp.create_group('Reads')
        mod_long_names = model.sublayers[-1].ordered_mod_long_names
        sys.stderr.write("* Preparing modified base output: {}.\n".format(
            ', '.join(map(str, mod_long_names))))
        mods_fp.create_dataset(
            'mod_long_names', data=np.array(mod_long_names, dtype='S'),
            dtype=h5py.special_dtype(vlen=str))

    sys.stderr.write("* Calling reads.\n")
    nbase, ncalled, nread, nsample = 0, 0, 0, 0
    t0 = time.time()
    progress = Progress(quiet=args.quiet)
    try:
        with open_file_or_stdout(args.output) as fh:
            for read_filename, read_id in fast5_reads:
                basecall, read_nsample = process_read(
                    read_filename, read_id, model, chunk_size,
                    chunk_overlap, device, n_can_states, stride, args.alphabet,
                    is_cat_mod, mods_fp)
                if basecall is not None:
                    fh.write(">{}\n{}\n".format(read_id, basecall))
                    nbase += len(basecall)
                    ncalled += 1
                nread += 1
                nsample += read_nsample
                progress.step()
    finally:
        if mods_fp is not None:
            mods_fp.close()
    total_time = time.time() - t0

    sys.stderr.write("* Called {} reads in {}s\n".format(nread, int(total_time)))
    sys.stderr.write("* {:7.2f} kbase / s\n".format(nbase / total_time / 1000.0))
    sys.stderr.write("* {:7.2f} ksample / s\n".format(nsample / total_time / 1000.0))
    sys.stderr.write("* {} reads failed.\n".format(nread - ncalled))
    return