def from_mdir(cls, model_dir): from basepair.seqmodel import SeqModel # TODO - figure out also the fasta_file if present (from dataspec) from basepair.cli.schemas import DataSpec ds_path = os.path.join(model_dir, "dataspec.yaml") if os.path.exists(ds_path): ds = DataSpec.load(ds_path) fasta_file = ds.fasta_file else: fasta_file = None return cls(SeqModel.from_mdir(model_dir), fasta_file=fasta_file)
def load_data(model_dir, cache_data=False, dataspec=None, hparams=None, data=None, preprocessor=None): if dataspec is not None: ds = DataSpec.load(dataspec) else: ds = DataSpec.load(os.path.join(model_dir, "dataspec.yaml")) if hparams is not None: hp = HParams.load(hparams) else: hp = HParams.load(os.path.join(model_dir, "hparams.yaml")) if data is not None: data_path = data else: data_path = os.path.join(model_dir, 'data.pkl') if os.path.exists(data_path): train, valid, test = read_pkl(data_path) else: train, valid, test = chip_exo_nexus(ds, peak_width=hp.data.peak_width, shuffle=hp.data.shuffle, valid_chr=hp.data.valid_chr, test_chr=hp.data.test_chr) # Pre-process the data logger.info("Pre-processing the data") if preprocessor is not None: preproc_path = os.path.join(model_dir, "preprocessor.pkl") else: preproc_path = preprocessor preproc = read_pkl(preproc_path) train[1] = preproc.transform(train[1]) valid[1] = preproc.transform(valid[1]) try: test[1] = preproc.transform(test[1]) except Exception: logger.warn("Test set couldn't be processed") test = None return train, valid, test
def get_StrandedProfile_datasets(dataspec, peak_width=200, seq_width=None, shuffle=True, target_transformer=AppendCounts(), valid_chr=['chr2', 'chr3', 'chr4'], test_chr=['chr1', 'chr8', 'chr9'], all_chr=all_chr, exclude_chr=[], vmtouch=True): # test and valid shouldn't be in the valid or test sets for vc in valid_chr: assert vc not in exclude_chr for vc in test_chr: assert vc not in exclude_chr if isinstance(dataspec, str): dataspec = DataSpec.load(dataspec) if vmtouch: # use vmtouch to load all file to memory dataspec.touch_all_files() return ( StrandedProfile( dataspec, peak_width, seq_width=seq_width, # Only include chromosomes from `all_chr` incl_chromosomes=[ c for c in all_chr if c not in valid_chr + test_chr + exclude_chr ], excl_chromosomes=valid_chr + test_chr + exclude_chr, shuffle=shuffle, target_transformer=target_transformer), StrandedProfile(dataspec, peak_width, seq_width=seq_width, incl_chromosomes=valid_chr, shuffle=shuffle, target_transformer=target_transformer), StrandedProfile(dataspec, peak_width, seq_width=seq_width, incl_chromosomes=test_chr, shuffle=shuffle, target_transformer=target_transformer))
def dataspec(): return DataSpec(bigwigs={ "task1": TaskSpec( task='task1', pos_counts=f"{ddir}/pos.bw", neg_counts=f"{ddir}/neg.bw", ), "task2": TaskSpec( task='task2', pos_counts=f"{ddir}/pos.bw", neg_counts=f"{ddir}/neg.bw", ) }, fasta_file=f"{ddir}/ref.fa", peaks=f"{ddir}/peaks.bed")
def from_mdir(cls, model_dir): """ Args: model_dir (str): Path to the model directory """ import os from basepair.cli.schemas import DataSpec from keras.models import load_model from basepair.utils import read_pkl ds = DataSpec.load(os.path.join(model_dir, "dataspec.yaml")) model = load_model(os.path.join(model_dir, "model.h5")) preproc_file = os.path.join(model_dir, "preprocessor.pkl") if os.path.exists(preproc_file): preproc = read_pkl(preproc_file) else: preproc = None return cls(model=model, fasta_file=ds.fasta_file, tasks=list(ds.task_specs), preproc=preproc)
def __init__(self, ds, peak_width=200, seq_width=None, incl_chromosomes=None, excl_chromosomes=None, intervals_file=None, bcolz=False, in_memory=False, include_metadata=True, taskname_first=False, tasks=None, include_classes=False, only_classes=False, shuffle=True, interval_transformer=None, target_transformer=None, profile_bias_pool_size=None): """Dataset for loading the bigwigs and fastas Args: ds (basepair.src.schemas.DataSpec): data specification containing the fasta file, bed files and bigWig file paths chromosomes (list of str): a list of chor peak_width: resize the bed file to a certain width intervals_file: if specified, use these regions to train the model. If not specified, the regions are inferred from the dataspec. only_classes: if True, load only classes bcolz: If True, the bigwig/fasta files are in the genomelake bcolz format in_memory: If True, load the whole bcolz into memory. Only applicable when bcolz=True shuffle: True preprocessor: trained preprocessor object containing the .transform methods """ if isinstance(ds, str): self.ds = DataSpec.load(ds) else: self.ds = ds self.peak_width = peak_width if seq_width is None: self.seq_width = peak_width else: self.seq_width = seq_width self.shuffle = shuffle self.intervals_file = intervals_file self.incl_chromosomes = incl_chromosomes self.excl_chromosomes = excl_chromosomes self.target_transformer = target_transformer self.include_classes = include_classes self.only_classes = only_classes self.taskname_first = taskname_first if self.only_classes: assert self.include_classes self.profile_bias_pool_size = profile_bias_pool_size # not specified yet self.fasta_extractor = None self.bw_extractors = None self.bias_bw_extractors = None self.include_metadata = include_metadata self.interval_transformer = interval_transformer self.bcolz = bcolz self.in_memory = in_memory if not self.bcolz and self.in_memory: raise ValueError( "in_memory option only applicable when bcolz=True") # Load chromosome lengths if self.bcolz: p = json.loads( (Path(self.ds.fasta_file) / "metadata.json").read_text()) self.chrom_lens = {c: v[0] for c, v in p['file_shapes'].items()} else: fa = FastaFile(self.ds.fasta_file) self.chrom_lens = { name: l for name, l in zip(fa.references, fa.lengths) } if len(self.chrom_lens) == 0: raise ValueError( f"no chromosomes found in fasta file: {self.ds.fasta_file}. " "Make sure the file path is correct and that the fasta index file {self.ds.fasta_file}.fai is up to date" ) del fa if self.intervals_file is None: self.dfm = load_beds(bed_files={ task: task_spec.peaks for task, task_spec in self.ds.task_specs.items() if task_spec.peaks is not None }, chromosome_lens=self.chrom_lens, excl_chromosomes=self.excl_chromosomes, incl_chromosomes=self.incl_chromosomes, resize_width=max(self.peak_width, self.seq_width)) assert list( self.dfm.columns)[:4] == ["chrom", "start", "end", "task"] if self.shuffle: self.dfm = self.dfm.sample(frac=1) self.tsv = None self.dfm_tasks = None else: self.tsv = TsvReader(self.intervals_file, num_chr=False, label_dtype=int, mask_ambigous=-1, incl_chromosomes=incl_chromosomes, excl_chromosomes=excl_chromosomes, chromosome_lens=self.chrom_lens, resize_width=max(self.peak_width, self.seq_width)) if self.shuffle: self.tsv.shuffle_inplace() self.dfm = self.tsv.df # use the data-frame from tsv self.dfm_tasks = self.tsv.get_target_names() # remember the tasks if tasks is None: self.tasks = list(self.ds.task_specs) else: self.tasks = tasks if self.bcolz and self.in_memory: self.fasta_extractor = ArrayExtractor(self.ds.fasta_file, in_memory=True) self.bw_extractors = { task: [ ArrayExtractor(task_spec.pos_counts, in_memory=True), ArrayExtractor(task_spec.neg_counts, in_memory=True) ] for task, task_spec in self.ds.task_specs.items() if task in self.tasks } self.bias_bw_extractors = { task: [ ArrayExtractor(task_spec.pos_counts, in_memory=True), ArrayExtractor(task_spec.neg_counts, in_memory=True) ] for task, task_spec in self.ds.bias_specs.items() if task in self.tasks } if self.include_classes: assert self.dfm_tasks is not None if self.dfm_tasks is not None: assert set(self.tasks).issubset(self.dfm_tasks) # setup bias maps per task self.task_bias_tracks = { task: [ bias for bias, spec in self.ds.bias_specs.items() if task in spec.tasks ] for task in self.tasks }
def get_gw_StrandedProfile_datasets(dataspec, intervals_file=None, peak_width=200, seq_width=None, shuffle=True, target_transformer=AppendCounts(), include_metadata=False, taskname_first=False, include_classes=False, only_classes=False, tasks=None, valid_chr=['chr2', 'chr3', 'chr4'], test_chr=['chr1', 'chr8', 'chr9'], exclude_chr=[], vmtouch=True, profile_bias_pool_size=None): # NOTE = only chromosomes from chr1-22 and chrX and chrY are considered here # (e.g. all other chromosomes like ChrUn... are omitted) from basepair.metrics import BPNetMetric, PeakPredictionProfileMetric, pearson_spearman # test and valid shouldn't be in the valid or test sets for vc in valid_chr: assert vc not in exclude_chr for vc in test_chr: assert vc not in exclude_chr dataspec = DataSpec.load(dataspec) if vmtouch: # use vmtouch to load all file to memory dataspec.touch_all_files() if tasks is None: tasks = list(dataspec.task_specs) train = StrandedProfile( dataspec, peak_width, seq_width=seq_width, intervals_file=intervals_file, include_metadata=include_metadata, taskname_first=taskname_first, include_classes=include_classes, only_classes=only_classes, tasks=tasks, incl_chromosomes=[ c for c in all_chr if c not in valid_chr + test_chr + exclude_chr ], excl_chromosomes=valid_chr + test_chr + exclude_chr, shuffle=shuffle, target_transformer=target_transformer, profile_bias_pool_size=profile_bias_pool_size) valid = [('train-valid-genome-wide', StrandedProfile(dataspec, peak_width, seq_width=seq_width, intervals_file=intervals_file, include_metadata=include_metadata, include_classes=include_classes, only_classes=only_classes, taskname_first=taskname_first, tasks=tasks, incl_chromosomes=valid_chr, shuffle=shuffle, target_transformer=target_transformer, profile_bias_pool_size=profile_bias_pool_size))] if include_classes: # Only use binary classification for genome-wide evaluation valid = valid + [('valid-genome-wide', StrandedProfile( dataspec, peak_width, seq_width=seq_width, intervals_file=intervals_file, include_metadata=include_metadata, include_classes=True, only_classes=True, taskname_first=taskname_first, tasks=tasks, incl_chromosomes=valid_chr, shuffle=shuffle, target_transformer=target_transformer, profile_bias_pool_size=profile_bias_pool_size))] if not only_classes: # Add also the peak regions valid = valid + [ ( 'valid-peaks', StrandedProfile( dataspec, peak_width, seq_width=seq_width, intervals_file=None, include_metadata=include_metadata, taskname_first=taskname_first, tasks=tasks, include_classes=False, # dataspec doesn't contain labels only_classes=only_classes, incl_chromosomes=valid_chr, shuffle=shuffle, target_transformer=target_transformer, profile_bias_pool_size=profile_bias_pool_size)), ( 'train-peaks', StrandedProfile( dataspec, peak_width, seq_width=seq_width, intervals_file=None, include_metadata=include_metadata, taskname_first=taskname_first, tasks=tasks, include_classes=False, # dataspec doesn't contain labels only_classes=only_classes, incl_chromosomes=[ c for c in all_chr if c not in valid_chr + test_chr + exclude_chr ], excl_chromosomes=valid_chr + test_chr + exclude_chr, shuffle=shuffle, target_transformer=target_transformer, profile_bias_pool_size=profile_bias_pool_size)), # use the default metric for the peak sets ] return train, valid
def get_StrandedProfile_datasets2(dataspec, peak_width=200, intervals_file=None, seq_width=None, shuffle=True, target_transformer=AppendCounts(), include_metadata=False, valid_chr=['chr2', 'chr3', 'chr4'], test_chr=['chr1', 'chr8', 'chr9'], tasks=None, taskname_first=False, exclude_chr=[], augment_interval=False, interval_augmentation_shift=200, vmtouch=True, profile_bias_pool_size=None): from basepair.metrics import BPNetMetric, PeakPredictionProfileMetric, pearson_spearman # test and valid shouldn't be in the valid or test sets for vc in valid_chr: assert vc not in exclude_chr for vc in test_chr: assert vc not in exclude_chr dataspec = DataSpec.load(dataspec) if vmtouch: # use vmtouch to load all file to memory dataspec.touch_all_files() if tasks is None: tasks = list(dataspec.task_specs) if augment_interval: interval_transformer = IntervalAugmentor( max_shift=interval_augmentation_shift, flip_strand=True) else: interval_transformer = None return (StrandedProfile( dataspec, peak_width, intervals_file=intervals_file, seq_width=seq_width, include_metadata=include_metadata, incl_chromosomes=[ c for c in all_chr if c not in valid_chr + test_chr + exclude_chr ], excl_chromosomes=valid_chr + test_chr + exclude_chr, tasks=tasks, taskname_first=taskname_first, shuffle=shuffle, target_transformer=target_transformer, interval_transformer=interval_transformer, profile_bias_pool_size=profile_bias_pool_size), [ ('valid-peaks', StrandedProfile(dataspec, peak_width, intervals_file=intervals_file, seq_width=seq_width, include_metadata=include_metadata, incl_chromosomes=valid_chr, tasks=tasks, taskname_first=taskname_first, interval_transformer=interval_transformer, shuffle=shuffle, target_transformer=target_transformer, profile_bias_pool_size=profile_bias_pool_size)), ('train-peaks', StrandedProfile(dataspec, peak_width, intervals_file=intervals_file, seq_width=seq_width, include_metadata=include_metadata, incl_chromosomes=[ c for c in all_chr if c not in valid_chr + test_chr + exclude_chr ], excl_chromosomes=valid_chr + test_chr + exclude_chr, tasks=tasks, taskname_first=taskname_first, interval_transformer=interval_transformer, shuffle=shuffle, target_transformer=target_transformer, profile_bias_pool_size=profile_bias_pool_size)), ])
def evaluate(model_dir, output_dir=None, gpu=0, exclude_metrics=False, splits=['train', 'valid'], model_path=None, data=None, hparams=None, dataspec=None, preprocessor=None): """ Args: model_dir: path to the model directory splits: For which data splits to compute the evaluation metrics model_metrics: if True, metrics computed using mode.evaluate(..) """ if gpu is not None: create_tf_session(gpu) if dataspec is not None: ds = DataSpec.load(dataspec) else: ds = DataSpec.load(os.path.join(model_dir, "dataspec.yaml")) if hparams is not None: hp = HParams.load(hparams) else: hp = HParams.load(os.path.join(model_dir, "hparams.yaml")) if model_path is not None: model = load_model(model_path) else: model = load_model(os.path.join(model_dir, "model.h5")) if output_dir is None: output_dir = os.path.join(model_dir, "eval") train, valid, test = load_data(model_dir, dataspec=dataspec, hparams=hparams, data=data, preprocessor=preprocessor) data = dict(train=train, valid=valid, test=test) metrics = {} profile_metrics = [] os.makedirs(os.path.join(output_dir, "plots"), exist_ok=True) for split in tqdm(splits): y_pred = model.predict(data[split][0]) y_true = data[split][1] if not exclude_metrics: eval_metrics_values = model.evaluate(data[split][0], data[split][1]) eval_metrics = dict(zip(_listify(model.metrics_names), _listify(eval_metrics_values))) eval_metrics = {split + "/" + k.replace("_", "/"): v for k, v in eval_metrics.items()} metrics = {**eval_metrics, **metrics} for task in ds.task_specs: # Counts yp = y_pred[ds.task2idx(task, "counts")].sum(axis=-1) yt = y_true["counts/" + task].sum(axis=-1) # compute the correlation rp = pearsonr(yt, yp)[0] rs = spearmanr(yt, yp)[0] metrics = {**metrics, split + f"/counts/{task}/pearsonr": rp, split + f"/counts/{task}/spearmanr": rs, } fig = plt.figure(figsize=(5, 5)) plt.scatter(yp, yt, alpha=0.5) plt.xlabel("Predicted") plt.ylabel("Observed") plt.title(f"R_pearson={rp:.2f}, R_spearman={rs:.2f}") plt.savefig(os.path.join(output_dir, f"plots/counts.{split}.{task}.png")) # Profile yp = softmax(y_pred[ds.task2idx(task, "profile")]) yt = y_true["profile/" + task] df = eval_profile(yt, yp, pos_min_threshold=hp.evaluate.pos_min_threshold, neg_max_threshold=hp.evaluate.neg_max_threshold, required_min_pos_counts=hp.evaluate.required_min_pos_counts, binsizes=hp.evaluate.binsizes) df['task'] = task df['split'] = split # Evaluate for the smallest binsize auprc_min = df[df.binsize == min(hp.evaluate.binsizes)].iloc[0].auprc metrics[split + f'/profile/{task}/auprc'] = auprc_min profile_metrics.append(df) # Write the count metrics write_json(metrics, os.path.join(output_dir, "metrics.json")) # write the profile metrics dfm = pd.concat(profile_metrics) dfm.to_csv(os.path.join(output_dir, "profile_metrics.tsv"), sep='\t', index=False) return dfm, metrics
def imp_score(model_dir, output_file, method="grad", split='all', batch_size=512, num_workers=10, h5_chunk_size=512, max_batches=-1, shuffle_seq=False, memfrac=0.45, exclude_chr='', overwrite=False, gpu=None): """Run importance scores for a BPNet model Args: model_dir: path to the model directory output_file: output file path (HDF5 format) method: which importance scoring method to use ('grad', 'deeplift' or 'ism') split: for which dataset split to compute the importance scores h5_chunk_size: hdf5 chunk size. exclude_chr: comma-separated list of chromosomes to exclude overwrite: if True, overwrite the output directory gpu (int): which GPU to use locally. If None, GPU is not used """ add_file_logging(os.path.dirname(output_file), logger, 'modisco-score') if gpu is not None: create_tf_session(gpu, per_process_gpu_memory_fraction=memfrac) else: # Don't use any GPU's os.environ['CUDA_VISIBLE_DEVICES'] = '' if os.path.exists(output_file): if overwrite: os.remove(output_file) else: raise ValueError(f"File exists {output_file}. Use overwrite=True to overwrite it") if exclude_chr: exclude_chr = exclude_chr.split(",") else: exclude_chr = [] # load the config files logger.info("Loading the config files") model_dir = Path(model_dir) hp = HParams.load(model_dir / "hparams.yaml") ds = DataSpec.load(model_dir / "dataspec.yaml") tasks = list(ds.task_specs) # validate that the correct dataset was used if hp.data.name != 'get_StrandedProfile_datasets': logger.warn("hp.data.name != 'get_StrandedProfile_datasets'") if split == 'valid': assert len(exclude_chr) == 0 incl_chromosomes = hp.data.kwargs['valid_chr'] excl_chromosomes = None elif split == 'test': assert len(exclude_chr) == 0 incl_chromosomes = hp.data.kwargs['test_chr'] excl_chromosomes = None elif split == 'train': assert len(exclude_chr) == 0 incl_chromosomes = None excl_chromosomes = hp.data.kwargs['valid_chr'] + hp.data.kwargs['test_chr'] + hp.data.kwargs.get('exclude_chr', []) elif split == 'all': incl_chromosomes = None excl_chromosomes = hp.data.kwargs.get('exclude_chr', []) + exclude_chr logger.info("Excluding chromosomes: {excl_chromosomes}") else: raise ValueError("split needs to be from {train,valid,test,all}") logger.info("Creating the dataset") from basepair.datasets import StrandedProfile seq_len = hp.data.kwargs['peak_width'] dl_valid = StrandedProfile(ds, incl_chromosomes=incl_chromosomes, excl_chromosomes=excl_chromosomes, peak_width=seq_len, shuffle=False, target_transformer=None) bpnet = BPNet.from_mdir(model_dir) writer = HDF5BatchWriter(output_file, chunk_size=h5_chunk_size) for i, batch in enumerate(tqdm(dl_valid.batch_iter(batch_size=batch_size, num_workers=num_workers))): if max_batches > 0: logging.info(f"max_batches: {max_batches} exceeded. Stopping the computation") if i > max_batches: break # append the bias model predictions # (batch['inputs'], batch['targets']) = bm((batch['inputs'], batch['targets'])) # store the original batch containing 'inputs' and 'targets' wdict = batch if shuffle_seq: # Di-nucleotide shuffle the sequences if 'seq' in batch['inputs']: batch['inputs']['seq'] = onehot_dinucl_shuffle(batch['inputs']['seq']) else: batch['inputs'] = onehot_dinucl_shuffle(batch['inputs']) # loop through all tasks, pred_summary and strands for task_i, task in enumerate(tasks): for pred_summary in ['count', 'weighted']: # figure out the number of channels nstrands = batch['targets'][f'profile/{task}'].shape[-1] strand_hash = ["pos", "neg"] for strand_i in range(nstrands): hyp_imp = bpnet.imp_score(batch['inputs'], task=task, strand=strand_hash[strand_i], method=method, pred_summary=pred_summary, batch_size=None) # don't second-batch # put importance scores to the dictionary wdict[f"/hyp_imp/{task}/{pred_summary}/{strand_i}"] = hyp_imp writer.batch_write(wdict) writer.close()
def imp_score_seqmodel(model_dir, output_file, dataspec=None, peak_width=1000, seq_width=None, intp_pattern='*', # specifies which imp. scores to compute # skip_trim=False, # skip trimming the output method="deeplift", batch_size=512, max_batches=-1, shuffle_seq=False, memfrac=0.45, num_workers=10, h5_chunk_size=512, exclude_chr='', include_chr='', overwrite=False, skip_bias=False, gpu=None): """Run importance scores for a BPNet model Args: model_dir: path to the model directory output_file: output file path (HDF5 format) method: which importance scoring method to use ('grad', 'deeplift' or 'ism') split: for which dataset split to compute the importance scores h5_chunk_size: hdf5 chunk size. exclude_chr: comma-separated list of chromosomes to exclude overwrite: if True, overwrite the output directory skip_bias: if True, don't store the bias tracks in teh output gpu (int): which GPU to use locally. If None, GPU is not used """ add_file_logging(os.path.dirname(output_file), logger, 'modisco-score-seqmodel') if gpu is not None: create_tf_session(gpu, per_process_gpu_memory_fraction=memfrac) else: # Don't use any GPU's os.environ['CUDA_VISIBLE_DEVICES'] = '' if os.path.exists(output_file): if overwrite: os.remove(output_file) else: raise ValueError(f"File exists {output_file}. Use overwrite=True to overwrite it") if seq_width is None: logger.info("Using seq_width = peak_width") seq_width = peak_width # make sure these are int's seq_width = int(seq_width) peak_width = int(peak_width) # Split intp_patterns = intp_pattern.split(",") # Allow chr inclusion / exclusion if exclude_chr: exclude_chr = exclude_chr.split(",") else: exclude_chr = None if include_chr: include_chr = include_chr.split(",") else: include_chr = None logger.info("Loading the config files") model_dir = Path(model_dir) if dataspec is None: # Specify dataspec dataspec = model_dir / "dataspec.yaml" ds = DataSpec.load(dataspec) logger.info("Creating the dataset") from basepair.datasets import StrandedProfile dl_valid = StrandedProfile(ds, incl_chromosomes=include_chr, excl_chromosomes=exclude_chr, peak_width=peak_width, seq_width=seq_width, shuffle=False, taskname_first=True, # Required to work nicely with imp-score target_transformer=None) # Setup importance score trimming if seq_width > peak_width: # Trim # make sure we can nicely trim the peak logger.info("Trimming the output") assert (seq_width - peak_width) % 2 == 0 trim_start = (seq_width - peak_width) // 2 trim_end = seq_width - trim_start assert trim_end - trim_start == peak_width elif seq_width == peak_width: trim_start = 0 trim_end = peak_width else: raise ValueError("seq_width < peak_width") seqmodel = SeqModel.from_mdir(model_dir) # get all possible interpretation names # make sure they match the specified glob intp_names = [name for name, _ in seqmodel.get_intp_tensors(preact_only=False) if fnmatch_any(name, intp_patterns)] logger.info(f"Using the following interpretation targets:") for n in intp_names: print(n) writer = HDF5BatchWriter(output_file, chunk_size=h5_chunk_size) for i, batch in enumerate(tqdm(dl_valid.batch_iter(batch_size=batch_size, num_workers=num_workers))): # store the original batch containing 'inputs' and 'targets' wdict = batch if skip_bias: wdict['inputs'] = {'seq': wdict['inputs']['seq']} # ignore all other inputs if max_batches > 0: logging.info(f"max_batches: {max_batches} exceeded. Stopping the computation") if i > max_batches: break if shuffle_seq: # Di-nucleotide shuffle the sequences batch['inputs']['seq'] = onehot_dinucl_shuffle(batch['inputs']['seq']) for name in intp_names: hyp_imp = seqmodel.imp_score(batch['inputs']['seq'], name=name, method=method, batch_size=None) # don't second-batch # put importance scores to the dictionary # also trim the importance scores appropriately so that # the output will always be w.r.t. the peak center wdict[f"/hyp_imp/{name}"] = hyp_imp[:, trim_start:trim_end] # trim the sequence as well if isinstance(wdict['inputs'], dict): # Trim the sequence wdict['inputs']['seq'] = wdict['inputs']['seq'][:, trim_start:trim_end] else: wdict['inputs'] = wdict['inputs'][:, trim_start:trim_end] writer.batch_write(wdict) writer.close()
isf = ImpScoreFile(models_dir / exp / 'deeplift.imp_score.h5', default_imp_score=imp_score) dfi_subset = pd.read_parquet( models_dir / exp / "deeplift/dfi_subset.parq", engine='fastparquet').assign(model=model_name).assign(exp=exp) mr = MultipleModiscoResult({ t: models_dir / exp / f'deeplift/{t}/out/{imp_score}/modisco.h5' for t in get_tasks(exp) }) return isf, dfi_subset, mr # load DataSpec # old config rdir = get_repo_root() dataspec_file = rdir / "src/chipnexus/train/seqmodel/ChIP-nexus.dataspec.yml" ds = DataSpec.load(dataspec_file) # Load all files into the page cache logger.info("Touch all files") ds.touch_all_files() # -------------------------------------------- # ### nexus/profile model_name = 'nexus/profile' imp_score = 'profile/wn' exp = models[model_name] isf, dfi_subset, mr = load_data(model_name, imp_score, exp) ranges_profile = isf.get_ranges() profiles = isf.get_profiles() # TODO - fix seqlets dfi_list.append(