def run_model(
        normed_signal, model, chunk_size=_DEFAULT_CHUNK_SIZE,
        overlap=_DEFAULT_OVERLAP, max_concur_chunks=None, return_numpy=True,
        return_tensor_on_device=True):
    """ Hook for megalodon to run network via taiyaki
    """
    device = next(model.parameters()).device
    stride = guess_model_stride(model)
    chunk_size *= stride
    overlap *= stride

    chunks, chunk_starts, chunk_ends = chunk_read(
        normed_signal, chunk_size, overlap)
    device = next(model.parameters()).device
    chunks = torch.tensor(chunks)
    with torch.no_grad():
        if max_concur_chunks is None:
            out = model(chunks.to(device)).cpu()
        else:
            out = []
            for some_chunks in torch.split(chunks, max_concur_chunks, 1):
                out.append(model(some_chunks.to(device)).cpu())
            out = torch.cat(out, 1)
        stitched_chunks = stitch_chunks(
            out, chunk_starts, chunk_ends, stride)
    if return_numpy:
        return stitched_chunks.numpy()
    if return_tensor_on_device:
        return stitched_chunks.to(device)
    return stitched_chunks
Exemple #2
0
def run_model(normed_signal,
              model,
              chunk_size=_DEFAULT_CHUNK_SIZE,
              overlap=_DEFAULT_OVERLAP,
              max_concur_chunks=None,
              return_numpy=True):
    """ Hook for megalodon to run network via taiyaki
    """
    device = next(model.parameters()).device
    stride = guess_model_stride(model)
    chunk_size, overlap = round_chunk_values(chunk_size, overlap, stride)

    chunks, chunk_starts, chunk_ends = chunk_read(normed_signal, chunk_size,
                                                  overlap)
    device = next(model.parameters()).device
    with torch.no_grad():
        if max_concur_chunks is None:
            out = model(torch.tensor(chunks, device=device))
        else:
            out = []
            for super_chunk_i in range(
                    np.ceil(chunks.shape[1] / max_concur_chunks).astype(int)):
                sc_start, sc_end = (super_chunk_i * max_concur_chunks,
                                    (super_chunk_i + 1) * max_concur_chunks)
                sc_chunks = np.ascontiguousarray(chunks[:, sc_start:sc_end])
                out.append(model(torch.tensor(sc_chunks, device=device)))
            out = torch.cat(out, 1)
        stitched_chunks = stitch_chunks(out, chunk_starts, chunk_ends, stride)
    if return_numpy:
        return stitched_chunks.cpu().numpy()
    return stitched_chunks
Exemple #3
0
def worker_init(device, modelname, chunk_size, overlap,
                read_params, alphabet, max_concurrent_chunks,
                fastq, qscore_scale, qscore_offset, beam, posterior,
                temperature):
    global all_read_params
    global process_read_partial

    all_read_params = read_params
    device = helpers.set_torch_device(device)
    model = load_model(modelname).to(device)
    stride = guess_model_stride(model)
    chunk_size = chunk_size * stride
    overlap = overlap * stride

    n_can_base = len(alphabet)
    n_can_state = nstate_flipflop(n_can_base)

    def process_read_partial(read_filename, read_id, read_params):
        res = process_read(read_filename, read_id,
                           model, chunk_size, overlap, read_params,
                           n_can_state, stride, alphabet,
                           max_concurrent_chunks, fastq, qscore_scale,
                           qscore_offset, beam, posterior, temperature)
        return (read_id, *res)
Exemple #4
0
def main():
    args = parser.parse_args()

    device = helpers.set_torch_device(args.device)
    # TODO convert to logging
    sys.stderr.write("* Loading model.\n")
    model = load_model(args.model).to(device)
    is_cat_mod = isinstance(model.sublayers[-1],
                            layers.GlobalNormFlipFlopCatMod)
    do_output_mods = args.modified_base_output is not None
    if do_output_mods and not is_cat_mod:
        sys.stderr.write(
            "Cannot output modified bases from canonical base only model.")
        sys.exit()
    n_can_states = nstate_flipflop(model.sublayers[-1].nbase)
    stride = guess_model_stride(model)
    chunk_size = args.chunk_size * stride
    chunk_overlap = args.overlap * stride

    sys.stderr.write("* Initializing reads file search.\n")
    fast5_reads = list(
        fast5utils.iterate_fast5_reads(args.input_folder,
                                       limit=args.limit,
                                       strand_list=args.input_strand_list,
                                       recursive=args.recursive))
    sys.stderr.write("* Found {} reads.\n".format(len(fast5_reads)))

    if args.scaling is not None:
        sys.stderr.write("* Loading read scaling parameters from {}.\n".format(
            args.scaling))
        all_read_params = get_per_read_params_dict_from_tsv(args.scaling)
        input_read_ids = frozenset(rec[1] for rec in fast5_reads)
        scaling_read_ids = frozenset(all_read_params.keys())
        sys.stderr.write("* {} / {} reads have scaling information.\n".format(
            len(input_read_ids & scaling_read_ids), len(input_read_ids)))
        fast5_reads = [
            rec for rec in fast5_reads if rec[1] in scaling_read_ids
        ]
    else:
        all_read_params = {}

    mods_fp = None
    if do_output_mods:
        mods_fp = h5py.File(args.modified_base_output)
        mods_fp.create_group('Reads')
        mod_long_names = model.sublayers[-1].ordered_mod_long_names
        sys.stderr.write("* Preparing modified base output: {}.\n".format(
            ', '.join(map(str, mod_long_names))))
        mods_fp.create_dataset('mod_long_names',
                               data=np.array(mod_long_names, dtype='S'),
                               dtype=h5py.special_dtype(vlen=str))

    sys.stderr.write("* Calling reads.\n")
    nbase, ncalled, nread, nsample = 0, 0, 0, 0
    t0 = time.time()
    progress = Progress(quiet=args.quiet)
    startcharacter = '@' if args.fastq else '>'
    try:
        with open_file_or_stdout(args.output) as fh:
            for read_filename, read_id in fast5_reads:
                read_params = all_read_params[
                    read_id] if read_id in all_read_params else None
                basecall, qstring, read_nsample = process_read(
                    read_filename, read_id, model, chunk_size, chunk_overlap,
                    read_params, n_can_states, stride, args.alphabet,
                    is_cat_mod, mods_fp, args.max_concurrent_chunks,
                    args.fastq, args.qscore_scale, args.qscore_offset)
                if basecall is not None:
                    fh.write("{}{}\n{}\n".format(
                        startcharacter, read_id,
                        basecall[::-1] if args.reverse else basecall))
                    nbase += len(basecall)
                    ncalled += 1
                    if args.fastq:
                        fh.write("+\n{}\n".format(
                            qstring[::-1] if args.reverse else qstring))
                nread += 1
                nsample += read_nsample
                progress.step()
    finally:
        if mods_fp is not None:
            mods_fp.close()
    total_time = time.time() - t0

    sys.stderr.write("* Called {} reads in {:.2f}s\n".format(
        nread, int(total_time)))
    sys.stderr.write("* {:7.2f} kbase / s\n".format(nbase / total_time /
                                                    1000.0))
    sys.stderr.write("* {:7.2f} ksample / s\n".format(nsample / total_time /
                                                      1000.0))
    sys.stderr.write("* {} reads failed.\n".format(nread - ncalled))
    return
def oneread_remap(read_tuple, references, model, device, per_read_params_dict,
                  alphabet_info):
    """ Worker function for remapping reads using flip-flop model on raw signal
    :param read_tuple                 : read, identified by a tuple (filepath, read_id)
    :param references                 :dict mapping fast5 filenames to reference strings
    :param model                      :pytorch model (the torch data structure, not a filename)
    :param device                     :integer specifying which GPU to use for remapping, or 'cpu' to use CPU
    :param per_read_params_dict       :dictionary where keys are UUIDs, values are dicts containing keys
                                         trim_start trim_end shift scale
    :param alphabet_info              : AlphabetInfo object for basecalling

    :returns: tuple of dictionary as specified in mapped_signal_files.Read class
              and a message string indicating an error if one occured
    """
    filename, read_id = read_tuple
    try:
        with fast5_interface.get_fast5_file(filename, 'r') as f5file:
            read = f5file.get_read(read_id)
            sig = signal.Signal(read)
    except Exception:
        return None, READ_ID_INFO_NOT_FOUND_ERR_TEXT

    if read_id in references:
        read_ref = references[read_id]
    else:
        return None, NO_REF_FOUND_ERR_TEXT

    try:
        read_params_dict = per_read_params_dict[read_id]
    except KeyError:
        return None, NO_PARAMS_ERR_TEXT

    sig.set_trim_absolute(read_params_dict['trim_start'],
                          read_params_dict['trim_end'])

    try:
        torch.set_num_threads(
            1
        )  # Prevents torch doing its own parallelisation on top of our imap_map
        # Standardise (i.e. shift/scale so that approximately mean =0, std=1)
        signalArray = (sig.current -
                       read_params_dict['shift']) / read_params_dict['scale']
        # Make signal into 3D tensor with shape [siglength,1,1] and move to appropriate device (GPU  number or CPU)
        signalTensor = torch.tensor(
            signalArray[:, np.newaxis, np.newaxis].astype(taiyaki_dtype),
            device=device)
        # The model must live on the same device
        modelOnDevice = model.to(device)
        # Apply the network to the signal, generating transition weight matrix, and put it back into a numpy array
        with torch.no_grad():
            transweights = modelOnDevice(signalTensor).cpu().numpy()
    except Exception:
        return None, REMAP_ERR_TEXT

    # Extra dimensions introduced by np.newaxis above removed by np.squeeze
    can_read_ref = alphabet_info.collapse_sequence(read_ref)
    remappingscore, path = flipflop_remap.flipflop_remap(
        np.squeeze(transweights),
        can_read_ref,
        alphabet=alphabet_info.can_bases,
        localpen=0.0)
    # read_ref comes out as a bytes object, so we need to convert to str
    # localpen=0.0 does local alignment

    # flipflop_remap() establishes a mapping between the network outputs and the reference.
    # What we need is a mapping between the signal and the reference.
    # To resolve this we need to know the stride of the model (how many samples for each network output)
    model_stride = helpers.guess_model_stride(model)
    remapping = mapping.Mapping.from_remapping_path(sig, path, read_ref,
                                                    model_stride)
    remapping.add_integer_reference(alphabet_info.alphabet)

    return remapping.get_read_dictionary(read_params_dict['shift'],
                                         read_params_dict['scale'],
                                         read_id), REMAP_SUCCESS_TEXT
Exemple #6
0
def load_network(args, alphabet_info, res_info, log):
    log.write('* Reading network from {}\n'.format(args.model))
    if res_info.is_lead_process:
        # Under pytorch's DistributedDataParallel scheme, we
        # need a clone of the start network to use as a template for saving
        # checkpoints. Necessary because DistributedParallel makes the class
        # structure different.
        model_kwargs = {
            'stride': args.stride,
            'winlen': args.winlen,
            'insize': 1,
            'size': args.size,
            'alphabet_info': alphabet_info
        }
        model_metadata = {
            'reverse': args.reverse,
            'standardize': args.standardize
        }
        net_clone = helpers.load_model(args.model,
                                       model_metadata=model_metadata,
                                       **model_kwargs)
        log.write('* Network has {} parameters.\n'.format(
            sum(p.nelement() for p in net_clone.parameters())))

        if not alphabet_info.is_compatible_model(net_clone):
            sys.stderr.write(
                '* ERROR: Model and mapped signal files contain ' +
                'incompatible alphabet definitions (including modified ' +
                'bases).')
            sys.exit(1)
        if layers.is_cat_mod_model(net_clone):
            log.write('* Loaded categorical modified base model.\n')
            if not alphabet_info.contains_modified_bases():
                sys.stderr.write(
                    '* ERROR: Modified bases model specified, but mapped ' +
                    'signal file does not contain modified bases.')
                sys.exit(1)
        else:
            log.write('* Loaded standard (canonical bases-only) model.\n')
            if alphabet_info.contains_modified_bases():
                sys.stderr.write(
                    '* ERROR: Standard (canonical bases only) model ' +
                    'specified, but mapped signal file does contains ' +
                    'modified bases.')
                sys.exit(1)
        if layers.is_delta_model(net_clone) and model_metadata.standardize:
            log.write('*' * 60 + '\n* WARNING: Delta-scaling models trained ' +
                      'with --standardize are not compatible with Guppy.\n' +
                      '*' * 60)
        log.write('* Dumping initial model\n')
        helpers.save_model(net_clone, args.outdir, 0)
    else:
        net_clone = None

    if res_info.is_multi_gpu:
        # so that processes 1,2,3.. don't try to load before process 0 has
        # saved
        torch.distributed.barrier()
        log.write('* MultiGPU process {}'.format(args.local_rank))
        log.write(': loading initial model saved by process 0\n')
        saved_startmodel_path = os.path.join(
            args.outdir, 'model_checkpoint_00000.checkpoint')
        network = helpers.load_model(saved_startmodel_path).to(res_info.device)
        network_metadata = parse_network_metadata(network)
        # Wrap network for training in the DistributedDataParallel structure
        network = torch.nn.parallel.DistributedDataParallel(
            network,
            device_ids=[args.local_rank],
            output_device=args.local_rank)
    else:
        log.write('* Loading model onto device\n')
        network = net_clone.to(res_info.device)
        network_metadata = parse_network_metadata(network)
        net_clone = None

    log.write('* Estimating filter parameters from training data\n')
    stride = guess_model_stride(network)
    optimiser = torch.optim.AdamW(network.parameters(),
                                  lr=args.lr_max,
                                  betas=args.adam,
                                  weight_decay=args.weight_decay,
                                  eps=args.eps)

    lr_warmup = args.lr_min if args.lr_warmup is None else args.lr_warmup
    adam_beta1, _ = args.adam
    if args.warmup_batches >= args.niteration:
        sys.stderr.write('* Error: --warmup_batches must be < --niteration\n')
        sys.exit(1)
    warmup_fraction = args.warmup_batches / args.niteration
    # Pytorch OneCycleLR crashes if pct_start==1 (i.e. warmup_fraction==1)
    lr_scheduler = torch.optim.lr_scheduler.OneCycleLR(
        optimiser,
        args.lr_max,
        total_steps=args.niteration,
        # pct_start is really fractional, not percent
        pct_start=warmup_fraction,
        div_factor=args.lr_max / lr_warmup,
        final_div_factor=lr_warmup / args.lr_min,
        cycle_momentum=(args.min_momentum is not None),
        base_momentum=adam_beta1 if args.min_momentum is None \
        else args.min_momentum,
        max_momentum=adam_beta1
    )
    log.write(
        ('* Learning rate increases from {:.2e} to {:.2e} over {} ' +
         'iterations using cosine schedule.\n').format(lr_warmup, args.lr_max,
                                                       args.warmup_batches))
    log.write(('* Then learning rate decreases from {:.2e} to {:.2e} over ' +
               '{} iterations using cosine schedule.\n').format(
                   args.lr_max, args.lr_min,
                   args.niteration - args.warmup_batches))

    if args.gradient_clip_num_mads is None:
        log.write('* No gradient clipping\n')
        rolling_mads = None
    else:
        nparams = len([p for p in network.parameters() if p.requires_grad])
        if nparams == 0:
            rolling_mads = None
            log.write('* No gradient clipping due to missing parameters\n')
        else:
            rolling_mads = maths.RollingMAD(nparams,
                                            args.gradient_clip_num_mads)
            log.write((
                '* Gradients will be clipped (by value) at {:3.2f} MADs ' +
                'above the median of the last {} gradient maximums.\n').format(
                    rolling_mads.n_mads, rolling_mads.window))

    net_info = NETWORK_INFO(net=network,
                            net_clone=net_clone,
                            metadata=network_metadata,
                            stride=stride)
    optim_info = OPTIM_INFO(optimiser=optimiser,
                            lr_warmup=lr_warmup,
                            lr_scheduler=lr_scheduler,
                            rolling_mads=rolling_mads)

    return net_info, optim_info
Exemple #7
0
def oneread_remap(read_tuple, references, model, device, per_read_params_dict,
                  alphabet, collapse_alphabet):
    """ Worker function for remapping reads using flip-flop model on raw signal
    :param read_tuple                 : read, identified by a tuple (filepath, read_id)
    :param references                 :dict mapping fast5 filenames to reference strings
    :param model                      :pytorch model (the torch data structure, not a filename)
    :param device                     :integer specifying which GPU to use for remapping, or 'cpu' to use CPU
    :param per_read_params_dict       :dictionary where keys are UUIDs, values are dicts containing keys
                                         trim_start trim_end shift scale
    :param alphabet                   : alphabet for basecalling (passed on to mapped-read file)
    :param collapse_alphabet          : collapsed alphabet for basecalling (passed on to mapped-read file)

    :returns: dictionary as specified in mapped_signal_files.Read class
    """
    filename, read_id = read_tuple
    try:
        with fast5_interface.get_fast5_file(filename, 'r') as f5file:
            read = f5file.get_read(read_id)
            sig = signal.Signal(read)
    except Exception as e:
        # We want any single failure in the batch of reads to not disrupt other reads being processed.
        sys.stderr.write(
            'No read information on read {} found in file {}.\n{}\n'.format(
                read_id, filename, repr(e)))
        return None

    if read_id in references:
        read_ref = references[read_id].decode("utf-8")
    else:
        sys.stderr.write('No fasta reference found for {}.\n'.format(read_id))
        return None

    if read_id in per_read_params_dict:
        read_params_dict = per_read_params_dict[read_id]
    else:
        return None

    sig.set_trim_absolute(read_params_dict['trim_start'],
                          read_params_dict['trim_end'])

    try:
        torch.set_num_threads(
            1
        )  # Prevents torch doing its own parallelisation on top of our imap_map
        # Standardise (i.e. shift/scale so that approximately mean =0, std=1)
        signalArray = (sig.current -
                       read_params_dict['shift']) / read_params_dict['scale']
        # Make signal into 3D tensor with shape [siglength,1,1] and move to appropriate device (GPU  number or CPU)
        signalTensor = torch.tensor(
            signalArray[:, np.newaxis, np.newaxis].astype(taiyaki_dtype),
            device=device)
        # The model must live on the same device
        modelOnDevice = model.to(device)
        # Apply the network to the signal, generating transition weight matrix, and put it back into a numpy array
        with torch.no_grad():
            transweights = modelOnDevice(signalTensor).cpu().numpy()
    except Exception as e:
        sys.stderr.write(
            "Failure applying basecall network to remap read {}.\n{}\n".format(
                sig.read_id, repr(e)))
        return None

    # Extra dimensions introduced by np.newaxis above removed by np.squeeze
    remappingscore, path = flipflop_remap.flipflop_remap(
        np.squeeze(transweights), read_ref, localpen=0.0)
    # read_ref comes out as a bytes object, so we need to convert to str
    # localpen=0.0 does local alignment

    # flipflop_remap() establishes a mapping between the network outputs and the reference.
    # What we need is a mapping between the signal and the reference.
    # To resolve this we need to know the stride of the model (how many samples for each network output)
    model_stride = helpers.guess_model_stride(model, device=device)
    remapping = mapping.Mapping.from_remapping_path(sig, path, read_ref,
                                                    model_stride)

    return remapping.get_read_dictionary(read_params_dict['shift'],
                                         read_params_dict['scale'],
                                         read_id,
                                         alphabet=alphabet,
                                         collapse_alphabet=collapse_alphabet)
Exemple #8
0
def main():
    args = parser.parse_args()

    assert args.device != 'cpu', "Flipflop basecalling in taiyaki requires a GPU and for cupy to be installed"
    device = torch.device(args.device)
    # TODO convert to logging
    sys.stderr.write("* Loading model.\n")
    model = load_model(args.model).to(device)
    is_cat_mod = isinstance(model.sublayers[-1], layers.GlobalNormFlipFlopCatMod)
    do_output_mods = args.modified_base_output is not None
    if do_output_mods and not is_cat_mod:
        sys.stderr.write(
            "Cannot output modified bases from canonical base only model.")
        sys.exit()
    n_can_states = nstate_flipflop(model.sublayers[-1].nbase)
    stride = guess_model_stride(model, device=device)
    chunk_size, chunk_overlap = basecall_helpers.round_chunk_values(
        args.chunk_size, args.overlap, stride)

    sys.stderr.write("* Initializing reads file search.\n")
    fast5_reads = fast5utils.iterate_fast5_reads(
        args.input_folder, limit=args.limit, strand_list=args.input_strand_list,
        recursive=args.recursive)

    mods_fp = None
    if do_output_mods:
        mods_fp = h5py.File(args.modified_base_output)
        mods_fp.create_group('Reads')
        mod_long_names = model.sublayers[-1].ordered_mod_long_names
        sys.stderr.write("* Preparing modified base output: {}.\n".format(
            ', '.join(map(str, mod_long_names))))
        mods_fp.create_dataset(
            'mod_long_names', data=np.array(mod_long_names, dtype='S'),
            dtype=h5py.special_dtype(vlen=str))

    sys.stderr.write("* Calling reads.\n")
    nbase, ncalled, nread, nsample = 0, 0, 0, 0
    t0 = time.time()
    progress = Progress(quiet=args.quiet)
    try:
        with open_file_or_stdout(args.output) as fh:
            for read_filename, read_id in fast5_reads:
                basecall, read_nsample = process_read(
                    read_filename, read_id, model, chunk_size,
                    chunk_overlap, device, n_can_states, stride, args.alphabet,
                    is_cat_mod, mods_fp)
                if basecall is not None:
                    fh.write(">{}\n{}\n".format(read_id, basecall))
                    nbase += len(basecall)
                    ncalled += 1
                nread += 1
                nsample += read_nsample
                progress.step()
    finally:
        if mods_fp is not None:
            mods_fp.close()
    total_time = time.time() - t0

    sys.stderr.write("* Called {} reads in {}s\n".format(nread, int(total_time)))
    sys.stderr.write("* {:7.2f} kbase / s\n".format(nbase / total_time / 1000.0))
    sys.stderr.write("* {:7.2f} ksample / s\n".format(nsample / total_time / 1000.0))
    sys.stderr.write("* {} reads failed.\n".format(nread - ncalled))
    return
Exemple #9
0
def run_model(normed_signal,
              model,
              chunk_size=_DEFAULT_CHUNK_SIZE,
              overlap=_DEFAULT_OVERLAP,
              max_concur_chunks=None,
              return_numpy=True,
              return_tensor_on_device=True):
    """ Hook for megalodon to run network via taiyaki

    Note:
        The `chunk_size` and `overlap` parameters as multiples of the stride of
    `model` rather than as number of samples.  This behaviour is consistent
    with the parameterisation in Guppy.

    Args:
        normed_signal (:class:`ndarray`): Signal of read, which will be chunked
            and the result of calling and stitching returned.
        model (:class:`layers.Serial`): A Taiyaki model, implicitly assumed to
            to have a :class:`layers.Serial` as its outmost layer and for the
            first wrapped layer to have parameters.
        chunk_size (int, optional): Length of chunks into which `signal` will
            be split.
        overlap (int, optional): Overlap between one chunk and the next.
        max_concur_chunks (int, optional): Calculate chunks in batches of size
            at most `max_concur_chunks`; if None, then all chunks are
            calculated at once.
        return_numpy (bool, optional): Return value should be converted
            :class:`ndarray` (default).
        return_tensor_on_device (bool, optional):  Return value should be moved
            back onto same device as model (default).  Overridden by
            `return_numpy`.

    Returns:
        :class:`Tensor`: Output of basecalling chunks, stitched together. If
            `return_tensor_on_device` is True, then the stitched chunks are
            transferred back on to the GPU device; otherwise, they are returned
            in host memory ("cpu" device).

        If `return_numpy` is True, the return type is converted to a
        :class:`ndarray` and remains in host memory.
    """
    device = get_model_device(model)
    stride = guess_model_stride(model)
    chunk_size *= stride
    overlap *= stride

    chunks, chunk_starts, chunk_ends = chunk_read(normed_signal, chunk_size,
                                                  overlap)

    chunks = torch.tensor(chunks)
    with torch.no_grad():
        if max_concur_chunks is None:
            out = model(chunks.to(device)).cpu()
        else:
            out = []
            for some_chunks in torch.split(chunks, max_concur_chunks, 1):
                out.append(model(some_chunks.to(device)).cpu())
            out = torch.cat(out, 1)
        stitched_chunks = stitch_chunks(out, chunk_starts, chunk_ends, stride)
    if return_numpy:
        return stitched_chunks.numpy()
    if return_tensor_on_device:
        return stitched_chunks.to(device)
    return stitched_chunks
Exemple #10
0
    def _load_taiyaki_model(self):
        LOGGER.info('Loading taiyaki basecalling backend.')
        self.model_type = TAI_NAME

        devices = self.params.taiyaki.devices
        if devices is None:
            devices = ['cpu', ]
        self.process_devices = [
            parse_device(devices[i]) for i in np.tile(
                np.arange(len(devices)),
                (self.num_proc // len(devices)) + 1)][:self.num_proc]

        try:
            # import modules
            from taiyaki.helpers import (
                load_model as load_taiyaki_model, guess_model_stride)
            from taiyaki.basecall_helpers import run_model as tai_run_model
            from taiyaki.layers import GlobalNormFlipFlopCatMod
        except ImportError:
            LOGGER.error(
                'Failed to import taiyaki. Ensure working ' +
                'installations to run megalodon')
            sys.exit(1)
        try:
            import torch
        except ImportError:
            LOGGER.error(
                'Failed to import pytorch. Ensure working ' +
                'installations to run megalodon')
            sys.exit(1)

        # store modules in object
        self.load_taiyaki_model = load_taiyaki_model
        self.tai_run_model = tai_run_model
        self.torch = torch

        tmp_model = self.load_taiyaki_model(
            self.params.taiyaki.taiyaki_model_fn)
        ff_layer = tmp_model.sublayers[-1]
        self.is_cat_mod = (
            GlobalNormFlipFlopCatMod is not None and isinstance(
                ff_layer, GlobalNormFlipFlopCatMod))
        self.stride = guess_model_stride(tmp_model)
        self.output_size = ff_layer.size
        if self.is_cat_mod:
            # Modified base model is defined by 3 fixed fields in taiyaki
            # can_nmods, output_alphabet and modified_base_long_names
            self.output_alphabet = ff_layer.output_alphabet
            self.can_nmods = ff_layer.can_nmods
            self.ordered_mod_long_names = ff_layer.ordered_mod_long_names
            self.compute_mod_alphabet_attrs()
        else:
            if mh.nstate_to_nbase(self.output_size) != 4:
                raise NotImplementedError(
                    'Naive modified base flip-flop models are not ' +
                    'supported.')
            self.output_alphabet = mh.ALPHABET
            self.can_alphabet = mh.ALPHABET
            self.mod_long_names = []
            self.str_to_int_mod_labels = {}
            self.can_nmods = None
        self.n_mods = len(self.mod_long_names)
Exemple #11
0
def oneread_remap(read_tuple,
                  model,
                  per_read_params_dict,
                  alphabet_info,
                  max_read_length,
                  device='cpu',
                  localpen=0.0):
    """ Worker function for remapping reads using flip-flop model on raw signal

    Args:
        read_tuple (tuple) : read, identified by a tuple
                                  (filepath, read_id, read reference)
        model (pytorch Module): pytorch model
        device (int or float): integer specifying which GPU to use for
                                remapping, or 'cpu' to use CPU
        per_read_params_dict (dict) : dictionary where keys are UUIDs,
                                      values are dicts containing keys
                                      trim_start trim_end shift scale
        alphabet_info (AlphabetInfo object):  for basecalling
        max_read_length (int) : Don't attempt to remap reads with references
                                longer than this
        localpen (float): Penalty for local mapping

    Returns:
        tuple :(dict,str) containing
        1. dictionary as specified in
            signal_mapping.SignalMapping.get_read_dictionary
        2. message string indicating an error if one occured
    """
    filename, read_id, read_ref = read_tuple

    if read_ref is None:
        return None, RemapResult.NO_REF_FOUND

    if max_read_length is not None and len(read_ref) > max_read_length:
        return None, RemapResult.REF_TOO_LONG

    try:
        read_params_dict = per_read_params_dict[read_id]
    except KeyError:
        return None, RemapResult.NO_PARAMS

    try:
        with fast5_interface.get_fast5_file(filename, 'r') as f5file:
            read = f5file.get_read(read_id)
            sig = signal.Signal(read, read_params=read_params_dict)
    except Exception:
        return None, RemapResult.READ_ID_INFO_NOT_FOUND

    try:
        # Prevents torch doing its own parallelisation on top of our imap_map
        torch.set_num_threads(1)
        # Make signal into 3D tensor with shape [siglength,1,1] and move to
        # appropriate device (GPU  number or CPU)
        signalTensor = torch.tensor(
            sig.standardized_current[:, np.newaxis,
                                     np.newaxis].astype(np.float32),
            device=device)
        # The model must live on the same device
        modelOnDevice = model.to(device)
        # Apply the network to the signal, generating transition weight matrix,
        # and put it back into a numpy array
        with torch.no_grad():
            transweights = modelOnDevice(signalTensor).cpu().numpy()
    except Exception:
        return None, RemapResult.NETWORK_ERROR

    # Extra dimensions introduced by np.newaxis above removed by np.squeeze
    can_read_ref = alphabet_info.collapse_sequence(read_ref)
    remappingscore, path = flipflop_remap.flipflop_remap(
        np.squeeze(transweights),
        can_read_ref,
        alphabet=alphabet_info.can_bases,
        localpen=localpen)
    # read_ref comes out as a bytes object, so we need to convert to str
    # localpen=0.0 does local alignment

    # flipflop_remap() establishes a mapping between the network outputs and
    # the reference.
    # What we need is a mapping between the signal and the reference.
    # To resolve this we need to know the stride of the model (how many samples
    # for each network output)
    model_stride = helpers.guess_model_stride(model)
    int_ref = signal_mapping.SignalMapping.get_integer_reference(
        read_ref, alphabet_info.alphabet)
    sig_mapping = signal_mapping.SignalMapping.from_remapping_path(
        path, int_ref, model_stride, sig)
    try:
        sig_mapping_dict = sig_mapping.get_read_dictionary()
    except signal_mapping.TaiyakiSigMapError as e:
        return None, str(e)
    return sig_mapping_dict, RemapResult.SUCCESS