def test_BatchDataset(data): # BatchDataset example: class MyBatchDataset(BatchDataset): def __init__(self, data, batch_size=3): self.data = data self.batch_size = batch_size def __len__(self): return int(np.ceil(self.data["targets"].shape[0] / self.batch_size)) def __getitem__(self, idx): start = idx * self.batch_size end = min((idx + 1) * self.batch_size, self.data["targets"].shape[0]) return get_dataset_item(self.data, np.arange(start, end)) # ------------------------ d = MyBatchDataset(data) compare_arrays(d.load_all(), data) it = d.batch_iter() compare_arrays(next(it), get_dataset_item(data, np.arange(3))) # batch_train_iter d = MyBatchDataset(data, batch_size=2) it = d.batch_train_iter() for i in range(6): x, y = next(it) compare_arrays_x(x, get_dataset_item(data, np.arange(2))['inputs']) compare_arrays_y(y, get_dataset_item(data, np.arange(2))['targets'])
def predict_all(self, seq, imp_method='grad', batch_size=512, pred_summaries=['weighted', 'count']): """Make model prediction based """ if self.bias_model is not None: seq, = self.bias_model.predict((seq, ), batch_size) preds = self.predict(seq, batch_size=batch_size) if imp_method is not None: imp_scores = self.imp_score_all(seq, method=imp_method, aggregate_strand=True, batch_size=batch_size, pred_summaries=pred_summaries) else: imp_scores = dict() out = [ dict( seq=get_dataset_item(seq, i), # interval=intervals[i], pred=get_dataset_item(preds, i), # TODO - shall we call it hyp_imp score or imp_score? imp_score=get_dataset_item(imp_scores, i), ) for i in range(len(seq)) ] return out
def test_Dataset(data): # Dataset example: class MyDataset(Dataset): def __init__(self, data): self.data = data def __len__(self): return self.data["targets"].shape[0] def __getitem__(self, idx): return get_dataset_item(self.data, idx) # ------------------------ d = MyDataset(data) compare_arrays(d.load_all(), data) it = d.batch_iter(3) compare_arrays(next(it), get_dataset_item(data, np.arange(3))) # test batch_train_iter it = d.batch_train_iter(batch_size=2) for i in range(6): x, y = next(it) compare_arrays_x(x, get_dataset_item(data, np.arange(2))['inputs']) compare_arrays_y(y, get_dataset_item(data, np.arange(2))['targets'])
def score(self, input_batch): """ Args: input_batch: Input batch that should be scored. Returns: A list of length: len(`scores`). Every element of the list is a stacked list of depth D if the model input is D-dimensional with identcal shape. Every entry of that list then contains the scores of the model output selected by `output_sel_fn`. Values are `None` if the input_batch already had a `1` at that position. """ ref = self.model.predict_on_batch(input_batch) scores = [] for sample_i in range( len(get_model_input(input_batch, self.model_input))): # get the full set of model inputs for the selected sample sample_set = get_dataset_item(input_batch, sample_i) # get the reference output for this sample ref_sample_pred = get_dataset_item(ref, sample_i) # Apply the output selection function if defined if self.output_sel_fn is not None: ref_sample_pred = self.output_sel_fn(ref_sample_pred) # get the one-hot encoded reference input array input_sample = get_model_input(sample_set, input_id=self.model_input) # where we keep the scores - scores are lists (ordered by diff # method of ndarrays, lists or dictionaries - whatever is returned by the model score = np.empty(input_sample.shape, dtype=object) score[:] = None for alt_batch, alt_idxs in self._mutate_sample_batched( input_sample): num_samples = len(alt_batch) mult_set = numpy_collate([sample_set] * num_samples) mult_set = set_model_input(mult_set, numpy_collate(alt_batch), input_id=self.model_input) alt = self.model.predict_on_batch(mult_set) for alt_sample_i in range(num_samples): alt_sample = get_dataset_item(alt, alt_sample_i) # Apply the output selection function if defined if self.output_sel_fn is not None: alt_sample = self.output_sel_fn(alt_sample) # Apply scores across all model outputs for ref and alt output_scores = [ apply_within(ref_sample_pred, alt_sample, scr) for scr in self.scores ] score.__setitem__(alt_idxs[alt_sample_i], output_scores) scores.append(score.tolist()) return scores
def test_SampleIterator(data): # SampleIterator example: class MySampleIterator(SampleIterator): def __init__(self, data): self.data = data self.idx = 0 def __iter__(self): return self def __next__(self): if self.idx >= self.data["targets"].shape[0]: raise StopIteration ret = get_dataset_item(self.data, self.idx) self.idx += 1 return ret next = __next__ # ------------------------ d = MySampleIterator(data) compare_arrays(d.load_all(), data) d = MySampleIterator(data) it = d.batch_iter(batch_size=3) compare_arrays(next(it), get_dataset_item(data, np.arange(3)))
def test_BatchIterator(data): # BatchIterator example: class MyBatchIterator(BatchIterator): def __init__(self, data, batch_size): self.data = data self.batch_size = batch_size self.idx = 0 def __iter__(self): return self def __next__(self): idx = self.idx start = idx * self.batch_size if start >= self.data["targets"].shape[0]: raise StopIteration end = min((idx + 1) * self.batch_size, self.data["targets"].shape[0]) self.idx += 1 return get_dataset_item(self.data, np.arange(start, end)) next = __next__ # ------------------------ d = MyBatchIterator(data, 3) compare_arrays(d.load_all(), data) d = MyBatchIterator(data, 3) it = d.batch_iter() compare_arrays(next(it), get_dataset_item(data, np.arange(3)))
def __next__(self): idx = self.idx start = idx * self.batch_size if start >= self.data["targets"].shape[0]: raise StopIteration end = min((idx + 1) * self.batch_size, self.data["targets"].shape[0]) self.idx += 1 return get_dataset_item(self.data, np.arange(start, end))
def test_PreloadedDataset(data): # PreloadedDataset example: def data_fn(): return data # ------------------------ d = PreloadedDataset.from_fn(data_fn)() compare_arrays(d.load_all(), data) it = d.batch_iter(3) compare_arrays(next(it), get_dataset_item(data, np.arange(3))) # test batch_train_iter it = d.batch_train_iter(batch_size=2) for i in range(6): x, y = next(it) compare_arrays_x(x, get_dataset_item(data, np.arange(2))['inputs']) compare_arrays_y(y, get_dataset_item(data, np.arange(2))['targets'])
def test_get_item(data): dlen = get_dataset_lens(data)[0] assert dlen == 3 assert len(set(get_dataset_lens(data))) == 1 assert get_dataset_item(data, 1) == { "a": [1], "b": { "d": 1 }, "c": np.array([1]) }
def test_PreloadedDataset(data): # PreloadedDataset example: def data_fn(): return data # ------------------------ d = PreloadedDataset.from_fn(data_fn)() compare_arrays(d.load_all(), data) it = d.batch_iter(3) compare_arrays(next(it), get_dataset_item(data, np.arange(3)))
def nested_numpy_minibatch(data, batch_size=1): lens = get_dataset_lens(data) if isinstance(lens, collections.Mapping): ln = [v for v in lens.values()][0] elif isinstance(lens, collections.Sequence): ln = lens[0] else: ln = lens for idx in BatchSampler(range(ln), batch_size=batch_size, drop_last=False): yield get_dataset_item(data, idx)
def test_SampleGenerator(data): # SampleGenerator example: def generator_fn(data): for idx in range(data["targets"].shape[0]): yield get_dataset_item(data, idx) # ------------------------ d = SampleGenerator.from_fn(generator_fn)(data) compare_arrays(d.load_all(), data) d = SampleGenerator.from_fn(generator_fn)(data) it = d.batch_iter(batch_size=3) compare_arrays(next(it), get_dataset_item(data, np.arange(3))) d = SampleGenerator.from_fn(generator_fn)(data) it = d.batch_train_iter(batch_size=2) for i in range(6): x, y = next(it) compare_arrays_x(x, get_dataset_item(data, np.arange(2))['inputs']) compare_arrays_y(y, get_dataset_item(data, np.arange(2))['targets'])
def test_SampleGenerator(data): # SampleGenerator example: def generator_fn(data): for idx in range(data["targets"].shape[0]): yield get_dataset_item(data, idx) # ------------------------ d = SampleGenerator.from_fn(generator_fn)(data) compare_arrays(d.load_all(), data) d = SampleGenerator.from_fn(generator_fn)(data) it = d.batch_iter(batch_size=3) compare_arrays(next(it), get_dataset_item(data, np.arange(3)))
def test_BatchGenerator(data): # BatchGenerator example: def generator_fn(data, batch_size): for idx in range(int(np.ceil(data["targets"].shape[0] / batch_size))): start = idx * batch_size end = min((idx + 1) * batch_size, data["targets"].shape[0]) yield get_dataset_item(data, np.arange(start, end)) # ------------------------ d = BatchGenerator.from_fn(generator_fn)(data, 3) compare_arrays(d.load_all(), data) d = BatchGenerator.from_fn(generator_fn)(data, 3) it = d.batch_iter() compare_arrays(next(it), get_dataset_item(data, np.arange(3))) d = BatchGenerator.from_fn(generator_fn)(data, 2) it = d.batch_train_iter() for i in range(6): x, y = next(it) compare_arrays_x(x, get_dataset_item(data, np.arange(2))['inputs']) compare_arrays_y(y, get_dataset_item(data, np.arange(2))['targets'])
def test_Dataset(data): # Dataset example: class MyDataset(Dataset): def __init__(self, data): self.data = data def __len__(self): return self.data["targets"].shape[0] def __getitem__(self, idx): return get_dataset_item(self.data, idx) # ------------------------ d = MyDataset(data) compare_arrays(d.load_all(), data) it = d.batch_iter(3) compare_arrays(next(it), get_dataset_item(data, np.arange(3)))
def chip_exo_nexus(dataspec, peak_width=200, shuffle=True, preprocessor=AppendTotalCounts(), interval_augm=lambda x: x, valid_chr=valid_chr, test_chr=test_chr): """ General dataloading function for ChIP-exo or ChIP-nexus data Args: dataspec: basepair.schemas.DataSpec object containing information about the bigwigs, fasta_file and peak_width: final width of the interval to extract shuffle: if true, the order of the peaks will get shuffled preprocessor: preprocessor object - needs to implement .fit() and .predict() methods interval_augm: interval augmentor. valid_chr: list of chromosomes in the validation split test_chr: list of chromosomes in the test split Returns: (train, valid, test) tuple where train consists of: - x: one-hot encoded sequence, sample shape: (peak_width, 4) - y: dictionary containing fields: {task_id}/profile: sample shape - (peak_width, 2), count profile {task_id}/counts: sample shape - (2, ), total number of counts per strand - metadata: pandas dataframe storing the original intervals """ for v in valid_chr: assert v not in test_chr def set_attrs_name(interval, name): """Add a name to the interval """ interval.attrs['name'] = name return interval # Load intervals for all tasks. # remember the task name in interval.name def get_bt(peaks): if peaks is None: return [] else: return BedTool(peaks) # Resize and skip infervals outside of the genome from pysam import FastaFile fa = FastaFile(dataspec.fasta_file) # intervals = len(get_bt(peaks)) # n_int = len(intervals) intervals = [ set_attrs_name(resize_interval(interval_augm(interval), peak_width), task) for task, ds in dataspec.task_specs.items() for i, interval in enumerate(get_bt(ds.peaks)) if keep_interval(interval, peak_width, fa) ] # if len(intervals) != n_int: # logger.warn(f"Skipped {n_int - len(intervals)} intervals" # " outside of the genome size") if shuffle: Random(42).shuffle(intervals) # Setup metadata dfm = pd.DataFrame( dict(id=np.arange(len(intervals)), chr=[x.chrom for x in intervals], start=[x.start for x in intervals], end=[x.stop for x in intervals], task=[x.attrs['name'] for x in intervals])) logger.info("extract sequence") seq = FastaExtractor(dataspec.fasta_file)(intervals) logger.info("extract counts") cuts = { f"profile/{task}": spec.load_counts(intervals) for task, spec in tqdm(dataspec.task_specs.items()) } # # sum across the sequence # for task in dataspec.task_specs: # cuts[f"counts/{task}"] = cuts[f"profile/{task}"].sum(axis=1) assert len(seq) == len(dfm) assert len(seq) == len(cuts[list(cuts.keys())[0]]) # Split by chromosomes is_test = dfm.chr.isin(test_chr) is_valid = dfm.chr.isin(valid_chr) is_train = (~is_test) & (~is_valid) train = [seq[is_train], get_dataset_item(cuts, is_train), dfm[is_train]] valid = [seq[is_valid], get_dataset_item(cuts, is_valid), dfm[is_valid]] test = [seq[is_test], get_dataset_item(cuts, is_test), dfm[is_test]] if preprocessor is not None: preprocessor.fit(train[1]) train[1] = preprocessor.transform(train[1]) valid[1] = preprocessor.transform(valid[1]) test[1] = preprocessor.transform(test[1]) train.append(preprocessor) return (train, valid, test)
def __next__(self): if self.idx >= self.data["targets"].shape[0]: raise StopIteration ret = get_dataset_item(self.data, self.idx) self.idx += 1 return ret
def __getitem__(self, index): return get_dataset_item(self.data, index)
def __getitem__(self, idx): start = idx * self.batch_size end = min((idx + 1) * self.batch_size, self.data["targets"].shape[0]) return get_dataset_item(self.data, np.arange(start, end))
def generator_fn(data): for idx in range(data["targets"].shape[0]): yield get_dataset_item(data, idx)
def generator_fn(data, batch_size): for idx in range(int(np.ceil(data["targets"].shape[0] / batch_size))): start = idx * batch_size end = min((idx + 1) * batch_size, data["targets"].shape[0]) yield get_dataset_item(data, np.arange(start, end))