コード例 #1
0
ファイル: hdf5_source.py プロジェクト: mmorehea/gunpowder
    def provide(self, request):

        timing = Timing(self)
        timing.start()

        spec = self.get_spec()

        batch = Batch()

        with h5py.File(self.filename, 'r') as f:

            for (volume_type, roi) in request.volumes.items():

                if volume_type not in spec.volumes:
                    raise RuntimeError("Asked for %s which this source does not provide"%volume_type)

                if not spec.volumes[volume_type].contains(roi):
                    raise RuntimeError("%s's ROI %s outside of my ROI %s"%(volume_type,roi,spec.volumes[volume_type]))

                logger.debug("Reading %s in %s..."%(volume_type,roi))

                # shift request roi into dataset
                dataset_roi = roi.shift(-spec.volumes[volume_type].get_offset())

                batch.volumes[volume_type] = Volume(
                        self.__read(f, self.datasets[volume_type], dataset_roi),
                        roi=roi,
                        resolution=self.resolutions[volume_type])

        logger.debug("done")

        timing.stop()
        batch.profiling_stats.add(timing)

        return batch
コード例 #2
0
    def __setup_batch(self, batch_spec, chunk):
        '''Allocate a batch matching the sizes of ``batch_spec``, using
        ``chunk`` as template.'''

        batch = Batch()

        for (array_key, spec) in batch_spec.array_specs.items():
            roi = spec.roi
            voxel_size = self.spec[array_key].voxel_size

            # get the 'non-spatial' shape of the chunk-batch
            # and append the shape of the request to it
            array = chunk.arrays[array_key]
            shape = array.data.shape[:-roi.dims()]
            shape += (roi.get_shape() // voxel_size)

            spec = self.spec[array_key].copy()
            spec.roi = roi
            logger.info("allocating array of shape %s for %s", shape,
                        array_key)
            batch.arrays[array_key] = Array(data=np.zeros(shape), spec=spec)

        for (points_key, spec) in batch_spec.points_specs.items():
            roi = spec.roi
            spec = self.spec[points_key].copy()
            spec.roi = roi
            batch.points[points_key] = Points(data={}, spec=spec)

        logger.debug("setup batch to fill %s", batch)

        return batch
コード例 #3
0
ファイル: merge_provider.py プロジェクト: yajivunev/gunpowder
    def provide(self, request):

        # create upstream requests
        upstream_requests = {}
        for key, spec in request.items():

            provider = self.key_to_provider[key]
            if provider not in upstream_requests:
                # use new random seeds per upstream request.
                # seeds picked by random should be deterministic since
                # the provided request already has a random seed.
                seed = random.randint(0, 2**32)
                upstream_requests[provider] = BatchRequest(random_seed=seed)

            upstream_requests[provider][key] = spec

        # execute requests, merge batches
        merged_batch = Batch()
        for provider, upstream_request in upstream_requests.items():

            batch = provider.request_batch(upstream_request)
            for key, array in batch.arrays.items():
                merged_batch.arrays[key] = array
            for key, graph in batch.graphs.items():
                merged_batch.graphs[key] = graph
            merged_batch.profiling_stats.merge_with(batch.profiling_stats)

        return merged_batch
コード例 #4
0
ファイル: tiffstack_source.py プロジェクト: htem/gunpowder
    def provide(self, request):

        timing = Timing(self)
        timing.start()

        batch = Batch()

        for (array_key, request_spec) in request.array_specs.items():

            logger.debug("Reading %s in %s...", array_key, request_spec.roi)

            voxel_size = self.spec[array_key].voxel_size

            # scale request roi to voxel units
            dataset_roi = request_spec.roi / voxel_size

            # shift request roi into dataset
            dataset_roi = dataset_roi - self.spec[array_key].roi.get_offset(
            ) / voxel_size

            # create array spec
            array_spec = self.spec[array_key].copy()
            array_spec.roi = request_spec.roi

            # add array to batch
            batch.arrays[array_key] = Array(
                self.__read(self.datasets[array_key], dataset_roi), array_spec)

        logger.debug("done")

        timing.stop()
        batch.profiling_stats.add(timing)

        return batch
コード例 #5
0
    def provide(self, request):

        timing = Timing(self)
        timing.start()

        batch = Batch()

        with self._open_file(self.filename) as data_file:
            for (array_key, request_spec) in request.array_specs.items():

                voxel_size = self.spec[array_key].voxel_size

                # scale request roi to voxel units
                dataset_roi = request_spec.roi / voxel_size

                # shift request roi into dataset
                dataset_roi = dataset_roi - self.spec[
                    array_key].roi.get_offset() / voxel_size

                # create array spec
                array_spec = self.spec[array_key].copy()
                array_spec.roi = request_spec.roi

                # add array to batch
                batch.arrays[array_key] = Array(
                    self.__read(data_file, self.datasets[array_key],
                                dataset_roi, self.channel_ids[array_key]),
                    array_spec)

        timing.stop()
        batch.profiling_stats.add(timing)

        return batch
コード例 #6
0
ファイル: graph_source.py プロジェクト: pattonw/neurolight
    def provide(self, request: BatchRequest) -> Batch:

        timing = Timing(self, "provide")
        timing.start()

        batch = Batch()

        for points_key in self.points:

            if points_key not in request:
                continue

            # Retrieve all points in the requested region using a kdtree for speed
            point_ids = self._query_kdtree(
                self.data.tree,
                (
                    np.array(request[points_key].roi.get_begin()),
                    np.array(request[points_key].roi.get_end()),
                ),
            )

            # To account for boundary crossings we must retrieve neighbors of all points
            # in the graph. This is too slow for large queries and less important
            points_subgraph = self._subgraph_points(
                point_ids,
                with_neighbors=len(point_ids) < len(self._graph.nodes) // 2)
            nodes = [
                Node(id=node, location=attrs["location"], attrs=attrs)
                for node, attrs in points_subgraph.nodes.items()
            ]
            edges = [Edge(u, v) for u, v in points_subgraph.edges]
            return_graph = Graph(nodes, edges,
                                 GraphSpec(roi=request[points_key].roi))

            # Handle boundary cases
            return_graph = return_graph.trim(request[points_key].roi)

            batch = Batch()
            batch.points[points_key] = return_graph

            logger.debug(
                "Graph points source provided {} points for roi: {}".format(
                    len(list(batch.points[points_key].nodes)),
                    request[points_key].roi))

            logger.debug(
                f"Providing {len(list(points_subgraph.nodes))} nodes to {points_key}"
            )

        timing.stop()
        batch.profiling_stats.add(timing)

        return batch
コード例 #7
0
    def provide(self, request):

        timing_process = Timing(self)
        timing_process.start()

        batch = Batch()

        with h5py.File(self.filename, 'r') as hdf_file:

            # if pre and postsynaptic locations required, their id
            # SynapseLocation dictionaries should be created together s.t. ids
            # are unique and allow to find partner locations

            if PointsKeys.PRESYN in request.points_specs or PointsKeys.POSTSYN in request.points_specs:
                assert self.kind == 'synapse'
                # If only PRESYN or POSTSYN requested, assume PRESYN ROI = POSTSYN ROI.
                pre_key = PointsKeys.PRESYN if PointsKeys.PRESYN in request.points_specs else PointsKeys.POSTSYN
                post_key = PointsKeys.POSTSYN if PointsKeys.POSTSYN in request.points_specs else PointsKeys.PRESYN
                presyn_points, postsyn_points = self.__get_syn_points(
                    pre_roi=request.points_specs[pre_key].roi,
                    post_roi=request.points_specs[post_key].roi,
                    syn_file=hdf_file)
                points = {
                    PointsKeys.PRESYN: presyn_points,
                    PointsKeys.POSTSYN: postsyn_points
                }
            else:
                assert self.kind == 'presyn' or self.kind == 'postsyn'
                synkey = list(self.datasets.items())[0][0]  # only key of dic.
                presyn_points, postsyn_points = self.__get_syn_points(
                    pre_roi=request.points_specs[synkey].roi,
                    post_roi=request.points_specs[synkey].roi,
                    syn_file=hdf_file)
                points = {
                    synkey:
                    presyn_points if self.kind == 'presyn' else postsyn_points
                }

            for (points_key, request_spec) in request.points_specs.items():
                logger.debug("Reading %s in %s...", points_key,
                             request_spec.roi)
                points_spec = self.spec[points_key].copy()
                points_spec.roi = request_spec.roi
                logger.debug("Number of points len()".format(
                    len(points[points_key])))
                batch.points[points_key] = Points(data=points[points_key],
                                                  spec=points_spec)

        timing_process.stop()
        batch.profiling_stats.add(timing_process)

        return batch
コード例 #8
0
    def provide(self, request):

        empty_request = (len(request) == 0)
        if empty_request:
            scan_spec = self.spec
        else:
            scan_spec = request

        stride = self.__get_stride()
        shift_roi = self.__get_shift_roi(scan_spec)

        shifts = self.__enumerate_shifts(shift_roi, stride)
        num_chunks = len(shifts)

        logger.info("scanning over %d chunks", num_chunks)

        # the batch to return
        self.batch = Batch()

        if self.num_workers > 1:

            for shift in shifts:
                shifted_reference = self.__shift_request(self.reference, shift)
                self.request_queue.put(shifted_reference)

            for i in range(num_chunks):

                chunk = self.workers.get()

                if not empty_request:
                    self.__add_to_batch(request, chunk)

                logger.info("processed chunk %d/%d", i, num_chunks)

        else:

            for i, shift in enumerate(shifts):

                shifted_reference = self.__shift_request(self.reference, shift)
                chunk = self.__get_chunk(shifted_reference)

                if not empty_request:
                    self.__add_to_batch(request, chunk)

                logger.info("processed chunk %d/%d", i, num_chunks)

        batch = self.batch
        self.batch = None

        logger.debug("returning batch %s", batch)

        return batch
コード例 #9
0
    def __setup_batch(self, request, chunk_batch):

        batch = Batch()
        for (volume_type, roi) in request.volumes.items():
            if volume_type == VolumeTypes.PRED_AFFINITIES or volume_type == VolumeTypes.GT_AFFINITIES:
                shape = (3, ) + roi.get_shape()
            else:
                shape = roi.get_shape()

            batch.volumes[volume_type] = Volume(
                data=np.zeros(shape),
                roi=roi,
                resolution=chunk_batch.volumes[VolumeTypes.RAW].resolution)
        return batch
コード例 #10
0
    def provide(self, request):

        timing = Timing(self)
        timing.start()

        batch = Batch()

        # if pre and postsynaptic locations requested, their id : SynapseLocation dictionaries should be created
        # together s.t. the ids are unique and allow to find partner locations
        if GraphKey.PRESYN in request.points or GraphKey.POSTSYN in request.points:
            try:  # either both have the same roi, or only one of them is requested
                assert request.points[GraphKey.PRESYN] == request.points[
                    GraphKey.POSTSYN]
            except AssertionError:
                assert GraphKey.PRESYN not in request.points or GraphKey.POSTSYN not in request.points
            if GraphKey.PRESYN in request.points:
                presyn_points, postsyn_points = self.__read_syn_points(
                    roi=request.points[GraphKey.PRESYN])
            elif GraphKey.POSTSYN in request.points:
                presyn_points, postsyn_points = self.__read_syn_points(
                    roi=request.points[GraphKey.POSTSYN])

        for (points_key, roi) in request.points.items():
            # check if requested points can be provided
            if points_key not in self.spec:
                raise RuntimeError(
                    "Asked for %s which this source does not provide" %
                    points_key)
            # check if request roi lies within provided roi
            if not self.spec[points_key].roi.contains(roi):
                raise RuntimeError(
                    "%s's ROI %s outside of my ROI %s" %
                    (points_key, roi, self.spec[points_key].roi))

            logger.debug("Reading %s in %s..." % (points_key, roi))
            id_to_point = {
                GraphKey.PRESYN: presyn_points,
                GraphKey.POSTSYN: postsyn_points
            }[points_key]

            batch.points[points_key] = Graph(data=id_to_point,
                                             spec=GraphSpec(roi=roi))

        logger.debug("done")

        timing.stop()
        batch.profiling_stats.add(timing)

        return batch
コード例 #11
0
ファイル: clahe.py プロジェクト: pattonw/neurolight
    def process(self, batch, request):
        output = Batch()

        for in_key, out_key in zip(self.arrays, self.output_arrays):
            array = batch[in_key]
            data = array.data
            d_min = data.min()
            d_max = data.max()
            assert (
                d_min >= 0 and d_max <= 1
            ), f"Clahe expects data in range (0,1), got ({d_min}, {d_max})"
            if np.isclose(d_max, d_min):
                output[out_key] = Array(data, array.spec)
                continue
            if self.normalize:
                data = (data - d_min) / (d_max - d_min)
            shape = data.shape
            data_dims = len(shape)
            kernel_dims = len(self.kernel_size)
            extra_dims = data_dims - kernel_dims
            voxel_size = array.spec.voxel_size

            for index in itertools.product(*[range(s) for s in shape[:extra_dims]]):
                data[index] = clahe(
                    data[index],
                    kernel_size=Coordinate(self.kernel_size / voxel_size),
                    clip_limit=self.clip_limit,
                    nbins=self.nbins,
                )
            assert (
                data.min() >= 0 and data.max() <= 1
            ), f"Clahe should output data in range (0,1), got ({data.min()}, {data.max()})"
            output[out_key] = Array(data, array.spec).crop(request[out_key].roi)
        return output
コード例 #12
0
ファイル: mclahe.py プロジェクト: pattonw/neurolight
    def process(self, batch, request):
        output = Batch()

        for in_key, out_key in zip(self.arrays, self.output_arrays):
            array = batch[in_key]
            data = array.data
            shape = data.shape
            data_dims = len(shape)
            kernel_dims = len(self.kernel_size)
            extra_dims = data_dims - kernel_dims
            if self.slice_wise:
                for index in itertools.product(
                        *[range(s) for s in shape[:extra_dims]]):
                    data[index] = mclahe(
                        data[index],
                        kernel_size=self.kernel_size,
                        clip_limit=self.clip_limit,
                        n_bins=self.nbins,
                        use_gpu=False,
                        adaptive_hist_range=self.adaptive_hist_range,
                    )
            else:
                full_kernel = np.array(
                    (1, ) * extra_dims + tuple(self.kernel_size), dtype=int)
                data = mclahe(
                    data,
                    kernel_size=full_kernel,
                    clip_limit=self.clip_limit,
                    n_bins=self.nbins,
                    # use_gpu=False,
                ).astype(self.spec[out_key].dtype)
            output[out_key] = Array(data,
                                    array.spec).crop(request[out_key].roi)
        return output
コード例 #13
0
    def provide(self, request):

        timing = Timing(self)
        timing.start()

        min_bb = request[self.points].roi.get_begin()
        max_bb = request[self.points].roi.get_end()

        logger.debug(
            "CSV points source got request for %s",
            request[self.points].roi)

        point_filter = np.ones((self.data.shape[0],), dtype=np.bool)
        for d in range(self.ndims):
            point_filter = np.logical_and(point_filter, self.data[:,d] >= min_bb[d])
            point_filter = np.logical_and(point_filter, self.data[:,d] < max_bb[d])

        filtered = self.data[point_filter]
        ids = np.arange(len(self.data))[point_filter]

        points_data = {

            i: Point(p)
            for i, p in zip(ids, filtered)
        }
        points_spec = PointsSpec(roi=request[self.points].roi.copy())

        batch = Batch()
        batch.points[self.points] = Points(points_data, points_spec)

        timing.stop()
        batch.profiling_stats.add(timing)

        return batch
コード例 #14
0
    def process(self, batch, request):
        outputs = Batch()

        if self.target not in request:
            return

        input_roi = batch.arrays[self.source].spec.roi
        request_roi = request[self.target].roi

        assert input_roi.contains(request_roi)

        # upsample

        logger.debug("upsampling %s with %s", self.source, self.factor)

        crop = batch.arrays[self.source].crop(request_roi)
        data = crop.data

        for d, f in enumerate(self.factor):
            data = np.repeat(data, f, axis=d)

        # create output array
        spec = self.spec[self.target].copy()
        spec.roi = request_roi
        outputs.arrays[self.target] = Array(data, spec)
        return outputs
コード例 #15
0
    def provide(self, request):

        timing = Timing(self)
        timing.start()

        batch = Batch()

        for key, spec in request.items():
            logger.debug(f"fetching {key} in roi {spec.roi}")
            requested_graph = self.graph_provider.get_graph(
                spec.roi,
                edge_inclusion="either",
                node_inclusion="dangling",
                node_attrs=self.node_attrs,
                edge_attrs=self.edge_attrs,
                nodes_filter=self.nodes_filter,
                edges_filter=self.edges_filter,
            )
            logger.debug(
                f"got {len(requested_graph.nodes)} nodes and {len(requested_graph.edges)} edges"
            )
            for node, attrs in list(requested_graph.nodes.items()):
                if self.dist_attribute in attrs:
                    if attrs[self.dist_attribute] < self.min_dist:
                        requested_graph.remove_node(node)

            logger.debug(
                f"{len(requested_graph.nodes)} nodes remaining after filtering by distance"
            )

            if len(requested_graph.nodes) > self.num_nodes:
                nodes = list(requested_graph.nodes)
                nodes_to_keep = set(random.sample(nodes, self.num_nodes))

                for node in list(requested_graph.nodes()):
                    if node not in nodes_to_keep:
                        requested_graph.remove_node(node)

            for node, attrs in requested_graph.nodes.items():
                attrs["location"] = np.array(attrs[self.position_attribute],
                                             dtype=np.float32)
                attrs["id"] = node

            if spec.directed:
                requested_graph = requested_graph.to_directed()
            else:
                requested_graph = requested_graph.to_undirected()

            logger.debug(
                f"providing {key} with {len(requested_graph.nodes)} nodes and {len(requested_graph.edges)} edges"
            )

            points = Graph.from_nx_graph(requested_graph, spec)
            points.crop(spec.roi)
            batch[key] = points

        timing.stop()
        batch.profiling_stats.add(timing)

        return batch
コード例 #16
0
    def provide(self, request):

        timing = Timing(self)
        timing.start()

        min_bb = request[self.points].roi.get_begin()
        max_bb = request[self.points].roi.get_end()

        logger.debug("CSV points source got request for %s",
                     request[self.points].roi)

        point_filter = np.ones((self.data.shape[0], ), dtype=np.bool)
        for d in range(self.ndims):
            point_filter = np.logical_and(point_filter,
                                          self.data[:, d] >= min_bb[d])
            point_filter = np.logical_and(point_filter,
                                          self.data[:, d] < max_bb[d])

        points_data = self._get_points(point_filter)
        points_spec = GraphSpec(roi=request[self.points].roi.copy())

        batch = Batch()
        batch.graphs[self.points] = Graph(points_data, [], points_spec)

        timing.stop()
        batch.profiling_stats.add(timing)

        return batch
コード例 #17
0
    def provide(self, request):

        empty_request = (len(request) == 0)
        if not empty_request:
            raise RuntimeError(
                "requests made to DaisyRequestBlocks have to be empty")

        if self.num_workers > 1:

            self.workers = [
                multiprocessing.Process(target=self.__get_chunks)
                for _ in range(self.num_workers)
            ]

            for worker in self.workers:
                worker.start()

            for worker in self.workers:
                worker.join()

        else:

            self.__get_chunks()

        return Batch()
コード例 #18
0
    def provide(self, request):

        timing = Timing(self)
        timing.start()

        batch = Batch()

        for key, spec in request.items():
            logger.debug(f"fetching {key} in roi {spec.roi}")
            requested_graph = self.graph_provider.get_graph(
                spec.roi,
                edge_inclusion=self.edge_inclusion,
                node_inclusion=self.node_inclusion,
                node_attrs=self.node_attrs,
                edge_attrs=self.edge_attrs,
                nodes_filter=self.nodes_filter,
                edges_filter=self.edges_filter,
            )
            logger.debug(
                f"got {len(requested_graph.nodes)} nodes and {len(requested_graph.edges)} edges"
            )

            failed_nodes = []

            for node, attrs in requested_graph.nodes.items():
                try:
                    attrs["location"] = np.array(
                        attrs[self.position_attribute], dtype=np.float32)
                except KeyError:
                    logger.warning(
                        f"node: {node} was written (probably part of an edge), but never given coordinates!"
                    )
                    failed_nodes.append(node)
                attrs["id"] = node

            for node in failed_nodes:
                if self.fail_on_inconsistent_node:
                    raise ValueError(
                        f"Mongodb contains node {node} without location! "
                        f"It was probably written as part of an edge")
                requested_graph.remove_node(node)

            if spec.directed:
                requested_graph = requested_graph.to_directed()
            else:
                requested_graph = requested_graph.to_undirected()

            points = Graph.from_nx_graph(requested_graph, spec)
            points.relabel_connected_components()
            points.crop(spec.roi)
            batch[key] = points

            logger.debug(f"{key} with {len(list(points.nodes))} nodes")

        timing.stop()
        batch.profiling_stats.add(timing)

        return batch
コード例 #19
0
    def provide(self, request):

        timing = Timing(self)
        timing.start()

        batch = Batch()

        with h5py.File(self.filename, 'r') as hdf_file:

            # if pre and postsynaptic locations required, their id
            # SynapseLocation dictionaries should be created together s.t. ids
            # are unique and allow to find partner locations
            if PointsKeys.PRESYN in request.points_specs or PointsKeys.POSTSYN in request.points_specs:
                assert request.points_specs[
                    PointsKeys.PRESYN].roi == request.points_specs[
                        PointsKeys.POSTSYN].roi
                # Cremi specific, ROI offset corresponds to offset present in the
                # synapse location relative to the raw data.
                dataset_offset = self.spec[PointsKeys.PRESYN].roi.get_offset()
                presyn_points, postsyn_points = self.__get_syn_points(
                    roi=request.points_specs[PointsKeys.PRESYN].roi,
                    syn_file=hdf_file,
                    dataset_offset=dataset_offset)

            for (points_key, request_spec) in request.points_specs.items():

                logger.debug("Reading %s in %s...", points_key,
                             request_spec.roi)
                id_to_point = {
                    PointsKeys.PRESYN: presyn_points,
                    PointsKeys.POSTSYN: postsyn_points
                }[points_key]

                points_spec = self.spec[points_key].copy()
                points_spec.roi = request_spec.roi
                batch.points[points_key] = Points(data=id_to_point,
                                                  spec=points_spec)

        logger.debug("done")

        timing.stop()
        batch.profiling_stats.add(timing)

        return batch
コード例 #20
0
    def provide(self, request):

        timing = Timing(self)
        timing.start()

        spec = self.get_spec()

        batch = Batch()
        logger.debug("providing batch with resolution of {}".format(
            self.resolution))

        for (volume_type, roi) in request.volumes.items():

            if volume_type not in spec.volumes:
                raise RuntimeError(
                    "Asked for %s which this source does not provide" %
                    volume_type)

            if not spec.volumes[volume_type].contains(roi):
                raise RuntimeError(
                    "%s's ROI %s outside of my ROI %s" %
                    (volume_type, roi, spec.volumes[volume_type]))

            read, interpolate = {
                VolumeType.RAW: (self.__read_raw, True),
                VolumeType.GT_LABELS: (self.__read_gt, False),
                VolumeType.GT_MASK: (self.__read_gt_mask, False),
            }[volume_type]

            logger.debug("Reading %s in %s..." % (volume_type, roi))
            batch.volumes[volume_type] = Volume(
                read(roi),
                roi=roi,
                # TODO: get resolution from repository
                resolution=self.resolution,
                interpolate=interpolate)

        logger.debug("done")

        timing.stop()
        batch.profiling_stats.add(timing)

        return batch
コード例 #21
0
    def provide(self, request):

        timing = Timing(self)
        timing.start()

        spec = self.get_spec()

        batch = Batch()

        with h5py.File(self.filename, 'r') as f:

            for (volume_type, roi) in request.volumes.items():

                if volume_type not in spec.volumes:
                    raise RuntimeError(
                        "Asked for %s which this source does not provide" %
                        volume_type)

                if not spec.volumes[volume_type].contains(roi):
                    raise RuntimeError(
                        "%s's ROI %s outside of my ROI %s" %
                        (volume_type, roi, spec.volumes[volume_type]))

                interpolate = {
                    VolumeType.RAW: True,
                    VolumeType.GT_LABELS: False,
                    VolumeType.GT_MASK: False,
                    VolumeType.ALPHA_MASK: True,
                }[volume_type]

                logger.debug("Reading %s in %s..." % (volume_type, roi))
                batch.volumes[volume_type] = Volume(
                    self.__read(f, self.datasets[volume_type], roi),
                    roi=roi,
                    resolution=self.resolutions[volume_type],
                    interpolate=interpolate)

        logger.debug("done")

        timing.stop()
        batch.profiling_stats.add(timing)

        return batch
コード例 #22
0
ファイル: chunk.py プロジェクト: mananshah99/gunpowder
    def __setup_batch(self, batch_spec, reference):

        batch = Batch(batch_spec)

        for (volume_type, volume) in reference.volumes.items():

            interpolate = False
            if volume_type == VolumeType.RAW:
                shape = batch_spec.input_roi.get_shape()
                interpolate = True
            elif volume_type == VolumeType.GT_AFFINITIES or volume_type == VolumeType.PRED_AFFINITIES:
                shape = (3, ) + batch_spec.output_roi.get_shape()
            else:
                shape = batch_spec.output_roi.get_shape()

            batch.volumes[volume_type] = Volume(
                np.zeros(shape, volume.data.dtype), interpolate)

        return batch
コード例 #23
0
    def provide(self, request):

        timing = Timing(self)
        timing.start()

        batch = Batch()

        for (array_key, request_spec) in request.array_specs.items():

            logger.debug("Reading %s in %s...", array_key, request_spec.roi)

            voxel_size = self.spec[array_key].voxel_size

            # scale request roi to voxel units
            dataset_roi = request_spec.roi / voxel_size

            # shift request roi into dataset
            dataset_roi = dataset_roi - self.spec[array_key].roi.get_offset(
            ) / voxel_size

            # create array spec
            array_spec = self.spec[array_key].copy()
            array_spec.roi = request_spec.roi

            # read the data
            if array_key in self.datasets:
                data = self.__read_array(self.datasets[array_key], dataset_roi)
            elif array_key in self.masks:
                data = self.__read_mask(self.masks[array_key], dataset_roi)
            else:
                assert False, (
                    "Encountered a request for %s that is neither a volume "
                    "nor a mask." % array_key)

            # add array to batch
            batch.arrays[array_key] = Array(data, array_spec)

        logger.debug("done")

        timing.stop()
        batch.profiling_stats.add(timing)

        return batch
コード例 #24
0
    def provide(self, request):

        timing = Timing(self)
        timing.start()

        batch = Batch()

        cv = CloudVolume(self.cloudvolume_url, use_https=True, mip=self.mip)

        request_spec = request.array_specs[self.array_key]
        array_key = self.array_key
        logger.debug("Reading %s in %s...", array_key, request_spec.roi)

        voxel_size = self.array_spec.voxel_size

        # scale request roi to voxel units
        dataset_roi = request_spec.roi / voxel_size

        # shift request roi into dataset
        dataset_roi = dataset_roi - self.spec[
            array_key].roi.get_offset() / voxel_size

        # create array spec
        array_spec = self.array_spec.copy()
        array_spec.roi = request_spec.roi
        # array_spec.voxel_size = array_spec.voxel_size

        # add array to batch
        batch.arrays[array_key] = Array(
            self.__read(cv, dataset_roi),
            array_spec)

        logger.debug("done")

        timing.stop()
        batch.profiling_stats.add(timing)

        return batch
コード例 #25
0
    def provide(self, request):

        # create upstream requests
        upstream_requests = {}
        for key, spec in request.items():

            provider = self.key_to_provider[key]
            if provider not in upstream_requests:
                upstream_requests[provider] = BatchRequest()

            upstream_requests[provider][key] = spec

        # execute requests, merge batches
        merged_batch = Batch()
        for provider, upstream_request in upstream_requests.items():

            batch = provider.request_batch(upstream_request)
            for key, array in batch.arrays.items():
                merged_batch.arrays[key] = array
            for key, points in batch.points.items():
                merged_batch.points[key] = points

        return merged_batch
コード例 #26
0
    def provide(self, request):

        # create upstream requests
        upstream_requests = {}
        for key, spec in request.items():

            provider = self.key_to_provider[key]
            if provider not in upstream_requests:
                upstream_requests[provider] = BatchRequest()

            upstream_requests[provider][key] = spec

        # execute requests, merge batches
        merged_batch = Batch()
        for provider, upstream_request in upstream_requests.items():

            batch = provider.request_batch(upstream_request)
            for key, array in batch.arrays.items():
                merged_batch.arrays[key] = array
            for key, graph in batch.graphs.items():
                merged_batch.graphs[key] = graph
            merged_batch.profiling_stats.merge_with(batch.profiling_stats)

        return merged_batch
コード例 #27
0
ファイル: unsqueeze.py プロジェクト: yajivunev/gunpowder
    def process(self, batch, request):
        outputs = Batch()
        for array in self.arrays:
            if array in batch:
                if not batch[array].spec.nonspatial:
                    spatial_dims = request[array].roi.dims()
                    if self.axis > batch[array].data.ndim - spatial_dims:
                        raise ValueError((
                            f"Unsqueeze.axis={self.axis} not permitted. "
                            "Unsqueeze only supported for "
                            "non-spatial dimensions of Array."
                        ))

                outputs[array] = copy.deepcopy(batch[array])
                outputs[array].data = np.expand_dims(batch[array].data, self.axis)
        return outputs
コード例 #28
0
    def provide(self, request: BatchRequest) -> Batch:
        random.seed(request.random_seed)
        np.random.seed(request.random_seed)

        timing = Timing(self, "provide")
        timing.start()

        batch = Batch()

        roi = request[self.points].roi

        region_shape = roi.get_shape()

        trees = []
        for _ in range(self.n_obj):
            for _ in range(100):
                root = np.random.random(len(region_shape)) * region_shape
                tree = self._grow_tree(
                    root, Roi((0,) * len(region_shape), region_shape)
                )
                if self.num_nodes[0] <= len(tree.nodes) <= self.num_nodes[1]:
                    break
            trees.append(tree)

        # logger.info("{} trees got, expected {}".format(len(trees), self.n_obj))

        trees_graph = nx.disjoint_union_all(trees)

        points = {
            node_id: Node(np.floor(node_attrs["pos"]) + roi.get_begin())
            for node_id, node_attrs in trees_graph.nodes.items()
        }

        batch[self.points] = Graph(points, request[self.points], list(trees_graph.edges))

        timing.stop()
        batch.profiling_stats.add(timing)

        # self._plot_tree(tree)

        return batch
コード例 #29
0
    def process(self, batch, request):
        outputs = Batch()

        # downsample
        if isinstance(self.factor, tuple):
            slices = tuple(slice(None, None, k) for k in self.factor)
        else:
            slices = tuple(
                slice(None, None, self.factor)
                for i in range(batch[self.source].spec.roi.dims()))

        logger.debug("downsampling %s with %s", self.source, slices)

        data = batch.arrays[self.source].data[slices]

        # create output array
        spec = self.spec[self.target].copy()
        spec.roi = request[self.target].roi
        outputs.arrays[self.target] = Array(data, spec)

        return outputs
コード例 #30
0
    def provide(self, request: BatchRequest) -> Batch:

        timing = Timing(self)
        timing.start()

        logger.debug("Swc points source got request for %s", request[self.points].roi)

        # Retrieve all points in the requested region using a kdtree for speed
        points = self._query_kdtree(
            self.data.tree,
            (
                np.array(request[self.points].roi.get_begin()),
                np.array(request[self.points].roi.get_end()),
            ),
        )

        # Obtain subgraph that contains these points. Keep track of edges that
        # are present in the main graph, but not the subgraph
        sub_graph, predecessors, successors = self._points_to_graph(points)

        # Handle boundary cases
        self._handle_boundary_crossings(
            sub_graph, predecessors, successors, request[self.points].roi
        )

        # Convert graph into Points format
        points_data = self._graph_to_data(sub_graph)

        points_spec = PointsSpec(roi=request[self.points].roi.copy())

        batch = Batch()
        batch.points[self.points] = Points(points_data, points_spec)

        timing.stop()
        batch.profiling_stats.add(timing)

        return batch