def main(): usage = 'usage: %prog [options] <fasta_file> <sample_wigs_file> <hdf5_file>' parser = OptionParser(usage) parser.add_option( '-b', dest='limit_bed', help='Limit to segments that overlap regions in a BED file') parser.add_option( '-c', dest='clip', default=None, type='float', help='Clip target values to have minimum [Default: %default]') parser.add_option('-d', dest='sample_pct', default=1.0, type='float', help='Down-sample the segments') parser.add_option('-f', dest='fourier_dim', default=None, type='int', help='Fourier transform dimension [Default: %default]') parser.add_option('-g', dest='gaps_file', help='Genome assembly gaps BED [Default: %default]') parser.add_option('-l', dest='seq_length', default=131072, type='int', help='Sequence length [Default: %default]') parser.add_option( '--log2', dest='log10to2', default=False, action='store_true', help='Transform values from log10 to log2 [Default: %default]') parser.add_option('-m', dest='params_file', help='Dimension reduction hyper-parameters file') parser.add_option( '--mult_cov', dest='cov_multiplier', default=1, type='float', help= 'Coverage multiplier, useful when the read extension and pool width do not match [Default: %default]' ) parser.add_option( '-n', dest='na_t', default=0.25, type='float', help= 'Remove sequences with an NA% greater than this threshold [Default: %default]' ) parser.add_option( '--no_full', dest='no_full', default=False, action='store_true', help='Do not save full test sequence targets [Default: %default]') parser.add_option( '-o', dest='out_bed_file', help='Output the train/valid/test sequences as a BED file') parser.add_option( '-p', dest='processes', default=1, type='int', help='Number parallel processes to load data [Default: %default]') parser.add_option('-s', dest='stride', default=None, type='int', help='Stride to advance segments [Default: seq_length]') parser.add_option('--scent', dest='scent_file', help='Dimension reduction model file') parser.add_option( '-t', dest='test_pct_or_chr', type='str', default=0.05, help='Proportion of the data for testing [Default: %default]') parser.add_option('-u', dest='unmap_bed', help='Unmappable segments to set to NA') parser.add_option('-w', dest='pool_width', type='int', default=128, help='Average pooling width [Default: %default]') parser.add_option( '--w5', dest='w5', default=False, action='store_true', help='Coverage files are w5 rather than BigWig [Default: %default]') parser.add_option( '-v', dest='valid_pct_or_chr', type='str', default=0.05, help='Proportion of the data for validation [Default: %default]') parser.add_option('-z', dest='compression', help='h5py compression [Default: %default]') (options, args) = parser.parse_args() if len(args) != 3: parser.error( 'Must provide genome FASTA file, sample Wig/BigWig labels and paths, ' 'and model output file') else: fasta_file = args[0] sample_wigs_file = args[1] hdf5_file = args[2] random.seed(1) if options.stride is None: options.stride = options.seq_length ################################################################ # assess bigwigs ################################################################ # get wig files and labels target_wigs = OrderedDict() target_strands = [] target_labels = [] for line in open(sample_wigs_file, encoding='UTF-8'): a = line.rstrip().split('\t') target_wigs[a[0]] = a[1] if len(a) > 2: target_strands.append(a[2]) else: target_strands.append('*') if len(a) > 3: target_labels.append(a[3]) else: target_labels.append('') if options.fourier_dim is not None and 2 * options.fourier_dim >= options.seq_length / options.pool_width: print( "Fourier transform to %d dims won't compress %d length sequences with %d pooling" % (options.fourier_dim, options.seq_length, options.pool_width), file=sys.stderr) exit(1) ################################################################ # prepare genomic segments ################################################################ chrom_segments = genome.load_chromosomes(fasta_file) # remove gaps if options.gaps_file: chrom_segments = genome.split_contigs(chrom_segments, options.gaps_file) # ditch the chromosomes segments = [] for chrom in chrom_segments: segments += [(chrom, seg_start, seg_end) for seg_start, seg_end in chrom_segments[chrom]] # standardize order segments.sort() # filter for large enough segments = [ cse for cse in segments if cse[2] - cse[1] >= options.seq_length ] # down-sample if options.sample_pct < 1.0: segments = random.sample(segments, int(options.sample_pct * len(segments))) # limit to a BED file if options.limit_bed is not None: segments = limit_segments(segments, options.limit_bed) ################################################################ # one hot code sequences ################################################################ seqs_1hot, seqs_segments = segments_1hot(fasta_file, segments, options.seq_length, options.stride) print('%d sequences one hot coded' % seqs_1hot.shape[0]) ################################################################ # load model ################################################################ if options.params_file: job = dna_io.read_job_params(options.params_file) job['num_targets'] = len(target_wigs) job['batch_size'] = 1024 job['model'] = job.get('model', 'autoencoder') if job['model'] == 'autoencoder': model = autoencoder.AE(job) saver = tf.train.Saver() else: model = joblib.load(options.scent_file) ################################################################ # bigwig read and process ################################################################ print('Reading and pre-processing bigwigs for %d segments' % len(segments), flush=True) targets_real = [] targets_imag = [] include_indexes = [] include_marker = 0 targets_test = [] test_indexes = [] test_marker = 0 update_i = 0 ssi = 0 # initialize multiprocessing pool pool = multiprocessing.Pool(options.processes) with tf.Session() as sess: if options.scent_file and job['model'] == 'autoencoder': saver.restore(sess, options.scent_file) # batch segment processing bstart = 0 while bstart < len(segments): if update_i % 1 == 0: print('Tiling from %s:%d-%d' % segments[bstart], flush=True) # determine batch end bend = batch_end(segments, bstart, 400000) # bigwig_read parameters bwr_params = [(wig_file, segments[bstart:bend], options.seq_length, options.pool_width, options.stride, options.log10to2, options.cov_multiplier) for wig_file in target_wigs.values()] # pull the target values in parallel if options.w5: wig_targets = pool.starmap(w5_batch, bwr_params) else: wig_targets = pool.starmap(bigwig_batch, bwr_params) # transpose to S x L x T (making a copy?) targets_wig = np.transpose(np.array(wig_targets), axes=(1, 2, 0)) # clip if options.clip is not None: targets_wig = targets_wig.clip(options.clip) # sample indexes from this batch if options.test_pct_or_chr.startswith('chr'): test_bindexes = [ twi for twi in range(targets_wig.shape[0]) if seqs_segments[ssi + twi][0] == options.test_pct_or_chr ] else: test_pct = float(options.test_pct_or_chr) test_bindexes = [ twi for twi in range(targets_wig.shape[0]) if random.random() < test_pct ] # capture test indexes test_indexes += [test_marker + tbi for tbi in test_bindexes] # update test marker test_marker += targets_wig.shape[0] # save the full test targets if not options.no_full: targets_test.append(targets_wig[test_bindexes]) # map to latent space if options.scent_file is None: targets_latent = targets_wig else: targets_latent = latent_transform(sess, model, job, targets_wig) # compress across length if options.fourier_dim is None: targets_rfour = targets_latent targets_ifour = None else: targets_rfour, targets_ifour = fourier_transform( targets_latent, options.fourier_dim) # save targets_real.append(targets_rfour) targets_imag.append(targets_ifour) # update seqs_segments index ssi += targets_wig.shape[0] # update batch bstart = bend update_i += 1 pool.close() # stack arrays targets_real = np.vstack(targets_real) if options.fourier_dim is not None: targets_imag = np.vstack(targets_imag) if not options.no_full: targets_test = np.vstack(targets_test) print('%d target sequences' % targets_real.shape[0]) ################################################################ # correct for unmappable regions ################################################################ if options.unmap_bed is not None: seqs_na = annotate_na(seqs_segments, options.unmap_bed, options.seq_length, options.pool_width) # determine mappable sequences and update test indexes map_indexes = [] test_indexes_set = set(test_indexes) print('test_indexes', len(test_indexes)) test_indexes_na = [] new_i = 0 for old_i in range(seqs_na.shape[0]): # mappable if seqs_na[old_i, :].mean(dtype='float64') < options.na_t: map_indexes.append(old_i) if old_i in test_indexes_set: test_indexes_na.append(new_i) new_i += 1 # unmappable else: # forget it pass # update data structures targets_real = targets_real[map_indexes] if options.fourier_dim is not None: targets_imag = targets_imag[map_indexes] seqs_1hot = seqs_1hot[map_indexes] seqs_segments = [seqs_segments[mi] for mi in map_indexes] seqs_na = seqs_na[map_indexes] test_indexes = test_indexes_na print('test_indexes', len(test_indexes)) ################################################################ # write to train, valid, test HDF5 ################################################################ if options.valid_pct_or_chr.startswith('chr'): # sample valid chromosome valid_indexes = [ si for si in range(len(seqs_segments)) if seqs_segments[si][0] == options.valid_pct_or_chr ] else: # sample valid indexes (we already have test) valid_pct = float(options.valid_pct_or_chr) valid_n = int(valid_pct * targets_real.shape[0]) nontest_indexes = set(range(targets_real.shape[0])) - set(test_indexes) valid_indexes = random.sample(nontest_indexes, valid_n) # remainder is training train_indexes = list( set(range(len(seqs_segments))) - set(valid_indexes) - set(test_indexes)) # training may requires shuffle random.shuffle(sorted(train_indexes)) random.shuffle(sorted(valid_indexes)) random.shuffle(sorted(test_indexes)) # write to HDF5 hdf5_out = h5py.File(hdf5_file, 'w') # store pooling hdf5_out.create_dataset('pool_width', data=options.pool_width, dtype='int') # store targets target_ids = np.array(list(target_wigs.keys()), dtype='S') hdf5_out.create_dataset('target_ids', data=target_ids) target_labels = np.array(target_labels, dtype='S') hdf5_out.create_dataset('target_labels', data=target_labels) target_strands = np.array(target_strands, dtype='S') hdf5_out.create_dataset('target_strands', data=target_strands) # HDF5 train hdf5_out.create_dataset('train_in', data=seqs_1hot[train_indexes], dtype='bool', compression=options.compression) hdf5_out.create_dataset('train_out', data=targets_real[train_indexes], dtype='float16', compression=options.compression) if options.fourier_dim is not None: hdf5_out.create_dataset('train_out_imag', data=targets_imag[train_indexes], dtype='float16', compression=options.compression) if options.unmap_bed is not None: hdf5_out.create_dataset('train_na', data=seqs_na[train_indexes], dtype='bool', compression=options.compression) # HDF5 valid hdf5_out.create_dataset('valid_in', data=seqs_1hot[valid_indexes], dtype='bool', compression=options.compression) hdf5_out.create_dataset('valid_out', data=targets_real[valid_indexes], dtype='float16', compression=options.compression) if options.fourier_dim is not None: hdf5_out.create_dataset('valid_out_imag', data=targets_imag[valid_indexes], dtype='float16', compression=options.compression) if options.unmap_bed is not None: hdf5_out.create_dataset('valid_na', data=seqs_na[valid_indexes], dtype='bool', compression=options.compression) # HDF5 test hdf5_out.create_dataset('test_in', data=seqs_1hot[test_indexes], dtype='bool', compression=options.compression) hdf5_out.create_dataset('test_out', data=targets_real[test_indexes], dtype='float16', compression=options.compression) if options.fourier_dim is not None: hdf5_out.create_dataset('test_out_imag', data=targets_imag[test_indexes], dtype='float16', compression=options.compression) if not options.no_full: hdf5_out.create_dataset('test_out_full', data=targets_test, dtype='float16', compression=options.compression) if options.unmap_bed is not None: hdf5_out.create_dataset('test_na', data=seqs_na[test_indexes], dtype='bool', compression=options.compression) hdf5_out.close() # output BED file if options.out_bed_file: out_bed_out = open(options.out_bed_file, 'w') for si in train_indexes: print('%s\t%d\t%d\ttrain' % seqs_segments[si], file=out_bed_out) for si in valid_indexes: print('%s\t%d\t%d\tvalid' % seqs_segments[si], file=out_bed_out) for si in test_indexes: print('%s\t%d\t%d\ttest' % seqs_segments[si], file=out_bed_out) out_bed_out.close()
def main(): usage = "usage: %prog [options] <fasta_file> <sample_wigs_file> <hdf5_file>" parser = OptionParser(usage) parser.add_option( "-b", dest="limit_bed", help="Limit to segments that overlap regions in a BED file", ) parser.add_option( "-c", dest="clip", default=None, type="float", help="Clip target values to have minimum [Default: %default]", ) parser.add_option( "-d", dest="sample_pct", default=1.0, type="float", help="Down-sample the segments", ) parser.add_option( "-f", dest="fourier_dim", default=None, type="int", help="Fourier transform dimension [Default: %default]", ) parser.add_option("-g", dest="gaps_file", help="Genome assembly gaps BED [Default: %default]") parser.add_option( "-l", dest="seq_length", default=131072, type="int", help="Sequence length [Default: %default]", ) parser.add_option( "--log2", dest="log10to2", default=False, action="store_true", help="Transform values from log10 to log2 [Default: %default]", ) parser.add_option("-m", dest="params_file", help="Dimension reduction hyper-parameters file") parser.add_option( "--mult_cov", dest="cov_multiplier", default=1, type="float", help= "Coverage multiplier, useful when the read extension and pool width do not match [Default: %default]", ) parser.add_option( "-n", dest="na_t", default=0.25, type="float", help= "Remove sequences with an NA% greater than this threshold [Default: %default]", ) parser.add_option( "--no_full", dest="no_full", default=False, action="store_true", help="Do not save full test sequence targets [Default: %default]", ) parser.add_option( "-o", dest="out_bed_file", help="Output the train/valid/test sequences as a BED file", ) parser.add_option( "-p", dest="processes", default=1, type="int", help="Number parallel processes to load data [Default: %default]", ) parser.add_option( "-s", dest="stride", default=None, type="int", help="Stride to advance segments [Default: seq_length]", ) parser.add_option("--scent", dest="scent_file", help="Dimension reduction model file") parser.add_option( "-t", dest="test_pct_or_chr", type="str", default=0.05, help="Proportion of the data for testing [Default: %default]", ) parser.add_option("-u", dest="unmap_bed", help="Unmappable segments to set to NA") parser.add_option( "-w", dest="pool_width", type="int", default=128, help="Average pooling width [Default: %default]", ) parser.add_option( "--w5", dest="w5", default=False, action="store_true", help="Coverage files are w5 rather than BigWig [Default: %default]", ) parser.add_option( "-v", dest="valid_pct_or_chr", type="str", default=0.05, help="Proportion of the data for validation [Default: %default]", ) parser.add_option("-z", dest="compression", help="h5py compression [Default: %default]") (options, args) = parser.parse_args() if len(args) != 3: parser.error( "Must provide genome FASTA file, sample Wig/BigWig labels and paths, " "and model output file") else: fasta_file = args[0] sample_wigs_file = args[1] hdf5_file = args[2] random.seed(1) if options.stride is None: options.stride = options.seq_length ################################################################ # assess bigwigs ################################################################ # get wig files and labels target_wigs = OrderedDict() target_strands = [] target_labels = [] for line in open(sample_wigs_file, encoding="UTF-8"): a = line.rstrip().split("\t") target_wigs[a[0]] = a[1] if len(a) > 2: target_strands.append(a[2]) else: target_strands.append("*") if len(a) > 3: target_labels.append(a[3]) else: target_labels.append("") if (options.fourier_dim is not None and 2 * options.fourier_dim >= options.seq_length / options.pool_width): print( "Fourier transform to %d dims won't compress %d length sequences with %d pooling" % (options.fourier_dim, options.seq_length, options.pool_width), file=sys.stderr, ) exit(1) ################################################################ # prepare genomic segments ################################################################ chrom_segments = genome.load_chromosomes(fasta_file) # remove gaps if options.gaps_file: chrom_segments = genome.split_contigs(chrom_segments, options.gaps_file) # ditch the chromosomes segments = [] for chrom in chrom_segments: segments += [(chrom, seg_start, seg_end) for seg_start, seg_end in chrom_segments[chrom]] # standardize order segments.sort() # filter for large enough segments = [ cse for cse in segments if cse[2] - cse[1] >= options.seq_length ] # down-sample if options.sample_pct < 1.0: segments = random.sample(segments, int(options.sample_pct * len(segments))) # limit to a BED file if options.limit_bed is not None: segments = limit_segments(segments, options.limit_bed) ################################################################ # one hot code sequences ################################################################ seqs_1hot, seqs_segments = segments_1hot(fasta_file, segments, options.seq_length, options.stride) print("%d sequences one hot coded" % seqs_1hot.shape[0]) ################################################################ # load model ################################################################ if options.params_file: job = dna_io.read_job_params(options.params_file) job["num_targets"] = len(target_wigs) job["batch_size"] = 1024 job["model"] = job.get("model", "autoencoder") if job["model"] == "autoencoder": model = autoencoder.AE(job) saver = tf.train.Saver() else: model = joblib.load(options.scent_file) ################################################################ # bigwig read and process ################################################################ print("Reading and pre-processing bigwigs for %d segments" % len(segments), flush=True) targets_real = [] targets_imag = [] include_indexes = [] include_marker = 0 targets_test = [] test_indexes = [] test_marker = 0 update_i = 0 ssi = 0 # initialize multiprocessing pool pool = multiprocessing.Pool(options.processes) with tf.Session() as sess: if options.scent_file and job["model"] == "autoencoder": saver.restore(sess, options.scent_file) # batch segment processing bstart = 0 while bstart < len(segments): if update_i % 1 == 0: print("Tiling from %s:%d-%d" % segments[bstart], flush=True) # determine batch end bend = batch_end(segments, bstart, 400000) # bigwig_read parameters bwr_params = [( wig_file, segments[bstart:bend], options.seq_length, options.pool_width, options.stride, options.log10to2, options.cov_multiplier, ) for wig_file in target_wigs.values()] # pull the target values in parallel if options.w5: wig_targets = pool.starmap(w5_batch, bwr_params) else: wig_targets = pool.starmap(bigwig_batch, bwr_params) # transpose to S x L x T (making a copy?) targets_wig = np.transpose(np.array(wig_targets), axes=(1, 2, 0)) # clip if options.clip is not None: targets_wig = targets_wig.clip(options.clip) # sample indexes from this batch if options.test_pct_or_chr.startswith("chr"): test_bindexes = [ twi for twi in range(targets_wig.shape[0]) if seqs_segments[ssi + twi][0] == options.test_pct_or_chr ] else: test_pct = float(options.test_pct_or_chr) test_bindexes = [ twi for twi in range(targets_wig.shape[0]) if random.random() < test_pct ] # capture test indexes test_indexes += [test_marker + tbi for tbi in test_bindexes] # update test marker test_marker += targets_wig.shape[0] # save the full test targets if not options.no_full: targets_test.append(targets_wig[test_bindexes]) # map to latent space if options.scent_file is None: targets_latent = targets_wig else: targets_latent = latent_transform(sess, model, job, targets_wig) # compress across length if options.fourier_dim is None: targets_rfour = targets_latent targets_ifour = None else: targets_rfour, targets_ifour = fourier_transform( targets_latent, options.fourier_dim) # save targets_real.append(targets_rfour) targets_imag.append(targets_ifour) # update seqs_segments index ssi += targets_wig.shape[0] # update batch bstart = bend update_i += 1 pool.close() # stack arrays targets_real = np.vstack(targets_real) if options.fourier_dim is not None: targets_imag = np.vstack(targets_imag) if not options.no_full: targets_test = np.vstack(targets_test) print("%d target sequences" % targets_real.shape[0]) ################################################################ # correct for unmappable regions ################################################################ if options.unmap_bed is not None: seqs_na = annotate_na(seqs_segments, options.unmap_bed, options.seq_length, options.pool_width) # determine mappable sequences and update test indexes map_indexes = [] test_indexes_set = set(test_indexes) print("test_indexes", len(test_indexes)) test_indexes_na = [] new_i = 0 for old_i in range(seqs_na.shape[0]): # mappable if seqs_na[old_i, :].mean(dtype="float64") < options.na_t: map_indexes.append(old_i) if old_i in test_indexes_set: test_indexes_na.append(new_i) new_i += 1 # unmappable else: # forget it pass # update data structures targets_real = targets_real[map_indexes] if options.fourier_dim is not None: targets_imag = targets_imag[map_indexes] seqs_1hot = seqs_1hot[map_indexes] seqs_segments = [seqs_segments[mi] for mi in map_indexes] seqs_na = seqs_na[map_indexes] test_indexes = test_indexes_na print("test_indexes", len(test_indexes)) ################################################################ # write to train, valid, test HDF5 ################################################################ if options.valid_pct_or_chr.startswith("chr"): # sample valid chromosome valid_indexes = [ si for si in range(len(seqs_segments)) if seqs_segments[si][0] == options.valid_pct_or_chr ] else: # sample valid indexes (we already have test) valid_pct = float(options.valid_pct_or_chr) valid_n = int(valid_pct * targets_real.shape[0]) nontest_indexes = set(range(targets_real.shape[0])) - set(test_indexes) valid_indexes = random.sample(nontest_indexes, valid_n) # remainder is training train_indexes = list( set(range(len(seqs_segments))) - set(valid_indexes) - set(test_indexes)) # training may requires shuffle random.shuffle(sorted(train_indexes)) random.shuffle(sorted(valid_indexes)) random.shuffle(sorted(test_indexes)) # write to HDF5 hdf5_out = h5py.File(hdf5_file, "w") # store pooling hdf5_out.create_dataset("pool_width", data=options.pool_width, dtype="int") # store targets target_ids = np.array(list(target_wigs.keys()), dtype="S") hdf5_out.create_dataset("target_ids", data=target_ids) target_labels = np.array(target_labels, dtype="S") hdf5_out.create_dataset("target_labels", data=target_labels) target_strands = np.array(target_strands, dtype="S") hdf5_out.create_dataset("target_strands", data=target_strands) # HDF5 train hdf5_out.create_dataset( "train_in", data=seqs_1hot[train_indexes], dtype="bool", compression=options.compression, ) hdf5_out.create_dataset( "train_out", data=targets_real[train_indexes], dtype="float16", compression=options.compression, ) if options.fourier_dim is not None: hdf5_out.create_dataset( "train_out_imag", data=targets_imag[train_indexes], dtype="float16", compression=options.compression, ) if options.unmap_bed is not None: hdf5_out.create_dataset( "train_na", data=seqs_na[train_indexes], dtype="bool", compression=options.compression, ) # HDF5 valid hdf5_out.create_dataset( "valid_in", data=seqs_1hot[valid_indexes], dtype="bool", compression=options.compression, ) hdf5_out.create_dataset( "valid_out", data=targets_real[valid_indexes], dtype="float16", compression=options.compression, ) if options.fourier_dim is not None: hdf5_out.create_dataset( "valid_out_imag", data=targets_imag[valid_indexes], dtype="float16", compression=options.compression, ) if options.unmap_bed is not None: hdf5_out.create_dataset( "valid_na", data=seqs_na[valid_indexes], dtype="bool", compression=options.compression, ) # HDF5 test hdf5_out.create_dataset( "test_in", data=seqs_1hot[test_indexes], dtype="bool", compression=options.compression, ) hdf5_out.create_dataset( "test_out", data=targets_real[test_indexes], dtype="float16", compression=options.compression, ) if options.fourier_dim is not None: hdf5_out.create_dataset( "test_out_imag", data=targets_imag[test_indexes], dtype="float16", compression=options.compression, ) if not options.no_full: hdf5_out.create_dataset( "test_out_full", data=targets_test, dtype="float16", compression=options.compression, ) if options.unmap_bed is not None: hdf5_out.create_dataset( "test_na", data=seqs_na[test_indexes], dtype="bool", compression=options.compression, ) hdf5_out.close() # output BED file if options.out_bed_file: out_bed_out = open(options.out_bed_file, "w") for si in train_indexes: print("%s\t%d\t%d\ttrain" % seqs_segments[si], file=out_bed_out) for si in valid_indexes: print("%s\t%d\t%d\tvalid" % seqs_segments[si], file=out_bed_out) for si in test_indexes: print("%s\t%d\t%d\ttest" % seqs_segments[si], file=out_bed_out) out_bed_out.close()
def run(params_file, data_file, num_train_epochs): shifts = [int(shift) for shift in FLAGS.shifts.split(',')] ####################################################### # load data ####################################################### data_open = h5py.File(data_file) train_seqs = data_open['train_in'] train_targets = data_open['train_out'] train_na = None if 'train_na' in data_open: train_na = data_open['train_na'] valid_seqs = data_open['valid_in'] valid_targets = data_open['valid_out'] valid_na = None if 'valid_na' in data_open: valid_na = data_open['valid_na'] ####################################################### # model parameters and placeholders ####################################################### job = dna_io.read_job_params(params_file) job['batch_length'] = train_seqs.shape[1] job['seq_depth'] = train_seqs.shape[2] job['num_targets'] = train_targets.shape[2] job['target_pool'] = int(np.array(data_open.get('pool_width', 1))) job['early_stop'] = job.get('early_stop', 16) job['rate_drop'] = job.get('rate_drop', 3) t0 = time.time() dr = seqnn.SeqNN() dr.build(job) print('Model building time %f' % (time.time() - t0)) # adjust for fourier job['fourier'] = 'train_out_imag' in data_open if job['fourier']: train_targets_imag = data_open['train_out_imag'] valid_targets_imag = data_open['valid_out_imag'] ####################################################### # train ####################################################### # initialize batcher if job['fourier']: batcher_train = batcher.BatcherF(train_seqs, train_targets, train_targets_imag, train_na, dr.batch_size, dr.target_pool, shuffle=True) batcher_valid = batcher.BatcherF(valid_seqs, valid_targets, valid_targets_imag, valid_na, dr.batch_size, dr.target_pool) else: batcher_train = batcher.Batcher(train_seqs, train_targets, train_na, dr.batch_size, dr.target_pool, shuffle=True) batcher_valid = batcher.Batcher(valid_seqs, valid_targets, valid_na, dr.batch_size, dr.target_pool) print('Batcher initialized') # checkpoints saver = tf.train.Saver() config = tf.ConfigProto() if FLAGS.log_device_placement: config.log_device_placement = True with tf.Session(config=config) as sess: t0 = time.time() # set seed tf.set_random_seed(FLAGS.seed) if FLAGS.logdir: train_writer = tf.summary.FileWriter(FLAGS.logdir + '/train', sess.graph) else: train_writer = None if FLAGS.restart: # load variables into session saver.restore(sess, FLAGS.restart) else: # initialize variables print('Initializing...') sess.run(tf.global_variables_initializer()) print('Initialization time %f' % (time.time() - t0)) train_loss = None best_loss = None early_stop_i = 0 undroppable_counter = 3 max_drops = 8 num_drops = 0 for epoch in range(num_train_epochs): if early_stop_i < job['early_stop'] or epoch < FLAGS.min_epochs: t0 = time.time() # save previous train_loss_last = train_loss # alternate forward and reverse batches fwdrc = True if FLAGS.rc and epoch % 2 == 1: fwdrc = False # cycle shifts shift_i = epoch % len(shifts) # train train_loss = dr.train_epoch(sess, batcher_train, fwdrc, shifts[shift_i], train_writer) # validate valid_acc = dr.test(sess, batcher_valid, mc_n=FLAGS.mc_n, rc=FLAGS.rc, shifts=shifts) valid_loss = valid_acc.loss valid_r2 = valid_acc.r2().mean() del valid_acc best_str = '' if best_loss is None or valid_loss < best_loss: best_loss = valid_loss best_str = ', best!' early_stop_i = 0 saver.save( sess, '%s/%s_best.tf' % (FLAGS.logdir, FLAGS.save_prefix)) else: early_stop_i += 1 # measure time et = time.time() - t0 if et < 600: time_str = '%3ds' % et elif et < 6000: time_str = '%3dm' % (et / 60) else: time_str = '%3.1fh' % (et / 3600) # print update print( 'Epoch %3d: Train loss: %7.5f, Valid loss: %7.5f, Valid R2: %7.5f, Time: %s%s' % (epoch + 1, train_loss, valid_loss, valid_r2, time_str, best_str), end='') # if training stagnant if FLAGS.learn_rate_drop and num_drops < max_drops and undroppable_counter == 0 and ( train_loss_last - train_loss) / train_loss_last < 0.0002: print(', rate drop', end='') dr.drop_rate(2 / 3) undroppable_counter = 1 num_drops += 1 else: undroppable_counter = max(0, undroppable_counter - 1) print('') sys.stdout.flush() if FLAGS.logdir: train_writer.close()
def run(params_file, train_file, test_file, num_train_epochs, batches_per_epoch, num_test_batches): np.random.seed(FLAGS.seed) shifts = [int(shift) for shift in FLAGS.shifts.split(',')] job = dna_io.read_job_params(params_file) job['early_stop'] = job.get('early_stop', 16) job['rate_drop'] = job.get('rate_drop', 3) data_ops, training_init_op, test_init_op = make_data_ops( job, train_file, test_file) dr = seqnn.SeqNN() dr.build_from_data_ops(job, data_ops) # checkpoints saver = tf.train.Saver() with tf.Session() as sess: train_writer = tf.summary.FileWriter( FLAGS.logdir + '/train', sess.graph) if FLAGS.logdir else None t0 = time.time() sess.run(tf.local_variables_initializer()) sess.run(tf.global_variables_initializer()) coord = tf.train.Coordinator() tf.train.start_queue_runners(coord=coord) if FLAGS.restart: # load variables into session saver.restore(sess, FLAGS.restart) else: # initialize variables print('Initializing...') sess.run(tf.global_variables_initializer()) print('Initialization time %f' % (time.time() - t0)) train_loss = None best_loss = None early_stop_i = 0 undroppable_counter = 3 max_drops = 8 num_drops = 0 for epoch in range(num_train_epochs): if early_stop_i < job['early_stop'] or epoch < FLAGS.min_epochs: t0 = time.time() # save previous train_loss_last = train_loss # alternate forward and reverse batches fwdrc = True if FLAGS.rc and epoch % 2 == 1: fwdrc = False # cycle shifts shift_i = epoch % len(shifts) sess.run(training_init_op) # train train_loss, steps = dr.train_epoch_from_data_ops( sess, train_writer, batches_per_epoch) sess.run(test_init_op) valid_acc = dr.test_from_data_ops( sess, num_test_batches=num_test_batches) valid_loss = valid_acc.loss valid_r2 = valid_acc.r2().mean() del valid_acc best_str = '' if best_loss is None or valid_loss < best_loss: best_loss = valid_loss best_str = ', best!' early_stop_i = 0 saver.save( sess, '%s/%s_best.tf' % (FLAGS.logdir, FLAGS.save_prefix)) else: early_stop_i += 1 # measure time et = time.time() - t0 if et < 600: time_str = '%3ds' % et elif et < 6000: time_str = '%3dm' % (et / 60) else: time_str = '%3.1fh' % (et / 3600) # print update print( 'Epoch: %3d, Steps: %7d, Train loss: %7.5f, Valid loss: %7.5f, Valid R2: %7.5f, Time: %s%s' % (epoch + 1, steps, train_loss, valid_loss, valid_r2, time_str, best_str), end='') # if training stagnant if FLAGS.learn_rate_drop and ( num_drops < max_drops) and undroppable_counter == 0 and ( train_loss_last - train_loss) / train_loss_last < 0.0002: print(', rate drop', end='') dr.drop_rate(2 / 3) undroppable_counter = 1 num_drops += 1 else: undroppable_counter = max(0, undroppable_counter - 1) print('') sys.stdout.flush() if FLAGS.logdir: train_writer.close()