Example #1
0
def main():
    usage = 'usage: %prog [options] <params_file> <data1_dir> ...'
    parser = OptionParser(usage)
    parser.add_option('-k',
                      dest='keras_fit',
                      default=False,
                      action='store_true',
                      help='Train with Keras fit method [Default: %default]')
    parser.add_option(
        '-o',
        dest='out_dir',
        default='train_out',
        help='Output directory for test statistics [Default: %default]')
    parser.add_option(
        '--restore',
        dest='restore',
        help='Restore model and continue training [Default: %default]')
    parser.add_option('--trunk',
                      dest='trunk',
                      default=False,
                      action='store_true',
                      help='Restore only model trunk [Default: %default]')
    parser.add_option(
        '--tfr_train',
        dest='tfr_train_pattern',
        default=None,
        help=
        'Training TFR pattern string appended to data_dir/tfrecords for subsetting [Default: %default]'
    )
    parser.add_option(
        '--tfr_eval',
        dest='tfr_eval_pattern',
        default=None,
        help=
        'Evaluation TFR pattern string appended to data_dir/tfrecords for subsetting [Default: %default]'
    )
    (options, args) = parser.parse_args()

    if len(args) < 2:
        parser.error('Must provide parameters and data directory.')
    else:
        params_file = args[0]
        data_dirs = args[1:]

    if options.keras_fit and len(data_dirs) > 1:
        print('Cannot use keras fit method with multi-genome training.')
        exit(1)

    if not os.path.isdir(options.out_dir):
        os.mkdir(options.out_dir)
    if params_file != '%s/params.json' % options.out_dir:
        shutil.copy(params_file, '%s/params.json' % options.out_dir)

    # read model parameters
    with open(params_file) as params_open:
        params = json.load(params_open)
    params_model = params['model']
    params_train = params['train']

    # read datasets
    train_data = []
    eval_data = []

    for data_dir in data_dirs:
        # load train data
        train_data.append(
            dataset.SeqDataset(data_dir,
                               split_label='train',
                               batch_size=params_train['batch_size'],
                               mode='train',
                               tfr_pattern=options.tfr_train_pattern))

        # load eval data
        eval_data.append(
            dataset.SeqDataset(data_dir,
                               split_label='valid',
                               batch_size=params_train['batch_size'],
                               mode='eval',
                               tfr_pattern=options.tfr_eval_pattern))

    if params_train.get('num_gpu', 1) == 1:
        ########################################
        # one GPU

        # initialize model
        # print('INITIALIZE MODEL')
        seqnn_model = seqnn.SeqNN(params_model)
        # seqnn_model = model_zoo.basenji_model((131072,4), 3)

        # restore
        if options.restore:
            seqnn_model.restore(options.restore, trunk=options.trunk)

        # initialize trainer
        seqnn_trainer = trainer.Trainer(params_train, train_data, eval_data,
                                        options.out_dir)

        # compile model
        seqnn_trainer.compile(seqnn_model)

    # else:
    ########################################
    # two GPU

    # strategy = tf.distribute.MirroredStrategy()
    #
    # with strategy.scope():
    #
    #   if not options.keras_fit:
    #     # distribute data
    #     for di in range(len(data_dirs)):
    #       train_data[di].distribute(strategy)
    #       eval_data[di].distribute(strategy)
    #
    #   # initialize model
    #   seqnn_model = seqnn.SeqNN(params_model)
    #
    #   # restore
    #   if options.restore:
    #     seqnn_model.restore(options.restore, options.trunk)
    #
    #   # initialize trainer
    #   seqnn_trainer = trainer.Trainer(params_train, train_data, eval_data, options.out_dir,
    #                                   strategy, params_train['num_gpu'], options.keras_fit)
    #
    #   # compile model
    #   seqnn_trainer.compile(seqnn_model)

    # train model
    if options.keras_fit:
        seqnn_trainer.fit_keras(seqnn_model)
    else:
        if len(data_dirs) == 1:
            seqnn_trainer.fit_tape(seqnn_model)
        else:
            seqnn_trainer.fit2(seqnn_model)
Example #2
0
train_out = pd.read_csv(
    "/mnt/scratch/ws/psbelokopytova/202103211631polina/nn_anopheles/dataset_like_Akita/data/Aalb_2048bp_repeat/train_out_test2/model2.out",
    sep=" ",
    names=range(24))

fasta_file = "/mnt/scratch/ws/psbelokopytova/202103211631polina/nn_anopheles/input/genomes/AalbS2_V4.fa"
params_file = model_dir + 'params.json'
for i in range(0, 48, 3):
    model_file = model_dir + 'model_check_epoch' + str(i) + '.h5'
    # model_file  = model_dir+'model_best.h5'
    with open(params_file) as params_open:
        params = json.load(params_open)
        params_model = params['model']
        params_train = params['train']

    seqnn_model = seqnn.SeqNN(params_model)

    ### restore model ###
    seqnn_model.restore(model_file)
    print('successfully loaded')

    ### names of targets ###
    data_dir = '/mnt/scratch/ws/psbelokopytova/202103211631polina/nn_anopheles/dataset_like_Akita/data/Aalb_2048'
    # data_dir ='/mnt/scratch/ws/psbelokopytova/202103211631polina/nn_anopheles/dataset_like_Akita/data/Aste_2048_globaloe'
    hic_targets = pd.read_csv(data_dir + '/targets.txt', sep='\t')
    hic_file_dict_num = dict(
        zip(hic_targets['index'].values, hic_targets['file'].values))
    hic_file_dict = dict(
        zip(hic_targets['identifier'].values, hic_targets['file'].values))
    hic_num_to_name_dict = dict(
        zip(hic_targets['index'].values, hic_targets['identifier'].values))
Example #3
0
def main():
  usage = 'usage: %prog [options] <params_file> <model_file> <data_dir>'
  parser = OptionParser(usage)
  parser.add_option('--ai', dest='accuracy_indexes',
      help='Comma-separated list of target indexes to make accuracy scatter plots.')
  parser.add_option('--head', dest='head_i',
      default=0, type='int',
      help='Parameters head to test [Default: %default]')
  parser.add_option('--mc', dest='mc_n',
      default=0, type='int',
      help='Monte carlo test iterations [Default: %default]')
  parser.add_option('--peak','--peaks', dest='peaks',
      default=False, action='store_true',
      help='Compute expensive peak accuracy [Default: %default]')
  parser.add_option('-o', dest='out_dir',
      default='test_out',
      help='Output directory for test statistics [Default: %default]')
  parser.add_option('--rc', dest='rc',
      default=False, action='store_true',
      help='Average the fwd and rc predictions [Default: %default]')
  parser.add_option('--save', dest='save',
      default=False, action='store_true',
      help='Save targets and predictions numpy arrays [Default: %default]')
  parser.add_option('--shifts', dest='shifts',
      default='0',
      help='Ensemble prediction shifts [Default: %default]')
  parser.add_option('-t', dest='targets_file',
      default=None, type='str',
      help='File specifying target indexes and labels in table format')
  parser.add_option('--split', dest='split_label',
      default='test',
      help='Dataset split label for eg TFR pattern [Default: %default]')
  parser.add_option('--tfr', dest='tfr_pattern',
      default=None,
      help='TFR pattern string appended to data_dir/tfrecords for subsetting [Default: %default]')
  (options, args) = parser.parse_args()

  if len(args) != 3:
    parser.error('Must provide parameters, model, and test data HDF5')
  else:
    params_file = args[0]
    model_file = args[1]
    data_dir = args[2]

  if not os.path.isdir(options.out_dir):
    os.mkdir(options.out_dir)

  # parse shifts to integers
  options.shifts = [int(shift) for shift in options.shifts.split(',')]

  #######################################################
  # inputs

  # read targets
  if options.targets_file is None:
    options.targets_file = '%s/targets.txt' % data_dir
  targets_df = pd.read_csv(options.targets_file, index_col=0, sep='\t')

  # read model parameters
  with open(params_file) as params_open:
    params = json.load(params_open)
  params_model = params['model']
  params_train = params['train']

  # construct eval data
  eval_data = dataset.SeqDataset(data_dir,
    split_label=options.split_label,
    batch_size=params_train['batch_size'],
    mode='eval',
    tfr_pattern=options.tfr_pattern)

  # initialize model
  seqnn_model = seqnn.SeqNN(params_model)
  seqnn_model.restore(model_file, options.head_i)
  seqnn_model.build_ensemble(options.rc, options.shifts)

  #######################################################
  # evaluate

  loss_label = params_train.get('loss', 'poisson').lower()
  spec_weight = params_train.get('spec_weight', 1)
  loss_fn = trainer.parse_loss(loss_label, spec_weight=spec_weight)

  # evaluate
  test_loss, test_metric1, test_metric2 = seqnn_model.evaluate(eval_data, loss=loss_fn)

  # print summary statistics
  print('\nTest Loss:         %7.5f' % test_loss)

  if loss_label == 'bce':
    print('Test AUROC:        %7.5f' % test_metric1.mean())
    print('Test AUPRC:        %7.5f' % test_metric2.mean())

    # write target-level statistics
    targets_acc_df = pd.DataFrame({
      'index': targets_df.index,
      'auroc': test_metric1,
      'auprc': test_metric2,
      'identifier': targets_df.identifier,
      'description': targets_df.description
      })

  else:
    print('Test PearsonR:     %7.5f' % test_metric1.mean())
    print('Test R2:           %7.5f' % test_metric2.mean())

    # write target-level statistics
    targets_acc_df = pd.DataFrame({
      'index': targets_df.index,
      'pearsonr': test_metric1,
      'r2': test_metric2,
      'identifier': targets_df.identifier,
      'description': targets_df.description
      })

  targets_acc_df.to_csv('%s/acc.txt'%options.out_dir, sep='\t',
                        index=False, float_format='%.5f')

  #######################################################
  # predict?

  if options.save or options.peaks or options.accuracy_indexes is not None:
    # compute predictions
    test_preds = seqnn_model.predict(eval_data).astype('float16')

    # read targets
    test_targets = eval_data.numpy(return_inputs=False)

  if options.save:
    preds_h5 = h5py.File('%s/preds.h5' % options.out_dir, 'w')
    preds_h5.create_dataset('preds', data=test_preds)
    preds_h5.close()
    targets_h5 = h5py.File('%s/targets.h5' % options.out_dir, 'w')
    targets_h5.create_dataset('targets', data=test_targets)
    targets_h5.close()


  #######################################################
  # peak call accuracy

  if options.peaks:
    peaks_out_file = '%s/peaks.txt' % options.out_dir
    test_peaks(test_preds, test_targets, peaks_out_file)


  #######################################################
  # accuracy plots

  if options.accuracy_indexes is not None:
    accuracy_indexes = [int(ti) for ti in options.accuracy_indexes.split(',')]

    if not os.path.isdir('%s/scatter' % options.out_dir):
      os.mkdir('%s/scatter' % options.out_dir)

    if not os.path.isdir('%s/violin' % options.out_dir):
      os.mkdir('%s/violin' % options.out_dir)

    if not os.path.isdir('%s/roc' % options.out_dir):
      os.mkdir('%s/roc' % options.out_dir)

    if not os.path.isdir('%s/pr' % options.out_dir):
      os.mkdir('%s/pr' % options.out_dir)

    for ti in accuracy_indexes:
      test_targets_ti = test_targets[:, :, ti]

      ############################################
      # scatter

      # sample every few bins (adjust to plot the # points I want)
      ds_indexes = np.arange(0, test_preds.shape[1], 8)

      # subset and flatten
      test_targets_ti_flat = test_targets_ti[:, ds_indexes].flatten(
      ).astype('float32')
      test_preds_ti_flat = test_preds[:, ds_indexes, ti].flatten().astype(
          'float32')

      # take log2
      test_targets_ti_log = np.log2(test_targets_ti_flat + 1)
      test_preds_ti_log = np.log2(test_preds_ti_flat + 1)

      # plot log2
      sns.set(font_scale=1.2, style='ticks')
      out_pdf = '%s/scatter/t%d.pdf' % (options.out_dir, ti)
      plots.regplot(
          test_targets_ti_log,
          test_preds_ti_log,
          out_pdf,
          poly_order=1,
          alpha=0.3,
          sample=500,
          figsize=(6, 6),
          x_label='log2 Experiment',
          y_label='log2 Prediction',
          table=True)

      ############################################
      # violin

      # call peaks
      test_targets_ti_lambda = np.mean(test_targets_ti_flat)
      test_targets_pvals = 1 - poisson.cdf(
          np.round(test_targets_ti_flat) - 1, mu=test_targets_ti_lambda)
      test_targets_qvals = np.array(ben_hoch(test_targets_pvals))
      test_targets_peaks = test_targets_qvals < 0.01
      test_targets_peaks_str = np.where(test_targets_peaks, 'Peak',
                                        'Background')

      # violin plot
      sns.set(font_scale=1.3, style='ticks')
      plt.figure()
      df = pd.DataFrame({
          'log2 Prediction': np.log2(test_preds_ti_flat + 1),
          'Experimental coverage status': test_targets_peaks_str
      })
      ax = sns.violinplot(
          x='Experimental coverage status', y='log2 Prediction', data=df)
      ax.grid(True, linestyle=':')
      plt.savefig('%s/violin/t%d.pdf' % (options.out_dir, ti))
      plt.close()

      # ROC
      plt.figure()
      fpr, tpr, _ = roc_curve(test_targets_peaks, test_preds_ti_flat)
      auroc = roc_auc_score(test_targets_peaks, test_preds_ti_flat)
      plt.plot(
          [0, 1], [0, 1], c='black', linewidth=1, linestyle='--', alpha=0.7)
      plt.plot(fpr, tpr, c='black')
      ax = plt.gca()
      ax.set_xlabel('False positive rate')
      ax.set_ylabel('True positive rate')
      ax.text(
          0.99, 0.02, 'AUROC %.3f' % auroc,
          horizontalalignment='right')  # , fontsize=14)
      ax.grid(True, linestyle=':')
      plt.savefig('%s/roc/t%d.pdf' % (options.out_dir, ti))
      plt.close()

      # PR
      plt.figure()
      prec, recall, _ = precision_recall_curve(test_targets_peaks,
                                               test_preds_ti_flat)
      auprc = average_precision_score(test_targets_peaks, test_preds_ti_flat)
      plt.axhline(
          y=test_targets_peaks.mean(),
          c='black',
          linewidth=1,
          linestyle='--',
          alpha=0.7)
      plt.plot(recall, prec, c='black')
      ax = plt.gca()
      ax.set_xlabel('Recall')
      ax.set_ylabel('Precision')
      ax.text(
          0.99, 0.95, 'AUPRC %.3f' % auprc,
          horizontalalignment='right')  # , fontsize=14)
      ax.grid(True, linestyle=':')
      plt.savefig('%s/pr/t%d.pdf' % (options.out_dir, ti))
      plt.close()