Exemplo n.º 1
0
def test_BatchDataset(data):
    # BatchDataset example:
    class MyBatchDataset(BatchDataset):

        def __init__(self, data, batch_size=3):
            self.data = data
            self.batch_size = batch_size

        def __len__(self):
            return int(np.ceil(self.data["targets"].shape[0] / self.batch_size))

        def __getitem__(self, idx):
            start = idx * self.batch_size
            end = min((idx + 1) * self.batch_size, self.data["targets"].shape[0])
            return get_dataset_item(self.data, np.arange(start, end))

    # ------------------------
    d = MyBatchDataset(data)

    compare_arrays(d.load_all(), data)
    it = d.batch_iter()
    compare_arrays(next(it), get_dataset_item(data, np.arange(3)))

    # batch_train_iter
    d = MyBatchDataset(data, batch_size=2)
    it = d.batch_train_iter()
    for i in range(6):
        x, y = next(it)
    compare_arrays_x(x, get_dataset_item(data, np.arange(2))['inputs'])
    compare_arrays_y(y, get_dataset_item(data, np.arange(2))['targets'])
Exemplo n.º 2
0
    def predict_all(self,
                    seq,
                    imp_method='grad',
                    batch_size=512,
                    pred_summaries=['weighted', 'count']):
        """Make model prediction based
        """
        if self.bias_model is not None:
            seq, = self.bias_model.predict((seq, ), batch_size)

        preds = self.predict(seq, batch_size=batch_size)

        if imp_method is not None:
            imp_scores = self.imp_score_all(seq,
                                            method=imp_method,
                                            aggregate_strand=True,
                                            batch_size=batch_size,
                                            pred_summaries=pred_summaries)
        else:
            imp_scores = dict()

        out = [
            dict(
                seq=get_dataset_item(seq, i),
                # interval=intervals[i],
                pred=get_dataset_item(preds, i),
                # TODO - shall we call it hyp_imp score or imp_score?
                imp_score=get_dataset_item(imp_scores, i),
            ) for i in range(len(seq))
        ]
        return out
Exemplo n.º 3
0
def test_Dataset(data):
    # Dataset example:
    class MyDataset(Dataset):

        def __init__(self, data):
            self.data = data

        def __len__(self):
            return self.data["targets"].shape[0]

        def __getitem__(self, idx):
            return get_dataset_item(self.data, idx)

    # ------------------------

    d = MyDataset(data)

    compare_arrays(d.load_all(), data)
    it = d.batch_iter(3)
    compare_arrays(next(it), get_dataset_item(data, np.arange(3)))

    # test batch_train_iter
    it = d.batch_train_iter(batch_size=2)
    for i in range(6):
        x, y = next(it)
    compare_arrays_x(x, get_dataset_item(data, np.arange(2))['inputs'])
    compare_arrays_y(y, get_dataset_item(data, np.arange(2))['targets'])
Exemplo n.º 4
0
    def score(self, input_batch):
        """
        Args:
          input_batch: Input batch that should be scored.

        Returns:
          A list of length: len(`scores`). Every element of the list is
             a stacked list of depth D if the model input is D-dimensional
             with identcal shape. Every entry of that list then contains the
             scores of the model output selected by `output_sel_fn`. Values
             are `None` if the input_batch already had a `1` at that position.
        """

        ref = self.model.predict_on_batch(input_batch)
        scores = []
        for sample_i in range(
                len(get_model_input(input_batch, self.model_input))):

            # get the full set of model inputs for the selected sample
            sample_set = get_dataset_item(input_batch, sample_i)

            # get the reference output for this sample
            ref_sample_pred = get_dataset_item(ref, sample_i)

            # Apply the output selection function if defined
            if self.output_sel_fn is not None:
                ref_sample_pred = self.output_sel_fn(ref_sample_pred)

            # get the one-hot encoded reference input array
            input_sample = get_model_input(sample_set,
                                           input_id=self.model_input)

            # where we keep the scores - scores are lists (ordered by diff
            # method of ndarrays, lists or dictionaries - whatever is returned by the model
            score = np.empty(input_sample.shape, dtype=object)
            score[:] = None
            for alt_batch, alt_idxs in self._mutate_sample_batched(
                    input_sample):
                num_samples = len(alt_batch)
                mult_set = numpy_collate([sample_set] * num_samples)
                mult_set = set_model_input(mult_set,
                                           numpy_collate(alt_batch),
                                           input_id=self.model_input)
                alt = self.model.predict_on_batch(mult_set)
                for alt_sample_i in range(num_samples):
                    alt_sample = get_dataset_item(alt, alt_sample_i)
                    # Apply the output selection function if defined
                    if self.output_sel_fn is not None:
                        alt_sample = self.output_sel_fn(alt_sample)
                    # Apply scores across all model outputs for ref and alt
                    output_scores = [
                        apply_within(ref_sample_pred, alt_sample, scr)
                        for scr in self.scores
                    ]
                    score.__setitem__(alt_idxs[alt_sample_i], output_scores)
            scores.append(score.tolist())

        return scores
Exemplo n.º 5
0
def test_SampleIterator(data):
    # SampleIterator example:
    class MySampleIterator(SampleIterator):
        def __init__(self, data):
            self.data = data
            self.idx = 0

        def __iter__(self):
            return self

        def __next__(self):
            if self.idx >= self.data["targets"].shape[0]:
                raise StopIteration
            ret = get_dataset_item(self.data, self.idx)
            self.idx += 1
            return ret
        next = __next__
    # ------------------------

    d = MySampleIterator(data)

    compare_arrays(d.load_all(), data)
    d = MySampleIterator(data)
    it = d.batch_iter(batch_size=3)
    compare_arrays(next(it), get_dataset_item(data, np.arange(3)))
Exemplo n.º 6
0
def test_BatchIterator(data):
    # BatchIterator example:
    class MyBatchIterator(BatchIterator):
        def __init__(self, data, batch_size):
            self.data = data
            self.batch_size = batch_size
            self.idx = 0

        def __iter__(self):
            return self

        def __next__(self):
            idx = self.idx
            start = idx * self.batch_size
            if start >= self.data["targets"].shape[0]:
                raise StopIteration
            end = min((idx + 1) * self.batch_size, self.data["targets"].shape[0])
            self.idx += 1
            return get_dataset_item(self.data, np.arange(start, end))
        next = __next__
    # ------------------------

    d = MyBatchIterator(data, 3)

    compare_arrays(d.load_all(), data)
    d = MyBatchIterator(data, 3)
    it = d.batch_iter()
    compare_arrays(next(it), get_dataset_item(data, np.arange(3)))
Exemplo n.º 7
0
 def __next__(self):
     idx = self.idx
     start = idx * self.batch_size
     if start >= self.data["targets"].shape[0]:
         raise StopIteration
     end = min((idx + 1) * self.batch_size, self.data["targets"].shape[0])
     self.idx += 1
     return get_dataset_item(self.data, np.arange(start, end))
Exemplo n.º 8
0
def test_PreloadedDataset(data):
    # PreloadedDataset example:
    def data_fn():
        return data

    # ------------------------

    d = PreloadedDataset.from_fn(data_fn)()

    compare_arrays(d.load_all(), data)
    it = d.batch_iter(3)
    compare_arrays(next(it), get_dataset_item(data, np.arange(3)))

    # test batch_train_iter
    it = d.batch_train_iter(batch_size=2)
    for i in range(6):
        x, y = next(it)
    compare_arrays_x(x, get_dataset_item(data, np.arange(2))['inputs'])
    compare_arrays_y(y, get_dataset_item(data, np.arange(2))['targets'])
Exemplo n.º 9
0
def test_get_item(data):
    dlen = get_dataset_lens(data)[0]
    assert dlen == 3
    assert len(set(get_dataset_lens(data))) == 1
    assert get_dataset_item(data, 1) == {
        "a": [1],
        "b": {
            "d": 1
        },
        "c": np.array([1])
    }
Exemplo n.º 10
0
def test_PreloadedDataset(data):
    # PreloadedDataset example:
    def data_fn():
        return data
    # ------------------------

    d = PreloadedDataset.from_fn(data_fn)()

    compare_arrays(d.load_all(), data)
    it = d.batch_iter(3)
    compare_arrays(next(it), get_dataset_item(data, np.arange(3)))
Exemplo n.º 11
0
def nested_numpy_minibatch(data, batch_size=1):
    lens = get_dataset_lens(data)
    if isinstance(lens, collections.Mapping):
        ln = [v for v in lens.values()][0]
    elif isinstance(lens, collections.Sequence):
        ln = lens[0]
    else:
        ln = lens

    for idx in BatchSampler(range(ln),
                            batch_size=batch_size,
                            drop_last=False):
        yield get_dataset_item(data, idx)
Exemplo n.º 12
0
def test_SampleGenerator(data):
    # SampleGenerator example:
    def generator_fn(data):
        for idx in range(data["targets"].shape[0]):
            yield get_dataset_item(data, idx)

    # ------------------------

    d = SampleGenerator.from_fn(generator_fn)(data)

    compare_arrays(d.load_all(), data)
    d = SampleGenerator.from_fn(generator_fn)(data)

    it = d.batch_iter(batch_size=3)
    compare_arrays(next(it), get_dataset_item(data, np.arange(3)))

    d = SampleGenerator.from_fn(generator_fn)(data)
    it = d.batch_train_iter(batch_size=2)
    for i in range(6):
        x, y = next(it)
    compare_arrays_x(x, get_dataset_item(data, np.arange(2))['inputs'])
    compare_arrays_y(y, get_dataset_item(data, np.arange(2))['targets'])
Exemplo n.º 13
0
def test_SampleGenerator(data):
    # SampleGenerator example:
    def generator_fn(data):
        for idx in range(data["targets"].shape[0]):
            yield get_dataset_item(data, idx)
    # ------------------------

    d = SampleGenerator.from_fn(generator_fn)(data)

    compare_arrays(d.load_all(), data)
    d = SampleGenerator.from_fn(generator_fn)(data)

    it = d.batch_iter(batch_size=3)
    compare_arrays(next(it), get_dataset_item(data, np.arange(3)))
Exemplo n.º 14
0
def test_BatchGenerator(data):
    # BatchGenerator example:
    def generator_fn(data, batch_size):
        for idx in range(int(np.ceil(data["targets"].shape[0] / batch_size))):
            start = idx * batch_size
            end = min((idx + 1) * batch_size, data["targets"].shape[0])
            yield get_dataset_item(data, np.arange(start, end))

    # ------------------------

    d = BatchGenerator.from_fn(generator_fn)(data, 3)

    compare_arrays(d.load_all(), data)
    d = BatchGenerator.from_fn(generator_fn)(data, 3)

    it = d.batch_iter()
    compare_arrays(next(it), get_dataset_item(data, np.arange(3)))

    d = BatchGenerator.from_fn(generator_fn)(data, 2)
    it = d.batch_train_iter()
    for i in range(6):
        x, y = next(it)
    compare_arrays_x(x, get_dataset_item(data, np.arange(2))['inputs'])
    compare_arrays_y(y, get_dataset_item(data, np.arange(2))['targets'])
Exemplo n.º 15
0
def test_Dataset(data):
    # Dataset example:
    class MyDataset(Dataset):
        def __init__(self, data):
            self.data = data

        def __len__(self):
            return self.data["targets"].shape[0]

        def __getitem__(self, idx):
            return get_dataset_item(self.data, idx)
    # ------------------------

    d = MyDataset(data)

    compare_arrays(d.load_all(), data)
    it = d.batch_iter(3)
    compare_arrays(next(it), get_dataset_item(data, np.arange(3)))
Exemplo n.º 16
0
def chip_exo_nexus(dataspec,
                   peak_width=200,
                   shuffle=True,
                   preprocessor=AppendTotalCounts(),
                   interval_augm=lambda x: x,
                   valid_chr=valid_chr,
                   test_chr=test_chr):
    """
    General dataloading function for ChIP-exo or ChIP-nexus data

    Args:
      dataspec: basepair.schemas.DataSpec object containing information about
        the bigwigs, fasta_file and
      peak_width: final width of the interval to extract
      shuffle: if true, the order of the peaks will get shuffled
      preprocessor: preprocessor object - needs to implement .fit() and .predict() methods
      interval_augm: interval augmentor.
      valid_chr: list of chromosomes in the validation split
      test_chr: list of chromosomes in the test split

    Returns:
      (train, valid, test) tuple where train consists of:
        - x: one-hot encoded sequence, sample shape: (peak_width, 4)
        - y: dictionary containing fields:
          {task_id}/profile: sample shape - (peak_width, 2), count profile
          {task_id}/counts: sample shape - (2, ), total number of counts per strand
        - metadata: pandas dataframe storing the original intervals

    """
    for v in valid_chr:
        assert v not in test_chr

    def set_attrs_name(interval, name):
        """Add a name to the interval
        """
        interval.attrs['name'] = name
        return interval

    # Load intervals for all tasks.
    #   remember the task name in interval.name
    def get_bt(peaks):
        if peaks is None:
            return []
        else:
            return BedTool(peaks)

    # Resize and skip infervals outside of the genome
    from pysam import FastaFile
    fa = FastaFile(dataspec.fasta_file)
    #     intervals = len(get_bt(peaks))
    #     n_int = len(intervals)

    intervals = [
        set_attrs_name(resize_interval(interval_augm(interval), peak_width),
                       task) for task, ds in dataspec.task_specs.items()
        for i, interval in enumerate(get_bt(ds.peaks))
        if keep_interval(interval, peak_width, fa)
    ]
    #     if len(intervals) != n_int:
    #         logger.warn(f"Skipped {n_int - len(intervals)} intervals"
    #                     " outside of the genome size")

    if shuffle:
        Random(42).shuffle(intervals)

    # Setup metadata
    dfm = pd.DataFrame(
        dict(id=np.arange(len(intervals)),
             chr=[x.chrom for x in intervals],
             start=[x.start for x in intervals],
             end=[x.stop for x in intervals],
             task=[x.attrs['name'] for x in intervals]))

    logger.info("extract sequence")
    seq = FastaExtractor(dataspec.fasta_file)(intervals)

    logger.info("extract counts")
    cuts = {
        f"profile/{task}": spec.load_counts(intervals)
        for task, spec in tqdm(dataspec.task_specs.items())
    }
    # # sum across the sequence
    # for task in dataspec.task_specs:
    #     cuts[f"counts/{task}"] = cuts[f"profile/{task}"].sum(axis=1)
    assert len(seq) == len(dfm)
    assert len(seq) == len(cuts[list(cuts.keys())[0]])

    # Split by chromosomes
    is_test = dfm.chr.isin(test_chr)
    is_valid = dfm.chr.isin(valid_chr)
    is_train = (~is_test) & (~is_valid)

    train = [seq[is_train], get_dataset_item(cuts, is_train), dfm[is_train]]
    valid = [seq[is_valid], get_dataset_item(cuts, is_valid), dfm[is_valid]]
    test = [seq[is_test], get_dataset_item(cuts, is_test), dfm[is_test]]

    if preprocessor is not None:
        preprocessor.fit(train[1])
        train[1] = preprocessor.transform(train[1])
        valid[1] = preprocessor.transform(valid[1])
        test[1] = preprocessor.transform(test[1])

    train.append(preprocessor)
    return (train, valid, test)
Exemplo n.º 17
0
 def __next__(self):
     if self.idx >= self.data["targets"].shape[0]:
         raise StopIteration
     ret = get_dataset_item(self.data, self.idx)
     self.idx += 1
     return ret
Exemplo n.º 18
0
 def __getitem__(self, index):
     return get_dataset_item(self.data, index)
Exemplo n.º 19
0
 def __getitem__(self, idx):
     start = idx * self.batch_size
     end = min((idx + 1) * self.batch_size, self.data["targets"].shape[0])
     return get_dataset_item(self.data, np.arange(start, end))
Exemplo n.º 20
0
 def generator_fn(data):
     for idx in range(data["targets"].shape[0]):
         yield get_dataset_item(data, idx)
Exemplo n.º 21
0
 def generator_fn(data, batch_size):
     for idx in range(int(np.ceil(data["targets"].shape[0] / batch_size))):
         start = idx * batch_size
         end = min((idx + 1) * batch_size, data["targets"].shape[0])
         yield get_dataset_item(data, np.arange(start, end))