def main(): usage = 'usage: %prog [options] <params_file> <data1_dir> ...' parser = OptionParser(usage) parser.add_option('-k', dest='keras_fit', default=False, action='store_true', help='Train with Keras fit method [Default: %default]') parser.add_option( '-o', dest='out_dir', default='train_out', help='Output directory for test statistics [Default: %default]') parser.add_option( '--restore', dest='restore', help='Restore model and continue training [Default: %default]') parser.add_option('--trunk', dest='trunk', default=False, action='store_true', help='Restore only model trunk [Default: %default]') parser.add_option( '--tfr_train', dest='tfr_train_pattern', default=None, help= 'Training TFR pattern string appended to data_dir/tfrecords for subsetting [Default: %default]' ) parser.add_option( '--tfr_eval', dest='tfr_eval_pattern', default=None, help= 'Evaluation TFR pattern string appended to data_dir/tfrecords for subsetting [Default: %default]' ) (options, args) = parser.parse_args() if len(args) < 2: parser.error('Must provide parameters and data directory.') else: params_file = args[0] data_dirs = args[1:] if options.keras_fit and len(data_dirs) > 1: print('Cannot use keras fit method with multi-genome training.') exit(1) if not os.path.isdir(options.out_dir): os.mkdir(options.out_dir) if params_file != '%s/params.json' % options.out_dir: shutil.copy(params_file, '%s/params.json' % options.out_dir) # read model parameters with open(params_file) as params_open: params = json.load(params_open) params_model = params['model'] params_train = params['train'] # read datasets train_data = [] eval_data = [] for data_dir in data_dirs: # load train data train_data.append( dataset.SeqDataset(data_dir, split_label='train', batch_size=params_train['batch_size'], mode='train', tfr_pattern=options.tfr_train_pattern)) # load eval data eval_data.append( dataset.SeqDataset(data_dir, split_label='valid', batch_size=params_train['batch_size'], mode='eval', tfr_pattern=options.tfr_eval_pattern)) if params_train.get('num_gpu', 1) == 1: ######################################## # one GPU # initialize model # print('INITIALIZE MODEL') seqnn_model = seqnn.SeqNN(params_model) # seqnn_model = model_zoo.basenji_model((131072,4), 3) # restore if options.restore: seqnn_model.restore(options.restore, trunk=options.trunk) # initialize trainer seqnn_trainer = trainer.Trainer(params_train, train_data, eval_data, options.out_dir) # compile model seqnn_trainer.compile(seqnn_model) # else: ######################################## # two GPU # strategy = tf.distribute.MirroredStrategy() # # with strategy.scope(): # # if not options.keras_fit: # # distribute data # for di in range(len(data_dirs)): # train_data[di].distribute(strategy) # eval_data[di].distribute(strategy) # # # initialize model # seqnn_model = seqnn.SeqNN(params_model) # # # restore # if options.restore: # seqnn_model.restore(options.restore, options.trunk) # # # initialize trainer # seqnn_trainer = trainer.Trainer(params_train, train_data, eval_data, options.out_dir, # strategy, params_train['num_gpu'], options.keras_fit) # # # compile model # seqnn_trainer.compile(seqnn_model) # train model if options.keras_fit: seqnn_trainer.fit_keras(seqnn_model) else: if len(data_dirs) == 1: seqnn_trainer.fit_tape(seqnn_model) else: seqnn_trainer.fit2(seqnn_model)
train_out = pd.read_csv( "/mnt/scratch/ws/psbelokopytova/202103211631polina/nn_anopheles/dataset_like_Akita/data/Aalb_2048bp_repeat/train_out_test2/model2.out", sep=" ", names=range(24)) fasta_file = "/mnt/scratch/ws/psbelokopytova/202103211631polina/nn_anopheles/input/genomes/AalbS2_V4.fa" params_file = model_dir + 'params.json' for i in range(0, 48, 3): model_file = model_dir + 'model_check_epoch' + str(i) + '.h5' # model_file = model_dir+'model_best.h5' with open(params_file) as params_open: params = json.load(params_open) params_model = params['model'] params_train = params['train'] seqnn_model = seqnn.SeqNN(params_model) ### restore model ### seqnn_model.restore(model_file) print('successfully loaded') ### names of targets ### data_dir = '/mnt/scratch/ws/psbelokopytova/202103211631polina/nn_anopheles/dataset_like_Akita/data/Aalb_2048' # data_dir ='/mnt/scratch/ws/psbelokopytova/202103211631polina/nn_anopheles/dataset_like_Akita/data/Aste_2048_globaloe' hic_targets = pd.read_csv(data_dir + '/targets.txt', sep='\t') hic_file_dict_num = dict( zip(hic_targets['index'].values, hic_targets['file'].values)) hic_file_dict = dict( zip(hic_targets['identifier'].values, hic_targets['file'].values)) hic_num_to_name_dict = dict( zip(hic_targets['index'].values, hic_targets['identifier'].values))
def main(): usage = 'usage: %prog [options] <params_file> <model_file> <data_dir>' parser = OptionParser(usage) parser.add_option('--ai', dest='accuracy_indexes', help='Comma-separated list of target indexes to make accuracy scatter plots.') parser.add_option('--head', dest='head_i', default=0, type='int', help='Parameters head to test [Default: %default]') parser.add_option('--mc', dest='mc_n', default=0, type='int', help='Monte carlo test iterations [Default: %default]') parser.add_option('--peak','--peaks', dest='peaks', default=False, action='store_true', help='Compute expensive peak accuracy [Default: %default]') parser.add_option('-o', dest='out_dir', default='test_out', help='Output directory for test statistics [Default: %default]') parser.add_option('--rc', dest='rc', default=False, action='store_true', help='Average the fwd and rc predictions [Default: %default]') parser.add_option('--save', dest='save', default=False, action='store_true', help='Save targets and predictions numpy arrays [Default: %default]') parser.add_option('--shifts', dest='shifts', default='0', help='Ensemble prediction shifts [Default: %default]') parser.add_option('-t', dest='targets_file', default=None, type='str', help='File specifying target indexes and labels in table format') parser.add_option('--split', dest='split_label', default='test', help='Dataset split label for eg TFR pattern [Default: %default]') parser.add_option('--tfr', dest='tfr_pattern', default=None, help='TFR pattern string appended to data_dir/tfrecords for subsetting [Default: %default]') (options, args) = parser.parse_args() if len(args) != 3: parser.error('Must provide parameters, model, and test data HDF5') else: params_file = args[0] model_file = args[1] data_dir = args[2] if not os.path.isdir(options.out_dir): os.mkdir(options.out_dir) # parse shifts to integers options.shifts = [int(shift) for shift in options.shifts.split(',')] ####################################################### # inputs # read targets if options.targets_file is None: options.targets_file = '%s/targets.txt' % data_dir targets_df = pd.read_csv(options.targets_file, index_col=0, sep='\t') # read model parameters with open(params_file) as params_open: params = json.load(params_open) params_model = params['model'] params_train = params['train'] # construct eval data eval_data = dataset.SeqDataset(data_dir, split_label=options.split_label, batch_size=params_train['batch_size'], mode='eval', tfr_pattern=options.tfr_pattern) # initialize model seqnn_model = seqnn.SeqNN(params_model) seqnn_model.restore(model_file, options.head_i) seqnn_model.build_ensemble(options.rc, options.shifts) ####################################################### # evaluate loss_label = params_train.get('loss', 'poisson').lower() spec_weight = params_train.get('spec_weight', 1) loss_fn = trainer.parse_loss(loss_label, spec_weight=spec_weight) # evaluate test_loss, test_metric1, test_metric2 = seqnn_model.evaluate(eval_data, loss=loss_fn) # print summary statistics print('\nTest Loss: %7.5f' % test_loss) if loss_label == 'bce': print('Test AUROC: %7.5f' % test_metric1.mean()) print('Test AUPRC: %7.5f' % test_metric2.mean()) # write target-level statistics targets_acc_df = pd.DataFrame({ 'index': targets_df.index, 'auroc': test_metric1, 'auprc': test_metric2, 'identifier': targets_df.identifier, 'description': targets_df.description }) else: print('Test PearsonR: %7.5f' % test_metric1.mean()) print('Test R2: %7.5f' % test_metric2.mean()) # write target-level statistics targets_acc_df = pd.DataFrame({ 'index': targets_df.index, 'pearsonr': test_metric1, 'r2': test_metric2, 'identifier': targets_df.identifier, 'description': targets_df.description }) targets_acc_df.to_csv('%s/acc.txt'%options.out_dir, sep='\t', index=False, float_format='%.5f') ####################################################### # predict? if options.save or options.peaks or options.accuracy_indexes is not None: # compute predictions test_preds = seqnn_model.predict(eval_data).astype('float16') # read targets test_targets = eval_data.numpy(return_inputs=False) if options.save: preds_h5 = h5py.File('%s/preds.h5' % options.out_dir, 'w') preds_h5.create_dataset('preds', data=test_preds) preds_h5.close() targets_h5 = h5py.File('%s/targets.h5' % options.out_dir, 'w') targets_h5.create_dataset('targets', data=test_targets) targets_h5.close() ####################################################### # peak call accuracy if options.peaks: peaks_out_file = '%s/peaks.txt' % options.out_dir test_peaks(test_preds, test_targets, peaks_out_file) ####################################################### # accuracy plots if options.accuracy_indexes is not None: accuracy_indexes = [int(ti) for ti in options.accuracy_indexes.split(',')] if not os.path.isdir('%s/scatter' % options.out_dir): os.mkdir('%s/scatter' % options.out_dir) if not os.path.isdir('%s/violin' % options.out_dir): os.mkdir('%s/violin' % options.out_dir) if not os.path.isdir('%s/roc' % options.out_dir): os.mkdir('%s/roc' % options.out_dir) if not os.path.isdir('%s/pr' % options.out_dir): os.mkdir('%s/pr' % options.out_dir) for ti in accuracy_indexes: test_targets_ti = test_targets[:, :, ti] ############################################ # scatter # sample every few bins (adjust to plot the # points I want) ds_indexes = np.arange(0, test_preds.shape[1], 8) # subset and flatten test_targets_ti_flat = test_targets_ti[:, ds_indexes].flatten( ).astype('float32') test_preds_ti_flat = test_preds[:, ds_indexes, ti].flatten().astype( 'float32') # take log2 test_targets_ti_log = np.log2(test_targets_ti_flat + 1) test_preds_ti_log = np.log2(test_preds_ti_flat + 1) # plot log2 sns.set(font_scale=1.2, style='ticks') out_pdf = '%s/scatter/t%d.pdf' % (options.out_dir, ti) plots.regplot( test_targets_ti_log, test_preds_ti_log, out_pdf, poly_order=1, alpha=0.3, sample=500, figsize=(6, 6), x_label='log2 Experiment', y_label='log2 Prediction', table=True) ############################################ # violin # call peaks test_targets_ti_lambda = np.mean(test_targets_ti_flat) test_targets_pvals = 1 - poisson.cdf( np.round(test_targets_ti_flat) - 1, mu=test_targets_ti_lambda) test_targets_qvals = np.array(ben_hoch(test_targets_pvals)) test_targets_peaks = test_targets_qvals < 0.01 test_targets_peaks_str = np.where(test_targets_peaks, 'Peak', 'Background') # violin plot sns.set(font_scale=1.3, style='ticks') plt.figure() df = pd.DataFrame({ 'log2 Prediction': np.log2(test_preds_ti_flat + 1), 'Experimental coverage status': test_targets_peaks_str }) ax = sns.violinplot( x='Experimental coverage status', y='log2 Prediction', data=df) ax.grid(True, linestyle=':') plt.savefig('%s/violin/t%d.pdf' % (options.out_dir, ti)) plt.close() # ROC plt.figure() fpr, tpr, _ = roc_curve(test_targets_peaks, test_preds_ti_flat) auroc = roc_auc_score(test_targets_peaks, test_preds_ti_flat) plt.plot( [0, 1], [0, 1], c='black', linewidth=1, linestyle='--', alpha=0.7) plt.plot(fpr, tpr, c='black') ax = plt.gca() ax.set_xlabel('False positive rate') ax.set_ylabel('True positive rate') ax.text( 0.99, 0.02, 'AUROC %.3f' % auroc, horizontalalignment='right') # , fontsize=14) ax.grid(True, linestyle=':') plt.savefig('%s/roc/t%d.pdf' % (options.out_dir, ti)) plt.close() # PR plt.figure() prec, recall, _ = precision_recall_curve(test_targets_peaks, test_preds_ti_flat) auprc = average_precision_score(test_targets_peaks, test_preds_ti_flat) plt.axhline( y=test_targets_peaks.mean(), c='black', linewidth=1, linestyle='--', alpha=0.7) plt.plot(recall, prec, c='black') ax = plt.gca() ax.set_xlabel('Recall') ax.set_ylabel('Precision') ax.text( 0.99, 0.95, 'AUPRC %.3f' % auprc, horizontalalignment='right') # , fontsize=14) ax.grid(True, linestyle=':') plt.savefig('%s/pr/t%d.pdf' % (options.out_dir, ti)) plt.close()