예제 #1
0
def load_data(model_dir, cache_data=False,
              dataspec=None, hparams=None, data=None, preprocessor=None):
    if dataspec is not None:
        ds = DataSpec.load(dataspec)
    else:
        ds = DataSpec.load(os.path.join(model_dir, "dataspec.yaml"))
    if hparams is not None:
        hp = HParams.load(hparams)
    else:
        hp = HParams.load(os.path.join(model_dir, "hparams.yaml"))
    if data is not None:
        data_path = data
    else:
        data_path = os.path.join(model_dir, 'data.pkl')
    if os.path.exists(data_path):
        train, valid, test = read_pkl(data_path)
    else:
        train, valid, test = chip_exo_nexus(ds,
                                            peak_width=hp.data.peak_width,
                                            shuffle=hp.data.shuffle,
                                            valid_chr=hp.data.valid_chr,
                                            test_chr=hp.data.test_chr)

        # Pre-process the data
        logger.info("Pre-processing the data")
        if preprocessor is not None:
            preproc_path = os.path.join(model_dir, "preprocessor.pkl")
        else:
            preproc_path = preprocessor
        preproc = read_pkl(preproc_path)
        train[1] = preproc.transform(train[1])
        valid[1] = preproc.transform(valid[1])
        try:
            test[1] = preproc.transform(test[1])
        except Exception:
            logger.warn("Test set couldn't be processed")
            test = None
    return train, valid, test
예제 #2
0
def imp_score(model_dir,
              output_file,
              method="grad",
              split='all',
              batch_size=512,
              num_workers=10,
              h5_chunk_size=512,
              max_batches=-1,
              shuffle_seq=False,
              memfrac=0.45,
              exclude_chr='',
              overwrite=False,
              gpu=None):
    """Run importance scores for a BPNet model

    Args:
      model_dir: path to the model directory
      output_file: output file path (HDF5 format)
      method: which importance scoring method to use ('grad', 'deeplift' or 'ism')
      split: for which dataset split to compute the importance scores
      h5_chunk_size: hdf5 chunk size.
      exclude_chr: comma-separated list of chromosomes to exclude
      overwrite: if True, overwrite the output directory
      gpu (int): which GPU to use locally. If None, GPU is not used
    """
    add_file_logging(os.path.dirname(output_file), logger, 'modisco-score')
    if gpu is not None:
        create_tf_session(gpu, per_process_gpu_memory_fraction=memfrac)
    else:
        # Don't use any GPU's
        os.environ['CUDA_VISIBLE_DEVICES'] = ''

    if os.path.exists(output_file):
        if overwrite:
            os.remove(output_file)
        else:
            raise ValueError(f"File exists {output_file}. Use overwrite=True to overwrite it")

    if exclude_chr:
        exclude_chr = exclude_chr.split(",")
    else:
        exclude_chr = []
    # load the config files
    logger.info("Loading the config files")
    model_dir = Path(model_dir)
    hp = HParams.load(model_dir / "hparams.yaml")
    ds = DataSpec.load(model_dir / "dataspec.yaml")
    tasks = list(ds.task_specs)
    # validate that the correct dataset was used
    if hp.data.name != 'get_StrandedProfile_datasets':
        logger.warn("hp.data.name != 'get_StrandedProfile_datasets'")

    if split == 'valid':
        assert len(exclude_chr) == 0
        incl_chromosomes = hp.data.kwargs['valid_chr']
        excl_chromosomes = None
    elif split == 'test':
        assert len(exclude_chr) == 0
        incl_chromosomes = hp.data.kwargs['test_chr']
        excl_chromosomes = None
    elif split == 'train':
        assert len(exclude_chr) == 0
        incl_chromosomes = None
        excl_chromosomes = hp.data.kwargs['valid_chr'] + hp.data.kwargs['test_chr'] + hp.data.kwargs.get('exclude_chr', [])
    elif split == 'all':
        incl_chromosomes = None
        excl_chromosomes = hp.data.kwargs.get('exclude_chr', []) + exclude_chr
        logger.info("Excluding chromosomes: {excl_chromosomes}")
    else:
        raise ValueError("split needs to be from {train,valid,test,all}")

    logger.info("Creating the dataset")
    from basepair.datasets import StrandedProfile
    seq_len = hp.data.kwargs['peak_width']
    dl_valid = StrandedProfile(ds,
                               incl_chromosomes=incl_chromosomes,
                               excl_chromosomes=excl_chromosomes,
                               peak_width=seq_len,
                               shuffle=False,
                               target_transformer=None)

    bpnet = BPNet.from_mdir(model_dir)

    writer = HDF5BatchWriter(output_file, chunk_size=h5_chunk_size)
    for i, batch in enumerate(tqdm(dl_valid.batch_iter(batch_size=batch_size, num_workers=num_workers))):
        if max_batches > 0:
            logging.info(f"max_batches: {max_batches} exceeded. Stopping the computation")
            if i > max_batches:
                break
        # append the bias model predictions
        # (batch['inputs'], batch['targets']) = bm((batch['inputs'], batch['targets']))

        # store the original batch containing 'inputs' and 'targets'
        wdict = batch

        if shuffle_seq:
            # Di-nucleotide shuffle the sequences
            if 'seq' in batch['inputs']:
                batch['inputs']['seq'] = onehot_dinucl_shuffle(batch['inputs']['seq'])
            else:
                batch['inputs'] = onehot_dinucl_shuffle(batch['inputs'])

        # loop through all tasks, pred_summary and strands
        for task_i, task in enumerate(tasks):
            for pred_summary in ['count', 'weighted']:
                # figure out the number of channels
                nstrands = batch['targets'][f'profile/{task}'].shape[-1]
                strand_hash = ["pos", "neg"]

                for strand_i in range(nstrands):
                    hyp_imp = bpnet.imp_score(batch['inputs'],
                                              task=task,
                                              strand=strand_hash[strand_i],
                                              method=method,
                                              pred_summary=pred_summary,
                                              batch_size=None)  # don't second-batch
                    # put importance scores to the dictionary
                    wdict[f"/hyp_imp/{task}/{pred_summary}/{strand_i}"] = hyp_imp
        writer.batch_write(wdict)
    writer.close()
예제 #3
0
def evaluate(model_dir,
             output_dir=None,
             gpu=0,
             exclude_metrics=False,
             splits=['train', 'valid'],
             model_path=None,
             data=None,
             hparams=None,
             dataspec=None,
             preprocessor=None):
    """
    Args:
      model_dir: path to the model directory
      splits: For which data splits to compute the evaluation metrics
      model_metrics: if True, metrics computed using mode.evaluate(..)
    """
    if gpu is not None:
        create_tf_session(gpu)
    if dataspec is not None:
        ds = DataSpec.load(dataspec)
    else:
        ds = DataSpec.load(os.path.join(model_dir, "dataspec.yaml"))
    if hparams is not None:
        hp = HParams.load(hparams)
    else:
        hp = HParams.load(os.path.join(model_dir, "hparams.yaml"))
    if model_path is not None:
        model = load_model(model_path)
    else:
        model = load_model(os.path.join(model_dir, "model.h5"))
    if output_dir is None:
        output_dir = os.path.join(model_dir, "eval")
    train, valid, test = load_data(model_dir,
                                   dataspec=dataspec,
                                   hparams=hparams,
                                   data=data,
                                   preprocessor=preprocessor)
    data = dict(train=train, valid=valid, test=test)

    metrics = {}
    profile_metrics = []
    os.makedirs(os.path.join(output_dir, "plots"),
                exist_ok=True)
    for split in tqdm(splits):
        y_pred = model.predict(data[split][0])
        y_true = data[split][1]
        if not exclude_metrics:
            eval_metrics_values = model.evaluate(data[split][0],
                                                 data[split][1])
            eval_metrics = dict(zip(_listify(model.metrics_names),
                                    _listify(eval_metrics_values)))
            eval_metrics = {split + "/" + k.replace("_", "/"): v
                            for k, v in eval_metrics.items()}
            metrics = {**eval_metrics, **metrics}
        for task in ds.task_specs:
            # Counts

            yp = y_pred[ds.task2idx(task, "counts")].sum(axis=-1)
            yt = y_true["counts/" + task].sum(axis=-1)
            # compute the correlation
            rp = pearsonr(yt, yp)[0]
            rs = spearmanr(yt, yp)[0]
            metrics = {**metrics,
                       split + f"/counts/{task}/pearsonr": rp,
                       split + f"/counts/{task}/spearmanr": rs,
                       }

            fig = plt.figure(figsize=(5, 5))
            plt.scatter(yp, yt, alpha=0.5)
            plt.xlabel("Predicted")
            plt.ylabel("Observed")
            plt.title(f"R_pearson={rp:.2f}, R_spearman={rs:.2f}")
            plt.savefig(os.path.join(output_dir, f"plots/counts.{split}.{task}.png"))

            # Profile
            yp = softmax(y_pred[ds.task2idx(task, "profile")])
            yt = y_true["profile/" + task]
            df = eval_profile(yt, yp,
                              pos_min_threshold=hp.evaluate.pos_min_threshold,
                              neg_max_threshold=hp.evaluate.neg_max_threshold,
                              required_min_pos_counts=hp.evaluate.required_min_pos_counts,
                              binsizes=hp.evaluate.binsizes)
            df['task'] = task
            df['split'] = split
            # Evaluate for the smallest binsize
            auprc_min = df[df.binsize == min(hp.evaluate.binsizes)].iloc[0].auprc
            metrics[split + f'/profile/{task}/auprc'] = auprc_min
            profile_metrics.append(df)

    # Write the count metrics
    write_json(metrics, os.path.join(output_dir, "metrics.json"))

    # write the profile metrics
    dfm = pd.concat(profile_metrics)
    dfm.to_csv(os.path.join(output_dir, "profile_metrics.tsv"),
               sep='\t', index=False)
    return dfm, metrics
예제 #4
0
def hparams():
    return HParams(train=TrainHParams(epochs=2, batch_size=2),
                   data=DataHParams(valid_chr=['chr1'],
                                    test_chr=[],
                                    peak_width=10))