Exemple #1
0
def _load_data(args, log):
    if args.input_strand_list is not None:
        read_ids = list(set(helpers.get_read_ids(args.input_strand_list)))
        log.write('* Will train from a subset of {} strands, determined ' +
                  'by read_ids in input strand list\n'.format(len(read_ids)))
    else:
        log.write('* Will train from all strands\n')
        read_ids = 'all'

    if args.limit is not None:
        log.write('* Limiting number of strands to {}\n'.format(args.limit))

    with mapped_signal_files.HDF5Reader(args.input) as per_read_file:
        (bases_alphabet, collapse_alphabet,
         mod_long_names) = per_read_file.get_alphabet_information()
        read_data = per_read_file.get_multiple_reads(read_ids,
                                                     max_reads=args.limit)
        # read_data now contains a list of reads
        # (each an instance of the Read class defined in
        # mapped_signal_files.py, based on dict)

    log.write('* Loaded {} reads.\n'.format(len(read_data)))

    alphabet_info = alphabet.AlphabetInfo(bases_alphabet,
                                          collapse_alphabet,
                                          mod_long_names,
                                          do_reorder=False)
    log.write('* Using alphabet definition: {}\n'.format(str(alphabet_info)))

    return read_data, alphabet_info
    def test_mod_prepare_remap(self):
        print("Current directory is", os.getcwd())
        print("Taiyaki dir is", self.taiyakidir)
        print("Data dir is ", self.datadir)
        cmd = [
            self.script, self.read_dir, self.per_read_params,
            self.output_mapped_signal_file, self.remapping_model,
            self.mod_per_read_refs, "--mod", "Z", "C", "5mC", "--mod", "Y",
            "A", "6mA", "--overwrite"
        ]
        r = subprocess.run(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
        print("Result of running make command in shell:")
        print("Stdout=", r.stdout.decode('utf-8'))
        print("Stderr=", r.stderr.decode('utf-8'))

        # Open mapped read file and run checks to see if it complies with file format
        # Also get a chunk and check that speed is within reasonable bounds
        with mapped_signal_files.HDF5Reader(
                self.output_mapped_signal_file) as f:
            testreport = f.check()
            print("Test report from checking mapped read file:")
            print(testreport)
            self.assertEqual(testreport, "pass")
            read0 = f.get_multiple_reads("all")[0]
            chunk = read0.get_chunk_with_sample_length(1000, start_sample=10)
            # Defined start_sample to make it reproducible - otherwise randomly
            # located chunk is returned.
            chunk_meandwell = len(
                chunk['current']) / (len(chunk['sequence']) + 0.0001)
            print("chunk mean dwell time in samples = ", chunk_meandwell)
            assert 7 < chunk_meandwell < 13, "Chunk mean dwell time outside allowed range 7 to 13"

        return
Exemple #3
0
def check_map_sig_alphabet(model_info, ms_fn):
    # read filename queue filler
    msf = mapped_signal_files.HDF5Reader(ms_fn)
    tai_alph_info = msf.get_alphabet_information()
    msf.close()
    if model_info.output_alphabet != tai_alph_info.alphabet:
        raise mh.MegaError(
            (
                "Different alphabets specified in model ({}) and mapped "
                + "signal file ({})"
            ).format(model_info.output_alphabet, tai_alph_info.alphabet)
        )
    if set(model_info.can_alphabet) != set(tai_alph_info.collapse_alphabet):
        raise mh.MegaError(
            (
                "Different canonical alphabets specified in model ({}) and "
                + "mapped signal file ({})"
            ).format(model_info.can_alphabet, tai_alph_info.collapse_alphabet)
        )
    if model_info.ordered_mod_long_names != tai_alph_info.mod_long_names:
        raise mh.MegaError(
            (
                "Different modified base long names specified in model ({}) and "
                + "mapped signal file ({})"
            ).format(
                ", ".join(model_info.ordered_mod_long_names),
                ", ".join(tai_alph_info.mod_long_names),
            )
        )
    def test_check_HDF5_mapped_read_file(self):
        """Check that constructing a read object which doesn't conform
        leads to errors.
        """
        print("Creating flawed Read object from test data")
        read_dict = construct_mapped_read()
        read_dict['Reference'] = "I'm not a numpy array!"  # Wrong type!
        read_object = mapped_signal_files.Read(read_dict)
        print("Checking contents")
        check_text = read_object.check()
        print("Check result on read object: should fail")
        print(check_text)
        self.assertNotEqual(check_text, "pass")

        print("Writing to file")
        alphabet_info = alphabet.AlphabetInfo(DEFAULT_ALPHABET, DEFAULT_ALPHABET)
        with mapped_signal_files.HDF5Writer(self.testfilepath, alphabet_info) as f:
            f.write_read(read_object)

        print("Current dir = ", os.getcwd())
        print("File written to ", self.testfilepath)

        print("\nOpening file for reading")
        with mapped_signal_files.HDF5Reader(self.testfilepath) as f:
            ids = f.get_read_ids()
            print("Read ids=", ids[0])
            print("Version number = ", f.version)
            self.assertEqual(ids[0], read_dict['read_id'])

            file_test_report = f.check()
            print("Test report (should fail):", file_test_report)
            self.assertNotEqual(file_test_report, "pass")
Exemple #5
0
 def count_reads(self, mapped_signal_file, print_readlist=True):
     """Count the number of reads in a mapped signal file."""
     with mapped_signal_files.HDF5Reader(mapped_signal_file) as f:
         read_ids = f.get_read_ids()
         if print_readlist:
             print("Read list:")
             print('\n'.join(read_ids))
     return len(read_ids)
def fill_reads_queue(read_q, read_filler_conn, ms_fn, num_reads_limit,
                     num_proc):
    msf = mapped_signal_files.HDF5Reader(ms_fn)
    num_reads = 0
    for read in msf:
        read_q.put(read)
        num_reads += 1
        if num_reads_limit is not None and num_reads >= num_reads_limit:
            break
    read_filler_conn.send(num_reads)
    msf.close()
    for _ in num_proc:
        read_q.put(None)
def main():
    args = parser.parse_args()
    if args.output is not None:
        plt.figure(figsize=(12, 10))
    reads_sofar = 0
    for nfile, mapped_read_file in enumerate(args.mapped_read_files):
        with mapped_signal_files.HDF5Reader(mapped_read_file) as h5:
            all_read_ids = h5.get_read_ids()
            if len(args.read_ids) > 0:
                read_ids = args.read_ids
            else:
                read_ids = all_read_ids[:args.nreads]
                sys.stderr.write(
                    "Reading first {} read ids in file {}\n".format(
                        args.nreads, mapped_read_file))
            for nread, read_id in enumerate(read_ids):
                r = h5.get_read(read_id)
                mapping = r['Ref_to_signal']
                f = mapping >= 0
                maplen = len(mapping)
                read_info_text = (
                    'file {} read {}:{} reflen:{}, daclen:{}').format(
                        nfile, nread, read_id, maplen - 1, len(r['Dacs']))
                sys.stdout.write(read_info_text + '\n')

                if args.output is not None:
                    label = (read_info_text
                             if reads_sofar <= args.maxlegendsize else None)
                    x, y = np.arange(maplen)[f], mapping[f]
                    if args.xmin is not None:
                        xf = x >= args.xmin
                        x, y = x[xf], y[xf]
                    if args.xmax is not None:
                        xf = x <= args.xmax
                        x, y = x[xf], y[xf]
                    plt.plot(x,
                             y,
                             label=label,
                             linestyle='dashed' if nfile == 1 else 'solid')

    if args.output is not None:
        plt.grid()
        plt.xlabel('Reference location')
        plt.ylabel('Signal location')
        plt.legend(loc='upper left', framealpha=0.3)
        plt.tight_layout()
        sys.stderr.write("Saving plot to {}\n".format(args.output))
        plt.savefig(args.output)
def main():
    args = parser.parse_args()
    plt.figure(figsize=(12, 10))
    for nfile, mapped_read_file in enumerate(args.mapped_read_files):
        sys.stderr.write("Opening {}\n".format(mapped_read_file))
        with mapped_signal_files.HDF5Reader(mapped_read_file) as h5:
            all_read_ids = h5.get_read_ids()
            sys.stderr.write("First ten read_ids in file:\n")
            for read_id in all_read_ids[:10]:
                sys.stderr.write("    {}\n".format(read_id))
            if len(args.read_ids) > 0:
                read_ids = args.read_ids
            else:
                read_ids = all_read_ids[:args.nreads]
                sys.stderr.write("Plotting first {} read ids in file\n".format(
                    args.nreads))
            for nread, read_id in enumerate(read_ids):
                sys.stderr.write("Opening read id {}\n".format(read_id))
                r = h5.get_read(read_id)
                mapping = r['Ref_to_signal']
                f = mapping >= 0
                maplen = len(mapping)
                label = 'file ' + str(nfile) + ' read ' + str(
                    nread) + ":" + read_id + " reflen:" + str(
                        maplen - 1) + ", daclen:" + str(len(r['Dacs']))
                x, y = np.arange(maplen)[f], mapping[f]
                if args.xmin is not None:
                    xf = (x >= args.xmin)
                    x, y = x[xf], y[xf]
                if args.xmax is not None:
                    xf = (x <= args.xmax)
                    x, y = x[xf], y[xf]
                plt.plot(x,
                         y,
                         label=label,
                         linestyle='dashed' if nfile == 1 else 'solid')

    plt.grid()
    plt.xlabel('Reference location')
    plt.ylabel('Signal location')
    if len(read_ids) < 15:
        plt.legend(loc='upper left', framealpha=0.3)
    plt.tight_layout()
    sys.stderr.write("Saving plot to {}\n".format(args.output))
    plt.savefig(args.output)
def main():
    args = parser.parse_args()

    np.random.seed(args.seed)

    device = torch.device(args.device)
    if device.type == 'cuda':
        try:
            torch.cuda.set_device(device)
        except AttributeError:
            sys.stderr.write('ERROR: Torch not compiled with CUDA enabled ' +
                             'and GPU device set.')
            sys.exit(1)

    if not os.path.exists(args.output):
        os.mkdir(args.output)
    elif not args.overwrite:
        sys.stderr.write('Error: Output directory {} exists but --overwrite ' +
                         'is false\n'.format(args.output))
        exit(1)
    if not os.path.isdir(args.output):
        sys.stderr.write('Error: Output location {} is not directory\n'.format(
            args.output))
        exit(1)

    copyfile(args.model, os.path.join(args.output, 'model.py'))

    # Create a logging file to save details of chunks.
    # If args.chunk_logging_threshold is set to 0 then we log all chunks
    # including those rejected.
    chunk_log = chunk_selection.ChunkLog(args.output)

    log = helpers.Logger(os.path.join(args.output, 'model.log'), args.quiet)
    log.write('* Taiyaki version {}\n'.format(__version__))
    log.write('* Command line\n')
    log.write(' '.join(sys.argv) + '\n')
    log.write('* Loading data from {}\n'.format(args.input))
    log.write('* Per read file MD5 {}\n'.format(helpers.file_md5(args.input)))

    if args.input_strand_list is not None:
        read_ids = list(set(helpers.get_read_ids(args.input_strand_list)))
        log.write(('* Will train from a subset of {} strands, determined ' +
                   'by read_ids in input strand list\n').format(len(read_ids)))
    else:
        log.write('* Reads not filtered by id\n')
        read_ids = 'all'

    if args.limit is not None:
        log.write('* Limiting number of strands to {}\n'.format(args.limit))

    with mapped_signal_files.HDF5Reader(args.input) as per_read_file:
        alphabet, _, _ = per_read_file.get_alphabet_information()
        read_data = per_read_file.get_multiple_reads(read_ids,
                                                     max_reads=args.limit)
        # read_data now contains a list of reads
        # (each an instance of the Read class defined in
        # mapped_signal_files.py, based on dict)

    if len(read_data) == 0:
        log.write('* No reads remaining for training, exiting.\n')
        exit(1)
    log.write('* Loaded {} reads.\n'.format(len(read_data)))

    # Get parameters for filtering by sampling a subset of the reads
    # Result is a tuple median mean_dwell, mad mean_dwell
    # Choose a chunk length in the middle of the range for this
    sampling_chunk_len = (args.chunk_len_min + args.chunk_len_max) // 2
    filter_parameters = chunk_selection.sample_filter_parameters(
        read_data,
        args.sample_nreads_before_filtering,
        sampling_chunk_len,
        args,
        log,
        chunk_log=chunk_log)

    medmd, madmd = filter_parameters

    log.write(
        "* Sampled {} chunks: median(mean_dwell)={:.2f}, mad(mean_dwell)={:.2f}\n"
        .format(args.sample_nreads_before_filtering, medmd, madmd))
    log.write('* Reading network from {}\n'.format(args.model))
    nbase = len(alphabet)
    model_kwargs = {
        'stride': args.stride,
        'winlen': args.winlen,
        # Number of input features to model e.g. was >1 for event-based
        # models (level, std, dwell)
        'insize': 1,
        'size': args.size,
        'outsize': flipflopfings.nstate_flipflop(nbase)
    }
    network = helpers.load_model(args.model, **model_kwargs).to(device)
    log.write('* Network has {} parameters.\n'.format(
        sum([p.nelement() for p in network.parameters()])))

    optimizer = torch.optim.Adam(network.parameters(),
                                 lr=args.lr_max,
                                 betas=args.adam,
                                 weight_decay=args.weight_decay)

    lr_scheduler = optim.CosineFollowedByFlatLR(optimizer, args.lr_min,
                                                args.lr_cosine_iters)

    score_smoothed = helpers.WindowedExpSmoother()

    log.write('* Dumping initial model\n')
    helpers.save_model(network, args.output, 0)

    total_bases = 0
    total_samples = 0
    total_chunks = 0
    # To count the numbers of different sorts of chunk rejection
    rejection_dict = defaultdict(int)

    t0 = time.time()
    log.write('* Training\n')

    for i in range(args.niteration):
        lr_scheduler.step()
        # Chunk length is chosen randomly in the range given but forced to
        # be a multiple of the stride
        batch_chunk_len = (
            np.random.randint(args.chunk_len_min, args.chunk_len_max + 1) //
            args.stride) * args.stride
        # We choose the batch size so that the size of the data in the batch
        # is about the same as args.min_batch_size chunks of length
        # args.chunk_len_max
        target_batch_size = int(args.min_batch_size * args.chunk_len_max /
                                batch_chunk_len + 0.5)
        # ...but it can't be more than the number of reads.
        batch_size = min(target_batch_size, len(read_data))

        # If the logging threshold is 0 then we log all chunks, including those
        # rejected, so pass the log
        # object into assemble_batch
        if args.chunk_logging_threshold == 0:
            log_rejected_chunks = chunk_log
        else:
            log_rejected_chunks = None
        # Chunk_batch is a list of dicts.
        chunk_batch, batch_rejections = chunk_selection.assemble_batch(
            read_data,
            batch_size,
            batch_chunk_len,
            filter_parameters,
            args,
            log,
            chunk_log=log_rejected_chunks)
        total_chunks += len(chunk_batch)

        # Update counts of reasons for rejection
        for k, v in batch_rejections.items():
            rejection_dict[k] += v

        # Shape of input tensor must be:
        #     (timesteps) x (batch size) x (input channels)
        # in this case:
        #     batch_chunk_len x batch_size x 1
        stacked_current = np.vstack([d['current'] for d in chunk_batch]).T
        indata = torch.tensor(stacked_current,
                              device=device,
                              dtype=torch.float32).unsqueeze(2)
        # Sequence input tensor is just a 1D vector, and so is seqlens
        seqs = torch.tensor(np.concatenate([
            flipflopfings.flipflop_code(d['sequence'], nbase)
            for d in chunk_batch
        ]),
                            device=device,
                            dtype=torch.long)
        seqlens = torch.tensor([len(d['sequence']) for d in chunk_batch],
                               dtype=torch.long,
                               device=device)

        optimizer.zero_grad()
        outputs = network(indata)
        lossvector = ctc.crf_flipflop_loss(outputs, seqs, seqlens,
                                           args.sharpen)
        loss = lossvector.sum() / (seqlens > 0.0).float().sum()
        loss.backward()
        optimizer.step()

        fval = float(loss)
        score_smoothed.update(fval)

        # Check for poison chunk and save losses and chunk locations if we're
        # poisoned If args.chunk_logging_threshold set to zero then we log
        # everything
        if fval / score_smoothed.value >= args.chunk_logging_threshold:
            chunk_log.write_batch(i, chunk_batch, lossvector)

        total_bases += int(seqlens.sum())
        total_samples += int(indata.nelement())

        # Doing this deletion leads to less CUDA memory usage.
        del indata, seqs, seqlens, outputs, loss, lossvector
        if device.type == 'cuda':
            torch.cuda.empty_cache()

        if (i + 1) % args.save_every == 0:
            helpers.save_model(network, args.output,
                               (i + 1) // args.save_every)
            log.write('C')
        else:
            log.write('.')

        if (i + 1) % DOTROWLENGTH == 0:
            # In case of super batching, additional functionality must be
            # added here
            learning_rate = lr_scheduler.get_lr()[0]
            tn = time.time()
            dt = tn - t0
            t = (
                ' {:5d} {:5.3f}  {:5.2f}s ({:.2f} ksample/s {:.2f} kbase/s) ' +
                'lr={:.2e}')
            log.write(
                t.format((i + 1) // DOTROWLENGTH, score_smoothed.value, dt,
                         total_samples / 1000.0 / dt,
                         total_bases / 1000.0 / dt, learning_rate))
            # Write summary of chunk rejection reasons
            for k, v in rejection_dict.items():
                log.write(" {}:{} ".format(k, v))
            log.write("\n")
            total_bases = 0
            total_samples = 0
            t0 = tn

    helpers.save_model(network, args.output)
def main():
    args = parser.parse_args()
    np.random.seed(args.seed)

    if not os.path.exists(args.output):
        os.mkdir(args.output)
    elif not args.overwrite:
        sys.stderr.write(
            'Error: Output directory {} exists but --overwrite is false\n'.
            format(args.output))
        exit(1)
    if not os.path.isdir(args.output):
        sys.stderr.write('Error: Output location {} is not directory\n'.format(
            args.output))
        exit(1)

    log = helpers.Logger(os.path.join(args.output, 'model.log'), args.quiet)
    log.write('# Taiyaki version {}\n'.format(__version__))
    log.write('# Command line\n')
    log.write(' '.join(sys.argv) + '\n')

    if args.input_strand_list is not None:
        read_ids = list(set(helpers.get_read_ids(args.input_strand_list)))
        log.write('* Will train from a subset of {} strands\n'.format(
            len(read_ids)))
    else:
        log.write('* Reads not filtered by id\n')
        read_ids = 'all'

    if args.limit is not None:
        log.write('* Limiting number of strands to {}\n'.format(args.limit))

    with mapped_signal_files.HDF5Reader(args.input) as per_read_file:
        alphabet, _, _ = per_read_file.get_alphabet_information()
        assert len(alphabet) == 4, (
            'Squiggle prediction with modified base training data is ' +
            'not currenly supported.')
        read_data = per_read_file.get_multiple_reads(read_ids,
                                                     max_reads=args.limit)
        # read_data now contains a list of reads
        # (each an instance of the Read class defined in mapped_signal_files.py, based on dict)

    if len(read_data) == 0:
        log.write('* No reads remaining for training, exiting.\n')
        exit(1)
    log.write('* Loaded {} reads.\n'.format(len(read_data)))

    # Create a logging file to save details of chunks.
    # If args.chunk_logging_threshold is set to 0 then we log all chunks including those rejected.
    chunk_log = chunk_selection.ChunkLog(args.output)

    # Get parameters for filtering by sampling a subset of the reads
    # Result is a tuple median mean_dwell, mad mean_dwell
    filter_parameters = chunk_selection.sample_filter_parameters(
        read_data,
        args.sample_nreads_before_filtering,
        args.target_len,
        args,
        log,
        chunk_log=chunk_log)

    medmd, madmd = filter_parameters
    log.write(
        "* Sampled {} chunks: median(mean_dwell)={:.2f}, mad(mean_dwell)={:.2f}\n"
        .format(args.sample_nreads_before_filtering, medmd, madmd))

    conv_net = create_convolution(args.size, args.depth, args.winlen)
    nparam = sum([p.data.detach().numpy().size for p in conv_net.parameters()])
    log.write('# Created network.  {} parameters\n'.format(nparam))
    log.write('# Depth {} layers ({} residual layers)\n'.format(
        args.depth + 2, args.depth))
    log.write('# Window width {}\n'.format(args.winlen))
    log.write('# Context +/- {} bases\n'.format(
        (args.depth + 2) * (args.winlen // 2)))

    device = torch.device(args.device)
    conv_net = conv_net.to(device)

    optimizer = torch.optim.Adam(conv_net.parameters(),
                                 lr=args.lr_max,
                                 betas=args.adam,
                                 weight_decay=args.weight_decay)

    lr_scheduler = optim.ReciprocalLR(optimizer, args.lr_decay)

    rejection_dict = defaultdict(
        lambda: 0
    )  # To count the numbers of different sorts of chunk rejection
    t0 = time.time()
    score_smoothed = helpers.WindowedExpSmoother()
    total_chunks = 0

    for i in range(args.niteration):
        lr_scheduler.step()
        # If the logging threshold is 0 then we log all chunks, including those rejected, so pass the log
        # object into assemble_batch
        if args.chunk_logging_threshold == 0:
            log_rejected_chunks = chunk_log
        else:
            log_rejected_chunks = None
        # chunk_batch is a list of dicts.
        chunk_batch, batch_rejections = chunk_selection.assemble_batch(
            read_data,
            args.batch_size,
            args.target_len,
            filter_parameters,
            args,
            log,
            chunk_log=log_rejected_chunks,
            chunk_len_means_sequence_len=True)

        total_chunks += len(chunk_batch)
        # Update counts of reasons for rejection
        for k, v in batch_rejections.items():
            rejection_dict[k] += v

        # Shape of input needs to be seqlen x batchsize x embedding_dimension
        embedded_matrix = [
            embed_sequence(d['sequence'], alphabet=None) for d in chunk_batch
        ]
        seq_embed = torch.tensor(embedded_matrix).permute(1, 0, 2).to(device)
        # Shape of labels is a flat vector
        batch_signal = torch.tensor(
            np.concatenate([d['current'] for d in chunk_batch])).to(device)
        # Shape of lens is also a flat vector
        batch_siglen = torch.tensor([len(d['current'])
                                     for d in chunk_batch]).to(device)

        #print("First 10 elements of first sequence in batch",seq_embed[:10,0,:])
        #print("First 10 elements of signal batch",batch_signal[:10])
        #print("First 10 lengths",batch_siglen[:10])

        optimizer.zero_grad()

        predicted_squiggle = conv_net(seq_embed)
        batch_loss = squiggle_match_loss(predicted_squiggle, batch_signal,
                                         batch_siglen, args.back_prob)
        fval = batch_loss.sum() / float(batch_siglen.sum())

        fval.backward()
        optimizer.step()

        score_smoothed.update(float(fval))

        # Check for poison chunk and save losses and chunk locations if we're poisoned
        # If args.chunk_logging_threshold set to zero then we log everything
        if fval / score_smoothed.value >= args.chunk_logging_threshold:
            chunk_log.write_batch(i, chunk_batch, batch_loss)

        if (i + 1) % args.save_every == 0:
            helpers.save_model(conv_net, args.output,
                               (i + 1) // args.save_every)
            log.write('C')
        else:
            log.write('.')

        if (i + 1) % DOTROWLENGTH == 0:
            tn = time.time()
            dt = tn - t0
            t = ' {:5d} {:5.3f}  {:5.2f}s'
            log.write(
                t.format((i + 1) // DOTROWLENGTH, score_smoothed.value, dt))
            t0 = tn
            # Write summary of chunk rejection reasons
            for k, v in rejection_dict.items():
                log.write(" {}:{} ".format(k, v))
            log.write("\n")

    helpers.save_model(conv_net, args.output)
Exemple #11
0
def main():
    args = parser.parse_args()
    is_multi_gpu = (args.local_rank is not None)
    is_lead_process = (not is_multi_gpu) or args.local_rank == 0

    if is_multi_gpu:
        #Use distributed parallel processing to run one process per GPU
        try:
            torch.distributed.init_process_group(backend='nccl')
        except:
            raise Exception(
                "Unable to start multiprocessing group. " +
                "The most likely reason is that the script is running with " +
                "local_rank set but without the set-up for distributed " +
                "operation. local_rank should be used " +
                "only by torch.distributed.launch. See the README.")
        device = helpers.set_torch_device(args.local_rank)
        if args.seed is not None:
            #Make sure processes get different random picks of training data
            np.random.seed(args.seed + args.local_rank)
    else:
        device = helpers.set_torch_device(args.device)
        np.random.seed(args.seed)

    if is_lead_process:
        helpers.prepare_outdir(args.outdir, args.overwrite)
        if args.model.endswith('.py'):
            copyfile(args.model, os.path.join(args.outdir, 'model.py'))
        batchlog = helpers.BatchLog(args.outdir)
        logfile = os.path.join(args.outdir, 'model.log')
    else:
        logfile = None

    log = helpers.Logger(logfile, args.quiet)
    log.write(helpers.formatted_env_info(device))

    log.write('* Loading data from {}\n'.format(args.input))
    log.write('* Per read file MD5 {}\n'.format(helpers.file_md5(args.input)))

    if args.input_strand_list is not None:
        read_ids = list(set(helpers.get_read_ids(args.input_strand_list)))
        log.write(('* Will train from a subset of {} strands, determined ' +
                   'by read_ids in input strand list\n').format(len(read_ids)))
    else:
        log.write('* Reads not filtered by id\n')
        read_ids = 'all'

    if args.limit is not None:
        log.write('* Limiting number of strands to {}\n'.format(args.limit))

    with mapped_signal_files.HDF5Reader(args.input) as per_read_file:
        alphabet_info = per_read_file.get_alphabet_information()
        read_data = per_read_file.get_multiple_reads(read_ids,
                                                     max_reads=args.limit)
        # read_data now contains a list of reads
        # (each an instance of the Read class defined in
        # mapped_signal_files.py, based on dict)
    log.write('* Using alphabet definition: {}\n'.format(str(alphabet_info)))

    if len(read_data) == 0:
        log.write('* No reads remaining for training, exiting.\n')
        exit(1)
    log.write('* Loaded {} reads.\n'.format(len(read_data)))

    # Get parameters for filtering by sampling a subset of the reads
    # Result is a tuple median mean_dwell, mad mean_dwell
    # Choose a chunk length in the middle of the range for this
    sampling_chunk_len = (args.chunk_len_min + args.chunk_len_max) // 2
    filter_params = chunk_selection.sample_filter_parameters(
        read_data, args.sample_nreads_before_filtering, sampling_chunk_len,
        args.filter_mean_dwell, args.filter_max_dwell)

    log.write("* Sampled {} chunks".format(
        args.sample_nreads_before_filtering))
    log.write(": median(mean_dwell)={:.2f}".format(
        filter_params.median_meandwell))
    log.write(", mad(mean_dwell)={:.2f}\n".format(filter_params.mad_meandwell))
    log.write('* Reading network from {}\n'.format(args.model))
    model_kwargs = {
        'stride': args.stride,
        'winlen': args.winlen,
        # Number of input features to model e.g. was >1 for event-based
        # models (level, std, dwell)
        'insize': 1,
        'size': args.size,
        'alphabet_info': alphabet_info
    }

    if is_lead_process:
        # Under pytorch's DistributedDataParallel scheme, we
        # need a clone of the start network to use as a template for saving
        # checkpoints. Necessary because DistributedParallel makes the class
        # structure different.
        network_save_skeleton = helpers.load_model(args.model, **model_kwargs)
        log.write('* Network has {} parameters.\n'.format(
            sum([p.nelement() for p in network_save_skeleton.parameters()])))
        if not alphabet_info.is_compatible_model(network_save_skeleton):
            sys.stderr.write(
                '* ERROR: Model and mapped signal files contain incompatible '
                + 'alphabet definitions (including modified bases).')
            sys.exit(1)
        if is_cat_mod_model(network_save_skeleton):
            log.write('* Loaded categorical modified base model.\n')
            if not alphabet_info.contains_modified_bases():
                sys.stderr.write(
                    '* ERROR: Modified bases model specified, but mapped ' +
                    'signal file does not contain modified bases.')
                sys.exit(1)
        else:
            log.write('* Loaded standard (canonical bases-only) model.\n')
            if alphabet_info.contains_modified_bases():
                sys.stderr.write(
                    '* ERROR: Standard (canonical bases only) model ' +
                    'specified, but mapped signal file does contains ' +
                    'modified bases.')
                sys.exit(1)
        log.write('* Dumping initial model\n')
        helpers.save_model(network_save_skeleton, args.outdir, 0)

    if is_multi_gpu:
        #so that processes 1,2,3.. don't try to load before process 0 has saved
        torch.distributed.barrier()
        log.write('* MultiGPU process {}'.format(args.local_rank))
        log.write(': loading initial model saved by process 0\n')
        saved_startmodel_path = os.path.join(
            args.outdir, 'model_checkpoint_00000.checkpoint')
        network = helpers.load_model(saved_startmodel_path).to(device)
        # Wrap network for training in the DistributedDataParallel structure
        network = torch.nn.parallel.DistributedDataParallel(
            network,
            device_ids=[args.local_rank],
            output_device=args.local_rank)
    else:
        network = network_save_skeleton.to(device)
        network_save_skeleton = None

    optimizer = torch.optim.Adam(network.parameters(),
                                 lr=args.lr_max,
                                 betas=args.adam,
                                 weight_decay=args.weight_decay,
                                 eps=args.eps)

    if args.lr_warmup is None:
        lr_warmup = args.lr_min
    else:
        lr_warmup = args.lr_warmup

    if args.lr_frac_decay is not None:
        lr_scheduler = optim.ReciprocalLR(optimizer, args.lr_frac_decay,
                                          args.warmup_batches, lr_warmup)
        log.write('* Learning rate schedule lr_max*k/(k+t)')
        log.write(', k={}, t=iterations.\n'.format(args.lr_frac_decay))
    else:
        lr_scheduler = optim.CosineFollowedByFlatLR(optimizer, args.lr_min,
                                                    args.lr_cosine_iters,
                                                    args.warmup_batches,
                                                    lr_warmup)
        log.write('* Learning rate goes like cosine from lr_max to lr_min ')
        log.write('over {} iterations.\n'.format(args.lr_cosine_iters))
    log.write('* At start, train for {} '.format(args.warmup_batches))
    log.write('batches at warm-up learning rate {:3.2}\n'.format(lr_warmup))

    score_smoothed = helpers.WindowedExpSmoother()

    # prepare modified base paramter tensors
    network_is_catmod = is_cat_mod_model(network)
    mod_factor_t = torch.tensor(args.mod_factor,
                                dtype=torch.float32).to(device)
    can_mods_offsets = (network.sublayers[-1].can_mods_offsets
                        if network_is_catmod else None)
    # mod cat inv freq weighting is currently disabled. Compute and set this
    # value to enable mod cat weighting
    mod_cat_weights = np.ones(alphabet_info.nbase, dtype=np.float32)

    #Generating list of batches for standard loss reporting
    reporting_chunk_len = (args.chunk_len_min + args.chunk_len_max) // 2
    reporting_batch_list = list(
        prepare_random_batches(device, read_data, reporting_chunk_len,
                               args.min_sub_batch_size,
                               args.reporting_sub_batches, alphabet_info,
                               filter_params, network, network_is_catmod, log))

    log.write(
        ('* Standard loss report: chunk length = {} & sub-batch size ' +
         '= {} for {} sub-batches. \n').format(reporting_chunk_len,
                                               args.min_sub_batch_size,
                                               args.reporting_sub_batches))

    #Set cap at very large value (before we have any gradient stats).
    gradient_cap = constants.LARGE_VAL
    if args.gradient_cap_fraction is None:
        log.write('* No gradient capping\n')
    else:
        rolling_quantile = maths.RollingQuantile(args.gradient_cap_fraction)
        log.write('* Gradient L2 norm cap will be upper' +
                  ' {:3.2f} quantile of the last {} norms.\n'.format(
                      args.gradient_cap_fraction, rolling_quantile.window))

    total_bases = 0
    total_samples = 0
    total_chunks = 0
    # To count the numbers of different sorts of chunk rejection
    rejection_dict = defaultdict(int)

    t0 = time.time()
    log.write('* Training\n')

    for i in range(args.niteration):

        # Chunk length is chosen randomly in the range given but forced to
        # be a multiple of the stride
        batch_chunk_len = (
            np.random.randint(args.chunk_len_min, args.chunk_len_max + 1) //
            args.stride) * args.stride

        # We choose the size of a sub-batch so that the size of the data in
        # the sub-batch is about the same as args.min_sub_batch_size chunks of
        # length args.chunk_len_max
        sub_batch_size = int(args.min_sub_batch_size * args.chunk_len_max /
                             batch_chunk_len + 0.5)

        optimizer.zero_grad()

        main_batch_gen = prepare_random_batches(
            device, read_data, batch_chunk_len, sub_batch_size,
            args.sub_batches, alphabet_info, filter_params, network,
            network_is_catmod, log)

        chunk_count, fval, chunk_samples, chunk_bases, batch_rejections = \
                            calculate_loss( network, network_is_catmod,
                                            main_batch_gen, args.sharpen,
                                            can_mods_offsets, mod_cat_weights,
                                            mod_factor_t, calc_grads = True )

        gradnorm_uncapped = torch.nn.utils.clip_grad_norm_(
            network.parameters(), gradient_cap)
        if args.gradient_cap_fraction is not None:
            gradient_cap = rolling_quantile.update(gradnorm_uncapped)

        optimizer.step()
        if is_lead_process:
            batchlog.record(
                fval, gradnorm_uncapped,
                None if args.gradient_cap_fraction is None else gradient_cap)

        total_chunks += chunk_count
        total_samples += chunk_samples
        total_bases += chunk_bases

        # Update counts of reasons for rejection
        for k, v in batch_rejections.items():
            rejection_dict[k] += v

        score_smoothed.update(fval)

        if (i + 1) % args.save_every == 0 and is_lead_process:
            helpers.save_model(network, args.outdir,
                               (i + 1) // args.save_every,
                               network_save_skeleton)
            log.write('C')
        else:
            log.write('.')

        if (i + 1) % DOTROWLENGTH == 0:

            _, rloss, _, _, _ = calculate_loss(network, network_is_catmod,
                                               reporting_batch_list,
                                               args.sharpen, can_mods_offsets,
                                               mod_cat_weights, mod_factor_t)

            # In case of super batching, additional functionality must be
            # added here
            learning_rate = lr_scheduler.get_lr()[0]
            tn = time.time()
            dt = tn - t0
            t = (' {:5d} {:7.5f} {:7.5f}  {:5.2f}s ({:.2f} ksample/s {:.2f} ' +
                 'kbase/s) lr={:.2e}')
            log.write(
                t.format((i + 1) // DOTROWLENGTH, score_smoothed.value, rloss,
                         dt, total_samples / 1000.0 / dt,
                         total_bases / 1000.0 / dt, learning_rate))
            # Write summary of chunk rejection reasons
            if args.full_filter_status:
                for k, v in rejection_dict.items():
                    log.write(" {}:{} ".format(k, v))
            else:
                n_tot = n_fail = 0
                for k, v in rejection_dict.items():
                    n_tot += v
                    if k != 'pass':
                        n_fail += v
                log.write("  {:.1%} chunks filtered".format(n_fail / n_tot))
            log.write("\n")
            total_bases = 0
            total_samples = 0
            t0 = tn

            # Uncomment the lines below to check synchronisation of models
            # between processes in multi-GPU operation
            #for p in network.parameters():
            #    v = p.data.reshape(-1)[:5].to('cpu')
            #    u = p.data.reshape(-1)[-5:].to('cpu')
            #    break
            #if args.local_rank is not None:
            #    log.write("* GPU{} params:".format(args.local_rank))
            #log.write("{}...{}\n".format(v,u))

        lr_scheduler.step()

    if is_lead_process:
        helpers.save_model(network,
                           args.outdir,
                           model_skeleton=network_save_skeleton)
Exemple #12
0
def main():
    args = parser.parse_args()
    np.random.seed(args.seed)

    helpers.prepare_outdir(args.outdir, args.overwrite)

    device = helpers.set_torch_device(args.device)

    log = helpers.Logger(os.path.join(args.outdir, 'model.log'), args.quiet)
    log.write(helpers.formatted_env_info(device))

    if args.input_strand_list is not None:
        read_ids = list(set(helpers.get_read_ids(args.input_strand_list)))
        log.write('* Will train from a subset of {} strands\n'.format(
            len(read_ids)))
    else:
        log.write('* Reads not filtered by id\n')
        read_ids = 'all'

    if args.limit is not None:
        log.write('* Limiting number of strands to {}\n'.format(args.limit))

    with mapped_signal_files.HDF5Reader(args.input) as per_read_file:
        alphabet_info = per_read_file.get_alphabet_information()
        assert alphabet_info.nbase == 4, (
            'Squiggle prediction with modified base training data is ' +
            'not currenly supported.')
        read_data = per_read_file.get_multiple_reads(read_ids,
                                                     max_reads=args.limit)
        # read_data now contains a list of reads
        # (each an instance of the Read class defined in mapped_signal_files.py, based on dict)

    if len(read_data) == 0:
        log.write('* No reads remaining for training, exiting.\n')
        exit(1)
    log.write('* Loaded {} reads.\n'.format(len(read_data)))

    # Get parameters for filtering by sampling a subset of the reads
    # Result is a tuple median mean_dwell, mad mean_dwell
    filter_parameters = chunk_selection.sample_filter_parameters(
        read_data, args.sample_nreads_before_filtering, args.target_len,
        args.filter_mean_dwell, args.filter_max_dwell)

    log.write(
        "* Sampled {} chunks: median(mean_dwell)={:.2f}, mad(mean_dwell)={:.2f}\n"
        .format(args.sample_nreads_before_filtering,
                filter_parameters.median_meandwell,
                filter_parameters.mad_meandwell))

    conv_net = create_convolution(args.size, args.depth, args.winlen)
    nparam = sum([p.data.detach().numpy().size for p in conv_net.parameters()])
    log.write('* Created network.  {} parameters\n'.format(nparam))
    log.write('* Depth {} layers ({} residual layers)\n'.format(
        args.depth + 2, args.depth))
    log.write('* Window width {}\n'.format(args.winlen))
    log.write('* Context +/- {} bases\n'.format(
        (args.depth + 2) * (args.winlen // 2)))

    conv_net = conv_net.to(device)

    optimizer = torch.optim.Adam(conv_net.parameters(),
                                 lr=args.lr_max,
                                 betas=args.adam,
                                 weight_decay=args.weight_decay,
                                 eps=args.eps)

    lr_scheduler = optim.ReciprocalLR(optimizer, args.lr_decay)

    rejection_dict = defaultdict(
        lambda: 0
    )  # To count the numbers of different sorts of chunk rejection
    t0 = time.time()
    score_smoothed = helpers.WindowedExpSmoother()
    total_chunks = 0

    for i in range(args.niteration):
        # If the logging threshold is 0 then we log all chunks, including those rejected, so pass the log
        # object into assemble_batch
        # chunk_batch is a list of dicts.
        chunk_batch, batch_rejections = chunk_selection.assemble_batch(
            read_data,
            args.batch_size,
            args.target_len,
            filter_parameters,
            chunk_len_means_sequence_len=True)
        if len(chunk_batch) < args.batch_size:
            log.write('* Warning: only {} chunks passed filters.\n'.format(
                len(chunk_batch)))

        total_chunks += len(chunk_batch)
        # Update counts of reasons for rejection
        for k, v in batch_rejections.items():
            rejection_dict[k] += v

        # Shape of input needs to be seqlen x batchsize x embedding_dimension
        embedded_matrix = [
            embed_sequence(d['sequence'], alphabet=None) for d in chunk_batch
        ]
        seq_embed = torch.tensor(embedded_matrix).permute(1, 0, 2).to(device)
        # Shape of labels is a flat vector
        batch_signal = torch.tensor(
            np.concatenate([d['current'] for d in chunk_batch])).to(device)
        # Shape of lens is also a flat vector
        batch_siglen = torch.tensor([len(d['current'])
                                     for d in chunk_batch]).to(device)

        #print("First 10 elements of first sequence in batch",seq_embed[:10,0,:])
        #print("First 10 elements of signal batch",batch_signal[:10])
        #print("First 10 lengths",batch_siglen[:10])

        optimizer.zero_grad()

        predicted_squiggle = conv_net(seq_embed)
        batch_loss = squiggle_match_loss(predicted_squiggle, batch_signal,
                                         batch_siglen, args.back_prob)
        fval = batch_loss.sum() / float(batch_siglen.sum())

        fval.backward()
        optimizer.step()

        score_smoothed.update(float(fval))

        if (i + 1) % args.save_every == 0:
            helpers.save_model(conv_net, args.outdir,
                               (i + 1) // args.save_every)
            log.write('C')
        else:
            log.write('.')

        if (i + 1) % DOTROWLENGTH == 0:
            tn = time.time()
            dt = tn - t0
            t = ' {:5d} {:7.5f}  {:5.2f}s'
            log.write(
                t.format((i + 1) // DOTROWLENGTH, score_smoothed.value, dt))
            t0 = tn
            # Write summary of chunk rejection reasons
            if args.full_filter_status:
                for k, v in rejection_dict.items():
                    log.write(" {}:{} ".format(k, v))
            else:
                n_tot = n_fail = 0
                for k, v in rejection_dict.items():
                    n_tot += v
                    if k != 'pass':
                        n_fail += v
                log.write("  {:.1%} chunks filtered".format(n_fail / n_tot))
            log.write("\n")

        lr_scheduler.step()

    helpers.save_model(conv_net, args.outdir)
    def test_HDF5_mapped_read_file(self):
        """Test that we can save a mapped read file, open it again and
        use some methods to get data from it. Plot a picture for diagnostics.
        """

        print("Creating Read object from test data")
        read_dict = construct_mapped_read()
        read_object = mapped_signal_files.Read(read_dict)
        print("Checking contents")
        check_text = read_object.check()
        print("Check result on read object:")
        print(check_text)
        self.assertEqual(check_text, "pass")

        print("Writing to file")
        alphabet_info = alphabet.AlphabetInfo(DEFAULT_ALPHABET, DEFAULT_ALPHABET)
        with mapped_signal_files.HDF5Writer(self.testfilepath, alphabet_info) as f:
            f.write_read(read_object)

        print("Current dir = ", os.getcwd())
        print("File written to ", self.testfilepath)

        print("\nOpening file for reading")
        with mapped_signal_files.HDF5Reader(self.testfilepath) as f:
            ids = f.get_read_ids()
            print("Read ids=", ids[0])
            print("Version number = ", f.version)
            self.assertEqual(ids[0], read_dict['read_id'])

            file_test_report = f.check()
            print("Test report:", file_test_report)
            self.assertEqual(file_test_report, "pass")

            read_list = f.get_multiple_reads("all")

        recovered_read = read_list[0]
        reflen = len(recovered_read['Reference'])
        siglen = len(recovered_read['Dacs'])

        # Get a chunk - note that chunkstart is relative to the start of the mapped
        # region, not relative to the start of the signal
        chunklen, chunkstart = 5, 3
        chunkdict = recovered_read.get_chunk_with_sample_length(chunklen, chunkstart)

        # Check that the extracted chunk is the right length
        self.assertEqual(len(chunkdict['current']), chunklen)

        # Check that the mapping data agrees with what we put in
        self.assertTrue(np.all(recovered_read['Ref_to_signal']==read_dict['Ref_to_signal']))

        # Plot a picture showing ref_to_sig from the read object,    def setup():
        # and the result of searches to find the inverse
        if False:
            plt.figure()
            plt.xlabel('Signal coord')
            plt.ylabel('Ref coord')
            ix = np.array([0, -1])
            plt.scatter(chunkdict['current'][ix], chunkdict['sequence'][ix],
                        s=50, label='chunk limits', marker='s', color='black')
            plt.scatter(recovered_read['Ref_to_signal'], np.arange(reflen + 1), label='reftosig (source data)',
                        color='none', edgecolor='blue', s=60)
            siglocs = np.arange(siglen, dtype=np.int32)
            sigtoref_fromsearch = recovered_read.get_reference_locations(siglocs)
            plt.scatter(siglocs, sigtoref_fromsearch, label='from search', color='red', marker='x', s=50)
            plt.legend()
            plt.grid()
            plt.savefig(self.plotfilepath)
            print("Saved plot to", self.plotfilepath)