示例#1
0
def convert_seq(s, alphabet):
    buf = np.array(list(s))
    for i, b in enumerate(alphabet):
        buf[buf == b] = i
    buf = buf.astype('i4')
    assert np.all(buf < len(alphabet)), "Alphabet violates assumption in convert_seq"
    return flipflopfings.flipflop_code(buf, len(alphabet))
示例#2
0
def prepare_random_batches(device, read_data, batch_chunk_len, sub_batch_size,
                           target_sub_batches, alphabet_info, filter_params,
                           network, network_is_catmod, log):
    total_sub_batches = 0

    while total_sub_batches < target_sub_batches:

        # Chunk_batch is a list of dicts
        chunk_batch, batch_rejections = \
            chunk_selection.assemble_batch(read_data, sub_batch_size,
                                           batch_chunk_len, filter_params)
        if len(chunk_batch) < sub_batch_size:
            log.write(
                '* Warning: only {} chunks passed filters (asked for {}).\n'.
                format(len(chunk_batch), sub_batch_size))

        if not all(len(d['sequence']) > 0.0 for d in chunk_batch):
            raise Exception('Error: zero length sequence')

        # Shape of input tensor must be:
        #     (timesteps) x (batch size) x (input channels)
        # in this case:
        #     batch_chunk_len x sub_batch_size x 1
        stacked_current = np.vstack([d['current'] for d in chunk_batch]).T
        indata = torch.tensor(stacked_current,
                              device=device,
                              dtype=torch.float32).unsqueeze(2)

        # Prepare seqs, seqlens and (if necessary) mod_cats
        seqs, seqlens = [], []
        mod_cats = [] if network_is_catmod else None
        for chunk in chunk_batch:
            chunk_labels = chunk['sequence']
            seqlens.append(len(chunk_labels))
            if network_is_catmod:
                chunk_mod_cats = np.ascontiguousarray(
                    network.sublayers[-1].mod_labels[chunk_labels])
                mod_cats.append(chunk_mod_cats)
                # convert chunk_labels to canonical base labels
                chunk_labels = np.ascontiguousarray(
                    network.sublayers[-1].can_labels[chunk_labels])
            chunk_seq = flipflopfings.flipflop_code(chunk_labels,
                                                    alphabet_info.ncan_base)
            seqs.append(chunk_seq)

        seqs = torch.tensor(np.concatenate(seqs),
                            dtype=torch.float32,
                            device=device)
        seqlens = torch.tensor(seqlens, dtype=torch.long, device=device)
        if network_is_catmod:
            mod_cats = torch.tensor(np.concatenate(mod_cats),
                                    dtype=torch.long,
                                    device=device)

        total_sub_batches += 1

        yield indata, seqs, seqlens, mod_cats, sub_batch_size, batch_rejections
示例#3
0
def convert_seq(s, alphabet):
    """Convert str sequence to flip-flop integer codes

    Args:
        s (str) : base sequence (e.g. 'ACCCTGGA')
        alphabet (str): alphabet of bases for coding (e.g. 'ACGT')

    Returns:
        np i4 array : flip-flop coded sequence (e.g. 01513260)
    """
    buf = np.array(list(s))
    for i, b in enumerate(alphabet):
        buf[buf == b] = i
    buf = buf.astype('i4')
    assert np.all(
        buf < len(alphabet)), "Alphabet violates assumption in convert_seq"
    return flipflopfings.flipflop_code(buf, len(alphabet))
示例#4
0
def main():
    args = parser.parse_args()

    np.random.seed(args.seed)

    device = torch.device(args.device)
    if device.type == 'cuda':
        try:
            torch.cuda.set_device(device)
        except AttributeError:
            sys.stderr.write('ERROR: Torch not compiled with CUDA enabled ' +
                             'and GPU device set.')
            sys.exit(1)

    if not os.path.exists(args.output):
        os.mkdir(args.output)
    elif not args.overwrite:
        sys.stderr.write('Error: Output directory {} exists but --overwrite ' +
                         'is false\n'.format(args.output))
        exit(1)
    if not os.path.isdir(args.output):
        sys.stderr.write('Error: Output location {} is not directory\n'.format(
            args.output))
        exit(1)

    copyfile(args.model, os.path.join(args.output, 'model.py'))

    # Create a logging file to save details of chunks.
    # If args.chunk_logging_threshold is set to 0 then we log all chunks
    # including those rejected.
    chunk_log = chunk_selection.ChunkLog(args.output)

    log = helpers.Logger(os.path.join(args.output, 'model.log'), args.quiet)
    log.write('* Taiyaki version {}\n'.format(__version__))
    log.write('* Command line\n')
    log.write(' '.join(sys.argv) + '\n')
    log.write('* Loading data from {}\n'.format(args.input))
    log.write('* Per read file MD5 {}\n'.format(helpers.file_md5(args.input)))

    if args.input_strand_list is not None:
        read_ids = list(set(helpers.get_read_ids(args.input_strand_list)))
        log.write(('* Will train from a subset of {} strands, determined ' +
                   'by read_ids in input strand list\n').format(len(read_ids)))
    else:
        log.write('* Reads not filtered by id\n')
        read_ids = 'all'

    if args.limit is not None:
        log.write('* Limiting number of strands to {}\n'.format(args.limit))

    with mapped_signal_files.HDF5Reader(args.input) as per_read_file:
        alphabet, _, _ = per_read_file.get_alphabet_information()
        read_data = per_read_file.get_multiple_reads(read_ids,
                                                     max_reads=args.limit)
        # read_data now contains a list of reads
        # (each an instance of the Read class defined in
        # mapped_signal_files.py, based on dict)

    if len(read_data) == 0:
        log.write('* No reads remaining for training, exiting.\n')
        exit(1)
    log.write('* Loaded {} reads.\n'.format(len(read_data)))

    # Get parameters for filtering by sampling a subset of the reads
    # Result is a tuple median mean_dwell, mad mean_dwell
    # Choose a chunk length in the middle of the range for this
    sampling_chunk_len = (args.chunk_len_min + args.chunk_len_max) // 2
    filter_parameters = chunk_selection.sample_filter_parameters(
        read_data,
        args.sample_nreads_before_filtering,
        sampling_chunk_len,
        args,
        log,
        chunk_log=chunk_log)

    medmd, madmd = filter_parameters

    log.write(
        "* Sampled {} chunks: median(mean_dwell)={:.2f}, mad(mean_dwell)={:.2f}\n"
        .format(args.sample_nreads_before_filtering, medmd, madmd))
    log.write('* Reading network from {}\n'.format(args.model))
    nbase = len(alphabet)
    model_kwargs = {
        'stride': args.stride,
        'winlen': args.winlen,
        # Number of input features to model e.g. was >1 for event-based
        # models (level, std, dwell)
        'insize': 1,
        'size': args.size,
        'outsize': flipflopfings.nstate_flipflop(nbase)
    }
    network = helpers.load_model(args.model, **model_kwargs).to(device)
    log.write('* Network has {} parameters.\n'.format(
        sum([p.nelement() for p in network.parameters()])))

    optimizer = torch.optim.Adam(network.parameters(),
                                 lr=args.lr_max,
                                 betas=args.adam,
                                 weight_decay=args.weight_decay)

    lr_scheduler = optim.CosineFollowedByFlatLR(optimizer, args.lr_min,
                                                args.lr_cosine_iters)

    score_smoothed = helpers.WindowedExpSmoother()

    log.write('* Dumping initial model\n')
    helpers.save_model(network, args.output, 0)

    total_bases = 0
    total_samples = 0
    total_chunks = 0
    # To count the numbers of different sorts of chunk rejection
    rejection_dict = defaultdict(int)

    t0 = time.time()
    log.write('* Training\n')

    for i in range(args.niteration):
        lr_scheduler.step()
        # Chunk length is chosen randomly in the range given but forced to
        # be a multiple of the stride
        batch_chunk_len = (
            np.random.randint(args.chunk_len_min, args.chunk_len_max + 1) //
            args.stride) * args.stride
        # We choose the batch size so that the size of the data in the batch
        # is about the same as args.min_batch_size chunks of length
        # args.chunk_len_max
        target_batch_size = int(args.min_batch_size * args.chunk_len_max /
                                batch_chunk_len + 0.5)
        # ...but it can't be more than the number of reads.
        batch_size = min(target_batch_size, len(read_data))

        # If the logging threshold is 0 then we log all chunks, including those
        # rejected, so pass the log
        # object into assemble_batch
        if args.chunk_logging_threshold == 0:
            log_rejected_chunks = chunk_log
        else:
            log_rejected_chunks = None
        # Chunk_batch is a list of dicts.
        chunk_batch, batch_rejections = chunk_selection.assemble_batch(
            read_data,
            batch_size,
            batch_chunk_len,
            filter_parameters,
            args,
            log,
            chunk_log=log_rejected_chunks)
        total_chunks += len(chunk_batch)

        # Update counts of reasons for rejection
        for k, v in batch_rejections.items():
            rejection_dict[k] += v

        # Shape of input tensor must be:
        #     (timesteps) x (batch size) x (input channels)
        # in this case:
        #     batch_chunk_len x batch_size x 1
        stacked_current = np.vstack([d['current'] for d in chunk_batch]).T
        indata = torch.tensor(stacked_current,
                              device=device,
                              dtype=torch.float32).unsqueeze(2)
        # Sequence input tensor is just a 1D vector, and so is seqlens
        seqs = torch.tensor(np.concatenate([
            flipflopfings.flipflop_code(d['sequence'], nbase)
            for d in chunk_batch
        ]),
                            device=device,
                            dtype=torch.long)
        seqlens = torch.tensor([len(d['sequence']) for d in chunk_batch],
                               dtype=torch.long,
                               device=device)

        optimizer.zero_grad()
        outputs = network(indata)
        lossvector = ctc.crf_flipflop_loss(outputs, seqs, seqlens,
                                           args.sharpen)
        loss = lossvector.sum() / (seqlens > 0.0).float().sum()
        loss.backward()
        optimizer.step()

        fval = float(loss)
        score_smoothed.update(fval)

        # Check for poison chunk and save losses and chunk locations if we're
        # poisoned If args.chunk_logging_threshold set to zero then we log
        # everything
        if fval / score_smoothed.value >= args.chunk_logging_threshold:
            chunk_log.write_batch(i, chunk_batch, lossvector)

        total_bases += int(seqlens.sum())
        total_samples += int(indata.nelement())

        # Doing this deletion leads to less CUDA memory usage.
        del indata, seqs, seqlens, outputs, loss, lossvector
        if device.type == 'cuda':
            torch.cuda.empty_cache()

        if (i + 1) % args.save_every == 0:
            helpers.save_model(network, args.output,
                               (i + 1) // args.save_every)
            log.write('C')
        else:
            log.write('.')

        if (i + 1) % DOTROWLENGTH == 0:
            # In case of super batching, additional functionality must be
            # added here
            learning_rate = lr_scheduler.get_lr()[0]
            tn = time.time()
            dt = tn - t0
            t = (
                ' {:5d} {:5.3f}  {:5.2f}s ({:.2f} ksample/s {:.2f} kbase/s) ' +
                'lr={:.2e}')
            log.write(
                t.format((i + 1) // DOTROWLENGTH, score_smoothed.value, dt,
                         total_samples / 1000.0 / dt,
                         total_bases / 1000.0 / dt, learning_rate))
            # Write summary of chunk rejection reasons
            for k, v in rejection_dict.items():
                log.write(" {}:{} ".format(k, v))
            log.write("\n")
            total_bases = 0
            total_samples = 0
            t0 = tn

    helpers.save_model(network, args.output)
示例#5
0
def prepare_random_batches(read_data,
                           batch_chunk_len,
                           sub_batch_size,
                           target_sub_batches,
                           alphabet_info,
                           filter_params,
                           net_info,
                           log,
                           select_strands_randomly=True,
                           first_strand_index=0,
                           pin=True):
    total_sub_batches = 0
    if net_info.metadata.reverse:
        revop = np.flip
    else:
        revop = np.array

    while total_sub_batches < target_sub_batches:

        # Chunk_batch is a list of dicts
        chunk_batch, batch_rejections = chunk_selection.sample_chunks(
            read_data,
            sub_batch_size,
            batch_chunk_len,
            filter_params,
            standardize=net_info.metadata.standardize,
            select_strands_randomly=select_strands_randomly,
            first_strand_index=first_strand_index)
        first_strand_index += sum(batch_rejections.values())
        if len(chunk_batch) < sub_batch_size:
            log.write(('* Warning: only {} chunks passed filters ' +
                       '(asked for {}).\n').format(len(chunk_batch),
                                                   sub_batch_size))

        if not all(chunk.seq_len > 0.0 for chunk in chunk_batch):
            raise Exception('Error: zero length sequence')

        # Shape of input tensor must be:
        #     (timesteps) x (batch size) x (input channels)
        # in this case:
        #     batch_chunk_len x sub_batch_size x 1
        stacked_current = np.vstack(
            [revop(chunk.current) for chunk in chunk_batch]).T
        indata = torch.tensor(stacked_current,
                              device='cpu',
                              dtype=torch.float32).unsqueeze(2)

        if pin and torch.cuda.is_available():
            indata = indata.pin_memory()

        # Prepare seqs, seqlens and (if necessary) mod_cats
        seqs, seqlens = [], []
        mod_cats = [] if net_info.metadata.is_cat_mod else None
        for chunk in chunk_batch:
            chunk_labels = revop(chunk.sequence)
            seqlens.append(len(chunk_labels))
            if net_info.metadata.is_cat_mod:
                chunk_mod_cats = np.ascontiguousarray(
                    net_info.metadata.mod_labels[chunk_labels])
                mod_cats.append(chunk_mod_cats)
                # convert chunk_labels to canonical base labels
                chunk_labels = np.ascontiguousarray(
                    net_info.metadata.can_labels[chunk_labels])
            chunk_seq = flipflopfings.flipflop_code(chunk_labels,
                                                    alphabet_info.ncan_base)
            seqs.append(chunk_seq)

        seqs = torch.tensor(np.concatenate(seqs),
                            dtype=torch.long,
                            device='cpu')
        seqlens = torch.tensor(seqlens, dtype=torch.long, device='cpu')
        if net_info.metadata.is_cat_mod:
            mod_cats = torch.tensor(np.concatenate(mod_cats),
                                    dtype=torch.long,
                                    device='cpu')

        total_sub_batches += 1

        yield indata, seqs, seqlens, mod_cats, sub_batch_size, batch_rejections
示例#6
0
def main():
    args = parser.parse_args()
    log, loss_log, chunk_log, device = _setup_and_logs(args)

    read_data, alphabet_info = _load_data(args, log)
    # Get parameters for filtering by sampling a subset of the reads
    # Result is a tuple median mean_dwell, mad mean_dwell
    # Choose a chunk length in the middle of the range for this
    filter_parameters = chunk_selection.sample_filter_parameters(
        read_data,
        args.sample_nreads_before_filtering,
        (args.chunk_len_min + args.chunk_len_max) // 2,
        args,
        log,
        chunk_log=chunk_log)
    log.write(("* Sampled {} chunks: median(mean_dwell)={:.2f}, " +
               "mad(mean_dwell)={:.2f}\n").format(
                   args.sample_nreads_before_filtering, *filter_parameters))

    log.write('* Reading network from {}\n'.format(args.model))
    model_kwargs = {
        'insize': 1,
        'winlen': args.winlen,
        'stride': args.stride,
        'size': args.size,
        'alphabet_info': alphabet_info
    }
    network = helpers.load_model(args.model, **model_kwargs).to(device)
    if not isinstance(network.sublayers[-1], layers.GlobalNormFlipFlopCatMod):
        log.write(
            'ERROR: Model must end with GlobalNormCatModFlipFlop layer, ' +
            'not {}.\n'.format(str(network.sublayers[-1])))
        sys.exit(1)
    can_mods_offsets = network.sublayers[-1].can_mods_offsets
    flipflop_can_labels = network.sublayers[-1].can_labels
    flipflop_mod_labels = network.sublayers[-1].mod_labels
    flipflop_ncan_base = network.sublayers[-1].ncan_base
    log.write('* Loaded categorical modifications flip-flop model.\n')
    log.write('* Network has {} parameters.\n'.format(
        sum([p.nelement() for p in network.parameters()])))

    optimizer = torch.optim.Adam(network.parameters(),
                                 lr=args.lr_max,
                                 betas=args.adam,
                                 weight_decay=args.weight_decay)
    lr_scheduler = optim.CosineFollowedByFlatLR(optimizer, args.lr_min,
                                                args.lr_cosine_iters)

    if args.scale_mod_loss:
        try:
            mod_cat_weights = alphabet_info.compute_mod_inv_freq_weights(
                read_data, args.num_inv_freq_reads)
            log.write('* Modified base weights: {}\n'.format(
                str(mod_cat_weights)))
        except NotImplementedError:
            log.write(
                '* WARNING: Some mods not found when computing inverse ' +
                'frequency weights. Consider raising ' +
                '[--num_inv_freq_reads].\n')
            mod_cat_weights = np.ones(alphabet_info.nbase, dtype=np.float32)
    else:
        mod_cat_weights = np.ones(alphabet_info.nbase, dtype=np.float32)

    log.write('* Dumping initial model\n')
    save_model(network, args.outdir, 0)

    total_bases = 0
    total_chunks = 0
    total_samples = 0
    # To count the numbers of different sorts of chunk rejection
    rejection_dict = defaultdict(int)
    score_smoothed = helpers.WindowedExpSmoother()
    t0 = time.time()
    log.write('* Training\n')
    for i in range(args.niteration):
        lr_scheduler.step()
        mod_factor_t = torch.tensor(args.mod_factor, dtype=torch.float32)

        # Chunk length is chosen randomly in the range given but forced to
        # be a multiple of the stride
        batch_chunk_len = (
            np.random.randint(args.chunk_len_min, args.chunk_len_max + 1) //
            args.stride) * args.stride
        # We choose the batch size so that the size of the data in the batch
        # is about the same as args.min_batch_size chunks of length
        # args.chunk_len_max
        target_batch_size = int(args.min_batch_size * args.chunk_len_max /
                                batch_chunk_len + 0.5)
        # ...but it can't be more than the number of reads.
        batch_size = min(target_batch_size, len(read_data))

        # If the logging threshold is 0 then we log all chunks, including those
        # rejected, so pass the log
        # object into assemble_batch
        if args.chunk_logging_threshold == 0:
            log_rejected_chunks = chunk_log
        else:
            log_rejected_chunks = None
        # chunk_batch is a list of dicts.
        chunk_batch, batch_rejections = chunk_selection.assemble_batch(
            read_data,
            batch_size,
            batch_chunk_len,
            filter_parameters,
            args,
            log,
            chunk_log=log_rejected_chunks)
        total_chunks += len(chunk_batch)
        # Update counts of reasons for rejection
        for k, v in batch_rejections.items():
            rejection_dict[k] += v

        # Shape of input tensor must be:
        #     (timesteps) x (batch size) x (input channels)
        # in this case:
        #     batch_chunk_len x batch_size x 1
        stacked_current = np.vstack([d['current'] for d in chunk_batch]).T
        indata = torch.tensor(stacked_current,
                              device=device,
                              dtype=torch.float32).unsqueeze(2)
        seqs, mod_cats, seqlens = [], [], []
        for chunk in chunk_batch:
            chunk_labels = chunk['sequence']
            seqlens.append(len(chunk_labels))
            chunk_seq = flipflop_code(
                np.ascontiguousarray(flipflop_can_labels[chunk_labels]),
                flipflop_ncan_base)
            chunk_mod_cats = np.ascontiguousarray(
                flipflop_mod_labels[chunk_labels])
            seqs.append(chunk_seq)
            mod_cats.append(chunk_mod_cats)
        seqs, mod_cats = np.concatenate(seqs), np.concatenate(mod_cats)
        seqs = torch.tensor(seqs, dtype=torch.float32, device=device)
        seqlens = torch.tensor(seqlens, dtype=torch.long, device=device)
        mod_cats = torch.tensor(mod_cats, dtype=torch.long, device=device)

        optimizer.zero_grad()
        outputs = network(indata)
        lossvector = ctc.cat_mod_flipflop_loss(outputs, seqs, seqlens,
                                               mod_cats, can_mods_offsets,
                                               mod_cat_weights, mod_factor_t,
                                               args.sharpen)
        loss = lossvector.sum() / (seqlens > 0.0).float().sum()
        loss.backward()
        optimizer.step()

        fval = float(loss)
        score_smoothed.update(fval)

        # Check for poison chunk and save losses and chunk locations if we're
        # poisoned If args.chunk_logging_threshold set to zero then we log
        # everything
        if fval / score_smoothed.value >= args.chunk_logging_threshold:
            chunk_log.write_batch(i, chunk_batch, lossvector)

        total_bases += int(seqlens.sum())
        total_samples += int(indata.nelement())

        del indata, seqs, mod_cats, seqlens, outputs, loss, lossvector
        if device.type == 'cuda':
            torch.cuda.empty_cache()

        loss_log.write('{}\t{:.10f}\t{:.10f}\n'.format(i, fval,
                                                       score_smoothed.value))
        if (i + 1) % args.save_every == 0:
            save_model(network, args.outdir, (i + 1) // args.save_every)
            log.write('C')
        else:
            log.write('.')

        if (i + 1) % 50 == 0:
            # In case of super batching, additional functionality must be
            # added here
            learning_rate = lr_scheduler.get_lr()[0]
            tn = time.time()
            dt = tn - t0
            log.write((' {:5d} {:5.3f}  {:5.2f}s ({:.2f} ksample/s ' +
                       '{:.2f} kbase/s) lr={:.2e}').format(
                           (i + 1) // 50, score_smoothed.value, dt,
                           total_samples / 1000 / dt, total_bases / 1000 / dt,
                           learning_rate))
            # Write summary of chunk rejection reasons
            for k, v in rejection_dict.items():
                log.write(" {}:{} ".format(k, v))
            log.write("\n")
            total_bases = 0
            total_samples = 0
            t0 = tn

    save_model(network, args.outdir)

    return