Пример #1
0
 def __init__(self, zarr_file, inputs_fields, batch_size=64):
     ## inputs_field: 'preds/logit_diff'
     self.zarr_file = zarr_file
     self.f = ZarrReader(self.zarr_file)
     self.f.open()
     self.batch_dataset = HStackBatchDataset([
         MatrixDataset(self.f.root[field], batch_size=batch_size)
         for field in inputs_fields
     ])
Пример #2
0
def test_ZarrBatchWriter_list(dl_batch, pred_batch_list, tmpdir):
    tmpfile = str(tmpdir.mkdir("example").join("out.zip.zarr"))
    batch = prepare_batch(dl_batch, pred_batch_list)
    writer = ZarrBatchWriter(tmpfile, chunk_size=4)

    writer.batch_write(batch)
    writer.batch_write(batch)
    writer.close()
    with ZarrReader(tmpfile) as f:
        assert np.all(
            list(f.batch_iter(2))[0]['metadata']['gene_id'] ==
            dl_batch['metadata']['gene_id'][:2])
        out = f.load_all()
        assert np.all(out['metadata']['gene_id'] == np.concatenate([
            dl_batch['metadata']['gene_id'], dl_batch['metadata']['gene_id']
        ]))
        assert np.all(out['metadata']['ranges']["chr"] == np.concatenate([
            dl_batch['metadata']['ranges']['chr'], dl_batch['metadata']
            ['ranges']['chr']
        ]))
        assert np.all(out['metadata']['ranges']["start"] == np.concatenate([
            dl_batch['metadata']['ranges']['start'], dl_batch['metadata']
            ['ranges']['start']
        ]))
        assert np.all(out['preds'][0][:3] == pred_batch_list[0])
Пример #3
0
class ZarrBatchDataset(BatchDataset):
    def __init__(self, zarr_file, inputs_fields, batch_size=64):
        ## inputs_field: 'preds/logit_diff'
        self.zarr_file = zarr_file
        self.f = ZarrReader(self.zarr_file)
        self.f.open()
        self.batch_dataset = HStackBatchDataset([
            MatrixDataset(self.f.root[field], batch_size=batch_size)
            for field in inputs_fields
        ])

    def __len__(self):
        return len(self.batch_dataset)

    def __getitem__(self, idx):
        return self.batch_dataset[idx]