Exemple #1
0
    def setup(self):

        f = h5py.File(self.filename, 'r')

        self.spec = ProviderSpec()
        self.ndims = None
        for (volume_type, ds) in self.datasets.items():

            if ds not in f:
                raise RuntimeError("%s not in %s" % (ds, self.filename))

            dims = f[ds].shape
            self.spec.volumes[volume_type] = Roi((0, ) * len(dims), dims)

            if self.ndims is None:
                self.ndims = len(dims)
            else:
                assert self.ndims == len(dims)

            if self.specified_resolution is None:
                if 'resolution' in f[ds].attrs:
                    self.resolutions[volume_type] = tuple(
                        f[ds].attrs['resolution'])
                else:
                    default_resolution = (1, ) * self.ndims
                    logger.warning(
                        "WARNING: your source does not contain resolution information"
                        " (no attribute 'resolution' in {} dataset). I will assume {}. "
                        "This might not be what you want.".format(
                            ds, default_resolution))
                    self.resolutions[volume_type] = default_resolution
            else:
                self.resolutions[volume_type] = self.specified_resolution

        f.close()
Exemple #2
0
    def create_output_file(self):
        try:
            os.makedirs(self.output_dir)
        except:
            pass
        _file = h5py.File(os.path.join(self.output_dir, self.output_filename),
                          'a')
        for (array_key, dataset_name) in self.dataset_names.items():

            logger.debug("Create dataset for %s", array_key)

            total_roi = self.spec[array_key].roi
            dims = total_roi.dims()

            # extends of spatial dimensions
            data_shape = total_roi.get_shape(
            ) // self.spec[array_key].voxel_size
            logger.debug("Shape in voxels: %s", data_shape)
            # add channel dimensions (HACK: Unsure how to get channels)
            data_shape = Coordinate([3])[:] + data_shape
            logger.debug("Shape with channel dimensions: %s", data_shape)

            if array_key in self.dataset_dtypes:
                dtype = self.dataset_dtypes[array_key]
            else:
                dtype = self.spec[array_key].dtype
            dataset = _file.create_dataset(name=dataset_name,
                                           shape=data_shape,
                                           compression=self.compression_type,
                                           dtype=dtype)
            self.dataset_shapes[dataset_name] = data_shape
            dataset.attrs['offset'] = total_roi.get_offset()
            dataset.attrs['resolution'] = self.spec[array_key].voxel_size
        _file.close()
Exemple #3
0
    def provide(self, request):

        timing = Timing(self)
        timing.start()

        spec = self.get_spec()

        batch = Batch()

        with h5py.File(self.filename, 'r') as f:

            for (volume_type, roi) in request.volumes.items():

                if volume_type not in spec.volumes:
                    raise RuntimeError("Asked for %s which this source does not provide"%volume_type)

                if not spec.volumes[volume_type].contains(roi):
                    raise RuntimeError("%s's ROI %s outside of my ROI %s"%(volume_type,roi,spec.volumes[volume_type]))

                logger.debug("Reading %s in %s..."%(volume_type,roi))

                # shift request roi into dataset
                dataset_roi = roi.shift(-spec.volumes[volume_type].get_offset())

                batch.volumes[volume_type] = Volume(
                        self.__read(f, self.datasets[volume_type], dataset_roi),
                        roi=roi,
                        resolution=self.resolutions[volume_type])

        logger.debug("done")

        timing.stop()
        batch.profiling_stats.add(timing)

        return batch
Exemple #4
0
    def test_output_2d(self):
        path = self.path_to('test_hdf_source.hdf')

        # create a test file
        with h5py.File(path, 'w') as f:
            f['raw'] = np.zeros((100, 100), dtype=np.float32)
            f['raw_low'] = np.zeros((10, 10), dtype=np.float32)
            f['raw_low'].attrs['resolution'] = (10, 10)
            f['seg'] = np.ones((100, 100), dtype=np.uint64)

        # read arrays
        raw = ArrayKey('RAW')
        raw_low = ArrayKey('RAW_LOW')
        seg = ArrayKey('SEG')
        source = Hdf5Source(path, {raw: 'raw', raw_low: 'raw_low', seg: 'seg'})

        with build(source):

            batch = source.request_batch(
                BatchRequest({
                    raw: ArraySpec(roi=Roi((0, 0), (100, 100))),
                    raw_low: ArraySpec(roi=Roi((0, 0), (100, 100))),
                    seg: ArraySpec(roi=Roi((0, 0), (100, 100))),
                }))

            self.assertTrue(batch.arrays[raw].spec.interpolatable)
            self.assertTrue(batch.arrays[raw_low].spec.interpolatable)
            self.assertFalse(batch.arrays[seg].spec.interpolatable)
Exemple #5
0
    def process(self, batch, request):

        if self.record_snapshot:

            try:
                os.makedirs(self.output_dir)
            except:
                pass

            snapshot_name = os.path.join(
                self.output_dir,
                self.output_filename.format(id=str(batch.id).zfill(8),
                                            iteration=int(batch.iteration
                                                          or 0)))
            logger.info('saving to %s' % snapshot_name)
            with h5py.File(snapshot_name, 'w') as f:

                for (array_key, array) in batch.arrays.items():

                    if array_key not in self.dataset_names:
                        continue

                    ds_name = self.dataset_names[array_key]

                    if array_key in self.dataset_dtypes:
                        dtype = self.dataset_dtypes[array_key]
                        dataset = f.create_dataset(
                            name=ds_name,
                            data=array.data.astype(dtype),
                            compression=self.compression_type)
                    else:
                        dataset = f.create_dataset(
                            name=ds_name,
                            data=array.data,
                            compression=self.compression_type)

                    if array.spec.roi is not None:
                        dataset.attrs['offset'] = tuple(
                            float(o) for o in array.spec.roi.get_offset())
                    dataset.attrs['resolution'] = tuple(
                        float(vs) for vs in self.spec[array_key].voxel_size)

                    # if array has attributes, add them to the dataset
                    for attribute_name, attribute in array.attrs.items():
                        dataset.attrs[attribute_name] = attribute

                    logger.debug('Getting additional attributes for %s',
                                 array_key)
                    for (attribute_name,
                         value) in self.attributes_callback(array_key,
                                                            array).items():
                        dataset.attrs[attribute_name] = value

                if batch.loss is not None:
                    f['/'].attrs['loss'] = batch.loss

        self.n += 1
Exemple #6
0
    def process(self, batch, request):

        if self.record_snapshot:

            try:
                os.makedirs(self.output_dir)
            except:
                pass

            snapshot_name = os.path.join(
                self.output_dir,
                self.output_filename.format(id=str(batch.id).zfill(8),
                                            iteration=int(batch.iteration
                                                          or 0)))
            logger.info('saving to %s' % snapshot_name)
            with h5py.File(snapshot_name, 'w') as f:

                for (array_key, array) in batch.arrays.items():

                    if array_key not in self.dataset_names:
                        continue

                    ds_name = self.dataset_names[array_key]

                    offset = array.spec.roi.get_offset()
                    if array_key in self.dataset_dtypes:
                        dtype = self.dataset_dtypes[array_key]
                        dataset = f.create_dataset(
                            name=ds_name,
                            data=array.data.astype(dtype),
                            compression=self.compression_type)
                    else:
                        dataset = f.create_dataset(
                            name=ds_name,
                            data=array.data,
                            compression=self.compression_type)

                    dataset.attrs['offset'] = offset
                    dataset.attrs['resolution'] = self.spec[
                        array_key].voxel_size

                    if self.store_value_range:
                        dataset.attrs['value_range'] = (np.asscalar(
                            array.data.min()), np.asscalar(array.data.max()))

                    # if array has attributes, add them to the dataset
                    for attribute_name, attribute in array.attrs.items():
                        dataset.attrs[attribute_name] = attribute

                if batch.loss is not None:
                    f['/'].attrs['loss'] = batch.loss

        self.n += 1
Exemple #7
0
    def provide(self, request):

        timing_process = Timing(self)
        timing_process.start()

        batch = Batch()

        with h5py.File(self.filename, 'r') as hdf_file:

            # if pre and postsynaptic locations required, their id
            # SynapseLocation dictionaries should be created together s.t. ids
            # are unique and allow to find partner locations

            if PointsKeys.PRESYN in request.points_specs or PointsKeys.POSTSYN in request.points_specs:
                assert self.kind == 'synapse'
                # If only PRESYN or POSTSYN requested, assume PRESYN ROI = POSTSYN ROI.
                pre_key = PointsKeys.PRESYN if PointsKeys.PRESYN in request.points_specs else PointsKeys.POSTSYN
                post_key = PointsKeys.POSTSYN if PointsKeys.POSTSYN in request.points_specs else PointsKeys.PRESYN
                presyn_points, postsyn_points = self.__get_syn_points(
                    pre_roi=request.points_specs[pre_key].roi,
                    post_roi=request.points_specs[post_key].roi,
                    syn_file=hdf_file)
                points = {
                    PointsKeys.PRESYN: presyn_points,
                    PointsKeys.POSTSYN: postsyn_points
                }
            else:
                assert self.kind == 'presyn' or self.kind == 'postsyn'
                synkey = list(self.datasets.items())[0][0]  # only key of dic.
                presyn_points, postsyn_points = self.__get_syn_points(
                    pre_roi=request.points_specs[synkey].roi,
                    post_roi=request.points_specs[synkey].roi,
                    syn_file=hdf_file)
                points = {
                    synkey:
                    presyn_points if self.kind == 'presyn' else postsyn_points
                }

            for (points_key, request_spec) in request.points_specs.items():
                logger.debug("Reading %s in %s...", points_key,
                             request_spec.roi)
                points_spec = self.spec[points_key].copy()
                points_spec.roi = request_spec.roi
                logger.debug("Number of points len()".format(
                    len(points[points_key])))
                batch.points[points_key] = Points(data=points[points_key],
                                                  spec=points_spec)

        timing_process.stop()
        batch.profiling_stats.add(timing_process)

        return batch
Exemple #8
0
    def setup(self):

        hdf_file = h5py.File(self.filename, 'r')

        for (array_key, ds_name) in self.datasets.items():

            if ds_name not in hdf_file:
                raise RuntimeError("%s not in %s" % (ds_name, self.filename))

            spec = self.__read_spec(array_key, hdf_file, ds_name)

            self.provides(array_key, spec)

        hdf_file.close()
Exemple #9
0
    def setup(self):

        hdf_file = h5py.File(self.filename, 'r')

        for (points_key, ds_name) in self.datasets.items():

            if ds_name not in hdf_file:
                raise RuntimeError("%s not in %s" % (ds_name, self.filename))

            spec = PointsSpec()
            spec.roi = self.rois[points_key]

            self.provides(points_key, spec)

        hdf_file.close()
Exemple #10
0
    def create_output_file(self, batch):

        try:
            os.makedirs(self.output_dir)
        except:
            pass

        self.file = h5py.File(
            os.path.join(self.output_dir, self.output_filename), 'w')
        self.datasets = {}

        for (array_key, dataset_name) in self.dataset_names.items():

            logger.debug("Create dataset for %s", array_key)

            assert array_key in self.spec, (
                "Asked to store %s, but is not provided upstream." % array_key)
            assert array_key in batch.arrays, (
                "Asked to store %s, but is not part of batch." % array_key)

            batch_shape = batch.arrays[array_key].data.shape

            total_roi = self.spec[array_key].roi
            dims = total_roi.dims()

            # extends of spatial dimensions
            data_shape = total_roi.get_shape(
            ) // self.spec[array_key].voxel_size
            logger.debug("Shape in voxels: %s", data_shape)
            # add channel dimensions (if present)
            data_shape = batch_shape[:-dims] + data_shape
            logger.debug("Shape with channel dimensions: %s", data_shape)

            if array_key in self.dataset_dtypes:
                dtype = self.dataset_dtypes[array_key]
            else:
                dtype = batch.arrays[array_key].data.dtype

            dataset = self.file.create_dataset(
                name=dataset_name,
                shape=data_shape,
                compression=self.compression_type,
                dtype=dtype)

            dataset.attrs['offset'] = total_roi.get_offset()
            dataset.attrs['resolution'] = self.spec[array_key].voxel_size

            self.datasets[array_key] = dataset
Exemple #11
0
    def process(self, batch, request):

        if self.record_snapshot:

            try:
                os.makedirs(self.output_dir)
            except:
                pass

            snapshot_name = os.path.join(
                self.output_dir,
                self.output_filename.format(id=str(batch.id).zfill(8),
                                            iteration=batch.iteration))
            logger.info("saving to " + snapshot_name)
            with h5py.File(snapshot_name, 'w') as f:

                for (volume_type, volume) in batch.volumes.items():

                    ds_name = {
                        VolumeTypes.RAW:
                        'volumes/raw',
                        VolumeTypes.ALPHA_MASK:
                        'volumes/alpha_mask',
                        VolumeTypes.GT_LABELS:
                        'volumes/labels/neuron_ids',
                        VolumeTypes.GT_AFFINITIES:
                        'volumes/labels/affs',
                        VolumeTypes.GT_MASK:
                        'volumes/labels/mask',
                        VolumeTypes.GT_IGNORE:
                        'volumes/labels/ignore',
                        VolumeTypes.PRED_AFFINITIES:
                        'volumes/predicted_affs',
                        VolumeTypes.LOSS_SCALE:
                        'volumes/loss_scale',
                        VolumeTypes.LOSS_GRADIENT:
                        'volumes/predicted_affs_loss_gradient',
                    }[volume_type]

                    offset = volume.roi.get_offset()
                    offset *= volume.resolution
                    dataset = f.create_dataset(name=ds_name, data=volume.data)
                    dataset.attrs['offset'] = offset
                    dataset.attrs['resolution'] = volume.resolution

                if batch.loss is not None:
                    f['/'].attrs['loss'] = batch.loss
Exemple #12
0
    def test_output(self):
        path = self.path_to('hdf5_write_test.hdf')

        source = Hdf5WriteTestSource()

        chunk_request = BatchRequest()
        chunk_request.add(ArrayKeys.RAW, (400,30,34))
        chunk_request.add(ArrayKeys.GT_LABELS, (200,10,14))

        pipeline = (
            source +
            Hdf5Write({
                ArrayKeys.RAW: 'arrays/raw'
            },
            output_filename=path) +
            Scan(chunk_request))

        with build(pipeline):

            raw_spec    = pipeline.spec[ArrayKeys.RAW]
            labels_spec = pipeline.spec[ArrayKeys.GT_LABELS]

            full_request = BatchRequest({
                    ArrayKeys.RAW: raw_spec,
                    ArrayKeys.GT_LABELS: labels_spec
                }
            )

            batch = pipeline.request_batch(full_request)

        # assert that stored HDF dataset equals batch array

        with h5py.File(path, 'r') as f:

            ds = f['arrays/raw']

            batch_raw = batch.arrays[ArrayKeys.RAW]
            stored_raw = np.array(ds)

            self.assertEqual(
                stored_raw.shape[-3:],
                batch_raw.spec.roi.get_shape()//batch_raw.spec.voxel_size)
            self.assertEqual(tuple(ds.attrs['offset']), batch_raw.spec.roi.get_offset())
            self.assertEqual(tuple(ds.attrs['resolution']), batch_raw.spec.voxel_size)
            self.assertTrue((stored_raw == batch.arrays[ArrayKeys.RAW].data).all())
Exemple #13
0
    def provide(self, request):

        timing = Timing(self)
        timing.start()

        batch = Batch()

        with h5py.File(self.filename, 'r') as hdf_file:

            # if pre and postsynaptic locations required, their id
            # SynapseLocation dictionaries should be created together s.t. ids
            # are unique and allow to find partner locations
            if PointsKeys.PRESYN in request.points_specs or PointsKeys.POSTSYN in request.points_specs:
                assert request.points_specs[
                    PointsKeys.PRESYN].roi == request.points_specs[
                        PointsKeys.POSTSYN].roi
                # Cremi specific, ROI offset corresponds to offset present in the
                # synapse location relative to the raw data.
                dataset_offset = self.spec[PointsKeys.PRESYN].roi.get_offset()
                presyn_points, postsyn_points = self.__get_syn_points(
                    roi=request.points_specs[PointsKeys.PRESYN].roi,
                    syn_file=hdf_file,
                    dataset_offset=dataset_offset)

            for (points_key, request_spec) in request.points_specs.items():

                logger.debug("Reading %s in %s...", points_key,
                             request_spec.roi)
                id_to_point = {
                    PointsKeys.PRESYN: presyn_points,
                    PointsKeys.POSTSYN: postsyn_points
                }[points_key]

                points_spec = self.spec[points_key].copy()
                points_spec.roi = request_spec.roi
                batch.points[points_key] = Points(data=id_to_point,
                                                  spec=points_spec)

        logger.debug("done")

        timing.stop()
        batch.profiling_stats.add(timing)

        return batch
Exemple #14
0
    def provide(self, request):

        timing = Timing(self)
        timing.start()

        spec = self.get_spec()

        batch = Batch()

        with h5py.File(self.filename, 'r') as f:

            for (volume_type, roi) in request.volumes.items():

                if volume_type not in spec.volumes:
                    raise RuntimeError(
                        "Asked for %s which this source does not provide" %
                        volume_type)

                if not spec.volumes[volume_type].contains(roi):
                    raise RuntimeError(
                        "%s's ROI %s outside of my ROI %s" %
                        (volume_type, roi, spec.volumes[volume_type]))

                interpolate = {
                    VolumeType.RAW: True,
                    VolumeType.GT_LABELS: False,
                    VolumeType.GT_MASK: False,
                    VolumeType.ALPHA_MASK: True,
                }[volume_type]

                logger.debug("Reading %s in %s..." % (volume_type, roi))
                batch.volumes[volume_type] = Volume(
                    self.__read(f, self.datasets[volume_type], roi),
                    roi=roi,
                    resolution=self.resolutions[volume_type],
                    interpolate=interpolate)

        logger.debug("done")

        timing.stop()
        batch.profiling_stats.add(timing)

        return batch
Exemple #15
0
def BG_Write_Thread_func(file_name, request_queue, stopEvent):
    logger.info("Thread started with fn %s" % file_name)
    _file = h5py.File(file_name, 'a')

    # keep looping until we are supposed to stop
    while not stopEvent.isSet():
        if not request_queue.empty():
            logger.info("Emptying Queue")
            while not request_queue.empty():
                req = request_queue.get()
                _file[req.dset][req.position] = req.data
    # event was set
    # make sure there was no race condition and no stragglers
    while not request_queue.empty():
        logger.info("A few left in Queue")
        req = request_queue.get()
        _file[req.dset][req.position] = req.data

    _file.close()
Exemple #16
0
    def setup(self):
        hdf_file = h5py.File(self.filename, 'r')

        for (volume_type, ds_name) in self.datasets.items():

            if ds_name not in hdf_file:
                raise RuntimeError("%s not in %s" % (ds_name, self.filename))

            spec = self.__read_spec(volume_type, hdf_file, ds_name)
            self.provides(volume_type, spec)

        if self.points_types is not None:

            for points_type in self.points_types:
                spec = PointsSpec()
                spec.roi = Roi(self.points_rois[points_type][0],
                               self.points_rois[points_type][1])

                self.provides(points_type, spec)

        hdf_file.close()
Exemple #17
0
    def provide(self, request):

        timing = Timing(self)
        timing.start()

        batch = Batch()

        with h5py.File(self.filename, 'r') as hdf_file:

            for (array_key, request_spec) in request.array_specs.items():

                logger.debug("Reading %s in %s...", array_key,
                             request_spec.roi)

                voxel_size = self.spec[array_key].voxel_size

                # scale request roi to voxel units
                dataset_roi = request_spec.roi / voxel_size

                # shift request roi into dataset
                dataset_roi = dataset_roi - self.spec[
                    array_key].roi.get_offset() / voxel_size

                # create array spec
                array_spec = self.spec[array_key].copy()
                array_spec.roi = request_spec.roi

                # add array to batch
                batch.arrays[array_key] = Array(
                    self.__read(hdf_file, self.datasets[array_key],
                                dataset_roi), array_spec)

        logger.debug("done")

        timing.stop()
        batch.profiling_stats.add(timing)

        return batch
Exemple #18
0
 def _open_file(self, filename):
     return h5py.File(filename, 'r')
    def provide(self, request):

        timing = Timing(self)
        timing.start()

        spec = self.get_spec()

        batch = Batch()

        with h5py.File(self.filename, 'r') as f:

            for (volume_type, roi) in request.volumes.items():

                if volume_type not in spec.volumes:
                    raise RuntimeError(
                        "Asked for %s which this source does not provide" %
                        volume_type)

                if not spec.volumes[volume_type].contains(roi):
                    raise RuntimeError(
                        "%s's ROI %s outside of my ROI %s" %
                        (volume_type, roi, spec.volumes[volume_type]))

                interpolate = {
                    VolumeTypes.RAW: True,
                    VolumeTypes.GT_LABELS: False,
                    VolumeTypes.GT_MASK: False,
                    VolumeTypes.ALPHA_MASK: True,
                }[volume_type]

                if volume_type in self.volume_phys_offset:
                    offset_shift = np.array(
                        self.volume_phys_offset[volume_type]) / np.array(
                            self.resolutions[volume_type])
                    roi_offset = roi.shift(tuple(-offset_shift))
                else:
                    roi_offset = roi
                logger.debug("Reading %s in %s..." % (volume_type, roi_offset))
                batch.volumes[volume_type] = Volume(
                    self.__read(f, self.datasets[volume_type], roi_offset),
                    roi=roi,
                    resolution=self.resolutions[volume_type])
                # interpolate=interpolate)

            # if pre and postsynaptic locations required, their id : SynapseLocation dictionaries should be created
            # together s.t. ids are unique and allow to find partner locations
            if PointsTypes.PRESYN in request.points or PointsTypes.POSTSYN in request.points:
                # assert request.points[PointsTypes.PRESYN] == request.points[PointsTypes.POSTSYN]
                # Cremi specific, ROI offset corresponds to offset present in the
                # synapse location relative to the raw data.
                # TODO: Make this generic and in the same style as done for volume_phys_offst.
                dataset_offset = self.get_spec().points[
                    PointsTypes.PRESYN].get_offset()
                presyn_points, postsyn_points = self.__get_syn_points(
                    roi=request.points[PointsTypes.PRESYN],
                    syn_file=f,
                    dataset_offset=dataset_offset)

            for (points_type, roi) in request.points.items():

                if points_type not in spec.points:
                    raise RuntimeError(
                        "Asked for %s which this source does not provide" %
                        points_type)

                if not spec.points[points_type].contains(roi):
                    raise RuntimeError(
                        "%s's ROI %s outside of my ROI %s" %
                        (points_type, roi, spec.points[points_type]))

                logger.debug("Reading %s in %s..." % (points_type, roi))
                id_to_point = {
                    PointsTypes.PRESYN: presyn_points,
                    PointsTypes.POSTSYN: postsyn_points
                }[points_type]
                # TODO: so far assumed that all points have resolution of raw volume
                batch.points[points_type] = Points(
                    data=id_to_point,
                    roi=roi,
                    resolution=self.resolutions[VolumeTypes.RAW])

        logger.debug("done")

        timing.stop()
        batch.profiling_stats.add(timing)

        return batch
Exemple #20
0
 def _open_file(self, filename):
     if os.path.exists(filename):
         return h5py.File(filename, 'r+')
     else:
         return h5py.File(filename, 'w')
Exemple #21
0
    def process(self, batch, request):

        if self.record_snapshot:

            try:
                os.makedirs(self.output_dir)
            except:
                pass

            snapshot_name = os.path.join(
                self.output_dir,
                self.output_filename.format(id=str(batch.id).zfill(8),
                                            iteration=int(batch.iteration
                                                          or 0)),
            )
            logger.info("saving to %s" % snapshot_name)
            with h5py.File(snapshot_name, "w") as f:

                for (array_key, array) in batch.arrays.items():

                    if array_key not in self.dataset_names:
                        continue

                    ds_name = self.dataset_names[array_key]

                    if array_key in self.dataset_dtypes:
                        dtype = self.dataset_dtypes[array_key]
                        dataset = f.create_dataset(
                            name=ds_name,
                            data=array.data.astype(dtype),
                            compression=self.compression_type,
                        )

                    else:
                        dataset = f.create_dataset(
                            name=ds_name,
                            data=array.data,
                            compression=self.compression_type,
                        )

                    if not array.spec.nonspatial:
                        if array.spec.roi is not None:
                            dataset.attrs[
                                "offset"] = array.spec.roi.get_offset()
                        dataset.attrs["resolution"] = self.spec[
                            array_key].voxel_size

                    if self.store_value_range:
                        dataset.attrs["value_range"] = (
                            np.asscalar(array.data.min()),
                            np.asscalar(array.data.max()),
                        )

                    # if array has attributes, add them to the dataset
                    for attribute_name, attribute in array.attrs.items():
                        dataset.attrs[attribute_name] = attribute

                for (graph_key, graph) in batch.graphs.items():
                    if graph_key not in self.dataset_names:
                        continue

                    ds_name = self.dataset_names[graph_key]

                    node_ids = []
                    locations = []
                    edges = []
                    for node in graph.nodes:
                        node_ids.append(node.id)
                        locations.append(node.location)
                    for edge in graph.edges:
                        edges.append((edge.u, edge.v))

                    f.create_dataset(
                        name=f"{ds_name}-ids",
                        data=np.array(node_ids, dtype=int),
                        compression=self.compression_type,
                    )
                    f.create_dataset(
                        name=f"{ds_name}-locations",
                        data=np.array(locations),
                        compression=self.compression_type,
                    )
                    f.create_dataset(
                        name=f"{ds_name}-edges",
                        data=np.array(edges),
                        compression=self.compression_type,
                    )

                if batch.loss is not None:
                    f["/"].attrs["loss"] = batch.loss

        self.n += 1
Exemple #22
0
 def _open_writable_file(self, path):
     return h5py.File(path, 'w')
Exemple #23
0
    def provide(self, request):

        timing = Timing(self)
        timing.start()

        batch = Batch()

        with h5py.File(self.filename, 'r') as hdf_file:

            for (volume_type, request_spec) in request.volume_specs.items():
                logger.debug("Reading %s in %s...", volume_type,
                             request_spec.roi)

                voxel_size = self.spec[volume_type].voxel_size

                # scale request roi to voxel units
                dataset_roi = request_spec.roi / voxel_size

                # shift request roi into dataset
                dataset_roi = dataset_roi - self.spec[
                    volume_type].roi.get_offset() / voxel_size

                # create volume spec
                volume_spec = self.spec[volume_type].copy()
                volume_spec.roi = request_spec.roi

                # add volume to batch
                batch.volumes[volume_type] = Volume(
                    self.__read(hdf_file, self.datasets[volume_type],
                                dataset_roi), volume_spec)

            # if pre and postsynaptic locations required, their id
            # SynapseLocation dictionaries should be created together s.t. ids
            # are unique and allow to find partner locations
            if PointsTypes.PRESYN in request.points_specs or PointsTypes.POSTSYN in request.points_specs:
                # assert request.points_specs[PointsTypes.PRESYN].roi == request.points_specs[PointsTypes.POSTSYN].roi
                # Cremi specific, ROI offset corresponds to offset present in the
                # synapse location relative to the raw data.
                assert self.spec[PointsTypes.PRESYN].roi.get_offset() == self.spec[PointsTypes.POSTSYN].roi.get_offset(),\
                    "Pre and Post synaptic offsets are not the same"
                # pdb.set_trace()
                # assert request[PointsTypes.PRESYN].roi == request[PointsTypes.POSTSYN].roi,\
                #     "Pre and Post synaptic roi requests are not the same"

                dataset_offset = self.spec[PointsTypes.PRESYN].roi.get_offset()
                presyn_points, postsyn_points = self.__get_syn_points(
                    roi=request.points_specs[PointsTypes.PRESYN].roi,
                    syn_file=hdf_file,
                    dataset_offset=dataset_offset)

            for (points_type, request_spec) in request.points_specs.items():

                logger.debug("Reading %s in %s...", points_type,
                             request_spec.roi)
                id_to_point = {
                    PointsTypes.PRESYN: presyn_points,
                    PointsTypes.POSTSYN: postsyn_points
                }[points_type]
                # TODO: so far assumed that all points have resolution of raw volume

                points_spec = self.spec[points_type].copy()
                points_spec.roi = request_spec.roi
                batch.points[points_type] = Points(data=id_to_point,
                                                   spec=points_spec)

        logger.debug("done")

        timing.stop()
        batch.profiling_stats.add(timing)

        return batch