Ejemplo n.º 1
0
    def provide(self, request):

        timing = Timing(self)
        timing.start()

        batch = Batch()

        with self._open_file(self.filename) as data_file:
            for (array_key, request_spec) in request.array_specs.items():

                voxel_size = self.spec[array_key].voxel_size

                # scale request roi to voxel units
                dataset_roi = request_spec.roi / voxel_size

                # shift request roi into dataset
                dataset_roi = dataset_roi - self.spec[
                    array_key].roi.get_offset() / voxel_size

                # create array spec
                array_spec = self.spec[array_key].copy()
                array_spec.roi = request_spec.roi

                # add array to batch
                batch.arrays[array_key] = Array(
                    self.__read(data_file, self.datasets[array_key],
                                dataset_roi, self.channel_ids[array_key]),
                    array_spec)

        timing.stop()
        batch.profiling_stats.add(timing)

        return batch
Ejemplo n.º 2
0
    def provide(self, request):

        report_next_timeout = 2
        num_rejected = 0

        timing = Timing(self)
        timing.start()

        have_good_batch = False
        while not have_good_batch:

            batch = self.upstream_provider.request_batch(request)

            if batch.arrays[self.ensure_nonempty].data.size != 0:

                have_good_batch = True
                logger.debug("Accepted batch with shape: %s",
                             batch.arrays[self.ensure_nonempty].data.shape)

            else:

                num_rejected += 1

                if timing.elapsed() > report_next_timeout:
                    logger.info(
                        "rejected %s batches, been waiting for a good one "
                        "since %s", num_rejected, report_next_timeout)
                    report_next_timeout *= 2

        timing.stop()
        batch.profiling_stats.add(timing)

        return batch
Ejemplo n.º 3
0
    def provide(self, request):

        timing = Timing(self)
        timing.start()

        min_bb = request[self.points].roi.get_begin()
        max_bb = request[self.points].roi.get_end()

        logger.debug(
            "CSV points source got request for %s",
            request[self.points].roi)

        point_filter = np.ones((self.data.shape[0],), dtype=np.bool)
        for d in range(self.ndims):
            point_filter = np.logical_and(point_filter, self.data[:,d] >= min_bb[d])
            point_filter = np.logical_and(point_filter, self.data[:,d] < max_bb[d])

        filtered = self.data[point_filter]
        ids = np.arange(len(self.data))[point_filter]

        points_data = {

            i: Point(p)
            for i, p in zip(ids, filtered)
        }
        points_spec = PointsSpec(roi=request[self.points].roi.copy())

        batch = Batch()
        batch.points[self.points] = Points(points_data, points_spec)

        timing.stop()
        batch.profiling_stats.add(timing)

        return batch
    def provide(self, request):

        timing = Timing(self)
        timing.start()

        batch = Batch()

        for key, spec in request.items():
            logger.debug(f"fetching {key} in roi {spec.roi}")
            requested_graph = self.graph_provider.get_graph(
                spec.roi,
                edge_inclusion="either",
                node_inclusion="dangling",
                node_attrs=self.node_attrs,
                edge_attrs=self.edge_attrs,
                nodes_filter=self.nodes_filter,
                edges_filter=self.edges_filter,
            )
            logger.debug(
                f"got {len(requested_graph.nodes)} nodes and {len(requested_graph.edges)} edges"
            )
            for node, attrs in list(requested_graph.nodes.items()):
                if self.dist_attribute in attrs:
                    if attrs[self.dist_attribute] < self.min_dist:
                        requested_graph.remove_node(node)

            logger.debug(
                f"{len(requested_graph.nodes)} nodes remaining after filtering by distance"
            )

            if len(requested_graph.nodes) > self.num_nodes:
                nodes = list(requested_graph.nodes)
                nodes_to_keep = set(random.sample(nodes, self.num_nodes))

                for node in list(requested_graph.nodes()):
                    if node not in nodes_to_keep:
                        requested_graph.remove_node(node)

            for node, attrs in requested_graph.nodes.items():
                attrs["location"] = np.array(attrs[self.position_attribute],
                                             dtype=np.float32)
                attrs["id"] = node

            if spec.directed:
                requested_graph = requested_graph.to_directed()
            else:
                requested_graph = requested_graph.to_undirected()

            logger.debug(
                f"providing {key} with {len(requested_graph.nodes)} nodes and {len(requested_graph.edges)} edges"
            )

            points = Graph.from_nx_graph(requested_graph, spec)
            points.crop(spec.roi)
            batch[key] = points

        timing.stop()
        batch.profiling_stats.add(timing)

        return batch
Ejemplo n.º 5
0
    def provide(self, request):

        timing = Timing(self)
        timing.start()

        if request != self.current_request:

            if self.workers is not None:
                logger.info(
                    "new request received, stopping current workers...")
                self.workers.stop()

            self.current_request = copy.deepcopy(request)

            logger.info("starting new set of workers...")
            self.workers = ProducerPool([
                lambda i=i: self.__run_worker(i)
                for i in range(self.num_workers)
            ],
                                        queue_size=self.cache_size)
            self.workers.start()

        logger.debug("getting batch from queue...")
        batch = self.workers.get()

        timing.stop()
        batch.profiling_stats.add(timing)

        return batch
Ejemplo n.º 6
0
    def provide(self, request):

        timing = Timing(self)
        timing.start()

        batch = Batch()

        _, request_spec = request.array_specs.items()[0]

        logger.debug("Reading %s in %s...", self.array, request_spec.roi)

        voxel_size = self.spec[self.array].voxel_size

        # scale request roi to voxel units
        dataset_roi = request_spec.roi/voxel_size

        # shift request roi into dataset
        dataset_roi = dataset_roi - self.spec[self.array].roi.get_offset()/voxel_size

        # create array spec
        array_spec = self.spec[self.array].copy()
        array_spec.roi = request_spec.roi

        # add array to batch
        batch.arrays[self.array] = Array(
            self.__read(dataset_roi),
            array_spec)

        logger.debug("done")

        timing.stop()
        batch.profiling_stats.add(timing)

        return batch
Ejemplo n.º 7
0
    def provide(self, request):

        # operate on a copy of the request, to provide the original request to
        # 'process' for convenience
        upstream_request = copy.deepcopy(request)

        skip = self.__can_skip(request)

        timing_prepare = Timing(self, 'prepare')
        timing_prepare.start()

        if not skip:
            self.prepare(upstream_request)
            self.remove_provided(upstream_request)

        timing_prepare.stop()

        batch = self.get_upstream_provider().request_batch(upstream_request)

        timing_process = Timing(self, 'process')
        timing_process.start()

        if not skip:
            self.process(batch, request)

        timing_process.stop()

        batch.profiling_stats.add(timing_prepare)
        batch.profiling_stats.add(timing_process)

        return batch
Ejemplo n.º 8
0
    def provide(self, request):

        timing = Timing(self)
        timing.start()

        min_bb = request[self.points].roi.get_begin()
        max_bb = request[self.points].roi.get_end()

        logger.debug(
            "CSV points source got request for %s",
            request[self.points].roi)

        point_filter = np.ones((self.locations.shape[0],), dtype=np.bool)
        for d in range(self.locations.shape[1]):
            point_filter = np.logical_and(point_filter,
                                          self.locations[:, d] >= min_bb[d])
            point_filter = np.logical_and(point_filter,
                                          self.locations[:, d] < max_bb[d])

        points_data = self._get_points(point_filter)
        logger.debug("Points data: %s", points_data)
        logger.debug("Type of point: %s", type(list(points_data.values())[0]))
        points_spec = PointsSpec(roi=request[self.points].roi.copy())

        batch = Batch()
        batch.points[self.points] = Points(points_data, points_spec)

        timing.stop()
        batch.profiling_stats.add(timing)

        return batch
Ejemplo n.º 9
0
    def provide(self, request):

        timing = Timing(self)
        timing.start()

        spec = self.get_spec()

        batch = Batch()

        with h5py.File(self.filename, 'r') as f:

            for (volume_type, roi) in request.volumes.items():

                if volume_type not in spec.volumes:
                    raise RuntimeError("Asked for %s which this source does not provide"%volume_type)

                if not spec.volumes[volume_type].contains(roi):
                    raise RuntimeError("%s's ROI %s outside of my ROI %s"%(volume_type,roi,spec.volumes[volume_type]))

                logger.debug("Reading %s in %s..."%(volume_type,roi))

                # shift request roi into dataset
                dataset_roi = roi.shift(-spec.volumes[volume_type].get_offset())

                batch.volumes[volume_type] = Volume(
                        self.__read(f, self.datasets[volume_type], dataset_roi),
                        roi=roi,
                        resolution=self.resolutions[volume_type])

        logger.debug("done")

        timing.stop()
        batch.profiling_stats.add(timing)

        return batch
Ejemplo n.º 10
0
    def provide(self, request):

        timing = Timing(self)
        timing.start()

        min_bb = request[self.points].roi.get_begin()
        max_bb = request[self.points].roi.get_end()

        logger.debug("CSV points source got request for %s",
                     request[self.points].roi)

        point_filter = np.ones((self.data.shape[0], ), dtype=np.bool)
        for d in range(self.ndims):
            point_filter = np.logical_and(point_filter,
                                          self.data[:, d] >= min_bb[d])
            point_filter = np.logical_and(point_filter,
                                          self.data[:, d] < max_bb[d])

        points_data = self._get_points(point_filter)
        points_spec = GraphSpec(roi=request[self.points].roi.copy())

        batch = Batch()
        batch.graphs[self.points] = Graph(points_data, [], points_spec)

        timing.stop()
        batch.profiling_stats.add(timing)

        return batch
Ejemplo n.º 11
0
    def provide(self, request):

        timing = Timing(self)
        timing.start()

        batch = Batch()

        for key, spec in request.items():
            logger.debug(f"fetching {key} in roi {spec.roi}")
            requested_graph = self.graph_provider.get_graph(
                spec.roi,
                edge_inclusion=self.edge_inclusion,
                node_inclusion=self.node_inclusion,
                node_attrs=self.node_attrs,
                edge_attrs=self.edge_attrs,
                nodes_filter=self.nodes_filter,
                edges_filter=self.edges_filter,
            )
            logger.debug(
                f"got {len(requested_graph.nodes)} nodes and {len(requested_graph.edges)} edges"
            )

            failed_nodes = []

            for node, attrs in requested_graph.nodes.items():
                try:
                    attrs["location"] = np.array(
                        attrs[self.position_attribute], dtype=np.float32)
                except KeyError:
                    logger.warning(
                        f"node: {node} was written (probably part of an edge), but never given coordinates!"
                    )
                    failed_nodes.append(node)
                attrs["id"] = node

            for node in failed_nodes:
                if self.fail_on_inconsistent_node:
                    raise ValueError(
                        f"Mongodb contains node {node} without location! "
                        f"It was probably written as part of an edge")
                requested_graph.remove_node(node)

            if spec.directed:
                requested_graph = requested_graph.to_directed()
            else:
                requested_graph = requested_graph.to_undirected()

            points = Graph.from_nx_graph(requested_graph, spec)
            points.relabel_connected_components()
            points.crop(spec.roi)
            batch[key] = points

            logger.debug(f"{key} with {len(list(points.nodes))} nodes")

        timing.stop()
        batch.profiling_stats.add(timing)

        return batch
Ejemplo n.º 12
0
    def provide(self, request):

        skip = self.__can_skip(request)

        timing_prepare = Timing(self, "prepare")
        timing_prepare.start()

        downstream_request = request.copy()

        if not skip:
            dependencies = self.prepare(request)
            if isinstance(dependencies, BatchRequest):
                upstream_request = request.merge(dependencies)
            elif dependencies is None:
                upstream_request = request.copy()
            else:
                raise BatchFilterError(
                    self, f"This BatchFilter returned a {type(dependencies)}! "
                    "Supported return types are: `BatchRequest` containing your exact "
                    "dependencies or `None`, indicating a dependency on the full request."
                )
            self.remove_provided(upstream_request)
        else:
            upstream_request = request.copy()
        self.remove_provided(upstream_request)

        timing_prepare.stop()

        batch = self.get_upstream_provider().request_batch(upstream_request)

        timing_process = Timing(self, "process")
        timing_process.start()

        if not skip:
            if dependencies is not None:
                dependencies.remove_placeholders()
                node_batch = batch.crop(dependencies)
            else:
                node_batch = batch
            downstream_request.remove_placeholders()
            processed_batch = self.process(node_batch, downstream_request)
            if processed_batch is None:
                processed_batch = node_batch
            batch = batch.merge(
                processed_batch,
                merge_profiling_stats=False).crop(downstream_request)

        timing_process.stop()

        batch.profiling_stats.add(timing_prepare)
        batch.profiling_stats.add(timing_process)

        return batch
Ejemplo n.º 13
0
    def provide(self, request):

        timing = Timing(self)
        timing.start()

        logger.debug("getting batch from queue...")
        batch = self.workers.get()

        timing.stop()
        batch.profiling_stats.add(timing)

        return batch
Ejemplo n.º 14
0
    def provide(self, request: BatchRequest) -> Batch:

        timing = Timing(self, "provide")
        timing.start()

        batch = Batch()

        for points_key in self.points:

            if points_key not in request:
                continue

            # Retrieve all points in the requested region using a kdtree for speed
            point_ids = self._query_kdtree(
                self.data.tree,
                (
                    np.array(request[points_key].roi.get_begin()),
                    np.array(request[points_key].roi.get_end()),
                ),
            )

            # To account for boundary crossings we must retrieve neighbors of all points
            # in the graph. This is too slow for large queries and less important
            points_subgraph = self._subgraph_points(
                point_ids,
                with_neighbors=len(point_ids) < len(self._graph.nodes) // 2)
            nodes = [
                Node(id=node, location=attrs["location"], attrs=attrs)
                for node, attrs in points_subgraph.nodes.items()
            ]
            edges = [Edge(u, v) for u, v in points_subgraph.edges]
            return_graph = Graph(nodes, edges,
                                 GraphSpec(roi=request[points_key].roi))

            # Handle boundary cases
            return_graph = return_graph.trim(request[points_key].roi)

            batch = Batch()
            batch.points[points_key] = return_graph

            logger.debug(
                "Graph points source provided {} points for roi: {}".format(
                    len(list(batch.points[points_key].nodes)),
                    request[points_key].roi))

            logger.debug(
                f"Providing {len(list(points_subgraph.nodes))} nodes to {points_key}"
            )

        timing.stop()
        batch.profiling_stats.add(timing)

        return batch
Ejemplo n.º 15
0
    def provide(self, request):

        timing_process = Timing(self)
        timing_process.start()

        batch = Batch()

        with h5py.File(self.filename, 'r') as hdf_file:

            # if pre and postsynaptic locations required, their id
            # SynapseLocation dictionaries should be created together s.t. ids
            # are unique and allow to find partner locations

            if PointsKeys.PRESYN in request.points_specs or PointsKeys.POSTSYN in request.points_specs:
                assert self.kind == 'synapse'
                # If only PRESYN or POSTSYN requested, assume PRESYN ROI = POSTSYN ROI.
                pre_key = PointsKeys.PRESYN if PointsKeys.PRESYN in request.points_specs else PointsKeys.POSTSYN
                post_key = PointsKeys.POSTSYN if PointsKeys.POSTSYN in request.points_specs else PointsKeys.PRESYN
                presyn_points, postsyn_points = self.__get_syn_points(
                    pre_roi=request.points_specs[pre_key].roi,
                    post_roi=request.points_specs[post_key].roi,
                    syn_file=hdf_file)
                points = {
                    PointsKeys.PRESYN: presyn_points,
                    PointsKeys.POSTSYN: postsyn_points
                }
            else:
                assert self.kind == 'presyn' or self.kind == 'postsyn'
                synkey = list(self.datasets.items())[0][0]  # only key of dic.
                presyn_points, postsyn_points = self.__get_syn_points(
                    pre_roi=request.points_specs[synkey].roi,
                    post_roi=request.points_specs[synkey].roi,
                    syn_file=hdf_file)
                points = {
                    synkey:
                    presyn_points if self.kind == 'presyn' else postsyn_points
                }

            for (points_key, request_spec) in request.points_specs.items():
                logger.debug("Reading %s in %s...", points_key,
                             request_spec.roi)
                points_spec = self.spec[points_key].copy()
                points_spec.roi = request_spec.roi
                logger.debug("Number of points len()".format(
                    len(points[points_key])))
                batch.points[points_key] = Points(data=points[points_key],
                                                  spec=points_spec)

        timing_process.stop()
        batch.profiling_stats.add(timing_process)

        return batch
    def provide(self, request):

        timing = Timing(self)
        timing.start()

        batch = Batch()

        # if pre and postsynaptic locations requested, their id : SynapseLocation dictionaries should be created
        # together s.t. the ids are unique and allow to find partner locations
        if GraphKey.PRESYN in request.points or GraphKey.POSTSYN in request.points:
            try:  # either both have the same roi, or only one of them is requested
                assert request.points[GraphKey.PRESYN] == request.points[
                    GraphKey.POSTSYN]
            except AssertionError:
                assert GraphKey.PRESYN not in request.points or GraphKey.POSTSYN not in request.points
            if GraphKey.PRESYN in request.points:
                presyn_points, postsyn_points = self.__read_syn_points(
                    roi=request.points[GraphKey.PRESYN])
            elif GraphKey.POSTSYN in request.points:
                presyn_points, postsyn_points = self.__read_syn_points(
                    roi=request.points[GraphKey.POSTSYN])

        for (points_key, roi) in request.points.items():
            # check if requested points can be provided
            if points_key not in self.spec:
                raise RuntimeError(
                    "Asked for %s which this source does not provide" %
                    points_key)
            # check if request roi lies within provided roi
            if not self.spec[points_key].roi.contains(roi):
                raise RuntimeError(
                    "%s's ROI %s outside of my ROI %s" %
                    (points_key, roi, self.spec[points_key].roi))

            logger.debug("Reading %s in %s..." % (points_key, roi))
            id_to_point = {
                GraphKey.PRESYN: presyn_points,
                GraphKey.POSTSYN: postsyn_points
            }[points_key]

            batch.points[points_key] = Graph(data=id_to_point,
                                             spec=GraphSpec(roi=roi))

        logger.debug("done")

        timing.stop()
        batch.profiling_stats.add(timing)

        return batch
Ejemplo n.º 17
0
    def provide(self, request):

        timing = Timing(self)
        timing.start()

        batch = gp.Batch()

        # If a Array is requested then we will randomly choose
        # the number of requested points
        if isinstance(self.points, gp.ArrayKey):
            points = np.random.choice(self.data.shape[0], self.num_points)
            data = self.data[points][np.newaxis]
            if self.scale is not None:
                data = data * self.scale
            if self.label_data is not None:
                labels = self.label_data[points]
            batch[self.points] = gp.Array(data, self.spec[self.points])

        else:
            # If a graph is request we must select points within the
            # request ROI

            min_bb = request[self.points].roi.get_begin()
            max_bb = request[self.points].roi.get_end()

            logger.debug("Points source got request for %s",
                         request[self.points].roi)

            point_filter = np.ones((self.data.shape[0], ), dtype=np.bool)
            for d in range(self.ndims):
                point_filter = np.logical_and(point_filter,
                                              self.data[:, d] >= min_bb[d])
                point_filter = np.logical_and(point_filter,
                                              self.data[:, d] < max_bb[d])

            points_data, labels = self._get_points(point_filter)
            logger.debug(f"Found {len(points_data)} points")
            points_spec = gp.GraphSpec(roi=request[self.points].roi.copy())
            batch.graphs[self.points] = gp.Graph(points_data, [], points_spec)

        # Labels will always be an Array
        if self.label_data is not None:
            batch[self.labels] = gp.Array(labels, self.spec[self.labels])

        timing.stop()
        batch.profiling_stats.add(timing)

        return batch
Ejemplo n.º 18
0
    def provide(self, request: BatchRequest) -> Batch:
        """
        First request points with specific seeds, then request the rest if
        valid points are found.
        """

        logger.debug(f"growing request by {self._get_growth()}")

        has_component = False
        while not has_component:

            base_seed, add_seed, direction, prepare_profiling_stats = self.get_valid_seeds(
                request)

            timing_prepare = Timing(self, "prepare")
            timing_prepare.start()

            request_base = self.prepare(request, base_seed, -direction)
            if add_seed is not None:
                request_add = self.prepare(request, add_seed, direction)
            else:
                request_add = None
                logger.debug(f"No add_request needed!")

            timing_prepare.stop()

            base = self.upstream_provider.request_batch(request_base)
            if request_add is not None:
                add = self.upstream_provider.request_batch(request_add)
            else:
                add = self._empty_copy(base)

            has_component = True

            timing_process = Timing(self, "process")
            timing_process.start()

            base = self.process(base, Coordinate([0, 0, 0]), request=request)
            add = self.process(add, -Coordinate([0, 0, 0]), request=request)

            batch = self.merge_batches(base, add)

            timing_process.stop()
            batch.profiling_stats.merge_with(prepare_profiling_stats)
            batch.profiling_stats.add(timing_prepare)
            batch.profiling_stats.add(timing_process)

        return batch
Ejemplo n.º 19
0
    def provide(self, request):

        timing = Timing(self)
        timing.start()

        batch = Batch()

        with h5py.File(self.filename, 'r') as hdf_file:

            # if pre and postsynaptic locations required, their id
            # SynapseLocation dictionaries should be created together s.t. ids
            # are unique and allow to find partner locations
            if PointsKeys.PRESYN in request.points_specs or PointsKeys.POSTSYN in request.points_specs:
                assert request.points_specs[
                    PointsKeys.PRESYN].roi == request.points_specs[
                        PointsKeys.POSTSYN].roi
                # Cremi specific, ROI offset corresponds to offset present in the
                # synapse location relative to the raw data.
                dataset_offset = self.spec[PointsKeys.PRESYN].roi.get_offset()
                presyn_points, postsyn_points = self.__get_syn_points(
                    roi=request.points_specs[PointsKeys.PRESYN].roi,
                    syn_file=hdf_file,
                    dataset_offset=dataset_offset)

            for (points_key, request_spec) in request.points_specs.items():

                logger.debug("Reading %s in %s...", points_key,
                             request_spec.roi)
                id_to_point = {
                    PointsKeys.PRESYN: presyn_points,
                    PointsKeys.POSTSYN: postsyn_points
                }[points_key]

                points_spec = self.spec[points_key].copy()
                points_spec.roi = request_spec.roi
                batch.points[points_key] = Points(data=id_to_point,
                                                  spec=points_spec)

        logger.debug("done")

        timing.stop()
        batch.profiling_stats.add(timing)

        return batch
Ejemplo n.º 20
0
    def provide(self, request):

        timing = Timing(self)
        timing.start()

        batch = Batch()

        for (array_key, request_spec) in request.array_specs.items():

            logger.debug("Reading %s in %s...", array_key, request_spec.roi)

            voxel_size = self.spec[array_key].voxel_size

            # scale request roi to voxel units
            dataset_roi = request_spec.roi / voxel_size

            # shift request roi into dataset
            dataset_roi = dataset_roi - self.spec[array_key].roi.get_offset(
            ) / voxel_size

            # create array spec
            array_spec = self.spec[array_key].copy()
            array_spec.roi = request_spec.roi

            # read the data
            if array_key in self.datasets:
                data = self.__read_array(self.datasets[array_key], dataset_roi)
            elif array_key in self.masks:
                data = self.__read_mask(self.masks[array_key], dataset_roi)
            else:
                assert False, (
                    "Encountered a request for %s that is neither a volume "
                    "nor a mask." % array_key)

            # add array to batch
            batch.arrays[array_key] = Array(data, array_spec)

        logger.debug("done")

        timing.stop()
        batch.profiling_stats.add(timing)

        return batch
Ejemplo n.º 21
0
    def provide(self, request):

        timing = Timing(self)
        timing.start()

        spec = self.get_spec()

        batch = Batch()

        with h5py.File(self.filename, 'r') as f:

            for (volume_type, roi) in request.volumes.items():

                if volume_type not in spec.volumes:
                    raise RuntimeError(
                        "Asked for %s which this source does not provide" %
                        volume_type)

                if not spec.volumes[volume_type].contains(roi):
                    raise RuntimeError(
                        "%s's ROI %s outside of my ROI %s" %
                        (volume_type, roi, spec.volumes[volume_type]))

                interpolate = {
                    VolumeType.RAW: True,
                    VolumeType.GT_LABELS: False,
                    VolumeType.GT_MASK: False,
                    VolumeType.ALPHA_MASK: True,
                }[volume_type]

                logger.debug("Reading %s in %s..." % (volume_type, roi))
                batch.volumes[volume_type] = Volume(
                    self.__read(f, self.datasets[volume_type], roi),
                    roi=roi,
                    resolution=self.resolutions[volume_type],
                    interpolate=interpolate)

        logger.debug("done")

        timing.stop()
        batch.profiling_stats.add(timing)

        return batch
Ejemplo n.º 22
0
    def provide(self, request):

        timing = Timing(self)
        timing.start()

        spec = self.get_spec()

        batch = Batch()
        logger.debug("providing batch with resolution of {}".format(
            self.resolution))

        for (volume_type, roi) in request.volumes.items():

            if volume_type not in spec.volumes:
                raise RuntimeError(
                    "Asked for %s which this source does not provide" %
                    volume_type)

            if not spec.volumes[volume_type].contains(roi):
                raise RuntimeError(
                    "%s's ROI %s outside of my ROI %s" %
                    (volume_type, roi, spec.volumes[volume_type]))

            read, interpolate = {
                VolumeType.RAW: (self.__read_raw, True),
                VolumeType.GT_LABELS: (self.__read_gt, False),
                VolumeType.GT_MASK: (self.__read_gt_mask, False),
            }[volume_type]

            logger.debug("Reading %s in %s..." % (volume_type, roi))
            batch.volumes[volume_type] = Volume(
                read(roi),
                roi=roi,
                # TODO: get resolution from repository
                resolution=self.resolution,
                interpolate=interpolate)

        logger.debug("done")

        timing.stop()
        batch.profiling_stats.add(timing)

        return batch
Ejemplo n.º 23
0
    def provide(self, request):

        report_next_timeout = 10
        num_rejected = 0

        timing = Timing(self)
        timing.start()

        assert self.mask in request, (
            "Reject can only be used if a GT mask is requested")

        have_good_batch = False
        while not have_good_batch:

            batch = self.upstream_provider.request_batch(request)
            mask_ratio = batch.arrays[self.mask].data.mean()
            have_good_batch = mask_ratio > self.min_masked

            if not have_good_batch and self.reject_probability < 1.:
                have_good_batch = random.random() > self.reject_probability

            if not have_good_batch:

                logger.debug("reject batch with mask ratio %f at %s",
                             mask_ratio, batch.arrays[self.mask].spec.roi)
                num_rejected += 1

                if timing.elapsed() > report_next_timeout:

                    logger.warning(
                        "rejected %d batches, been waiting for a good one "
                        "since %ds", num_rejected, report_next_timeout)
                    report_next_timeout *= 2

            else:

                logger.debug("accepted batch with mask ratio %f at %s",
                             mask_ratio, batch.arrays[self.mask].spec.roi)

        timing.stop()
        batch.profiling_stats.add(timing)

        return batch
Ejemplo n.º 24
0
    def provide(self, request: BatchRequest) -> Batch:
        random.seed(request.random_seed)
        np.random.seed(request.random_seed)

        timing = Timing(self, "provide")
        timing.start()

        batch = Batch()

        roi = request[self.points].roi

        region_shape = roi.get_shape()

        trees = []
        for _ in range(self.n_obj):
            for _ in range(100):
                root = np.random.random(len(region_shape)) * region_shape
                tree = self._grow_tree(
                    root, Roi((0,) * len(region_shape), region_shape)
                )
                if self.num_nodes[0] <= len(tree.nodes) <= self.num_nodes[1]:
                    break
            trees.append(tree)

        # logger.info("{} trees got, expected {}".format(len(trees), self.n_obj))

        trees_graph = nx.disjoint_union_all(trees)

        points = {
            node_id: Node(np.floor(node_attrs["pos"]) + roi.get_begin())
            for node_id, node_attrs in trees_graph.nodes.items()
        }

        batch[self.points] = Graph(points, request[self.points], list(trees_graph.edges))

        timing.stop()
        batch.profiling_stats.add(timing)

        # self._plot_tree(tree)

        return batch
Ejemplo n.º 25
0
    def provide(self, request):

        # operate on a copy of the request, to provide the original request to
        # 'process' for convenience
        upstream_request = copy.deepcopy(request)

        timing = Timing(self)

        timing.start()
        self.prepare(upstream_request)
        timing.stop()

        batch = self.get_upstream_provider().request_batch(upstream_request)

        timing.start()
        self.process(batch, request)
        timing.stop()

        batch.profiling_stats.add(timing)

        return batch
Ejemplo n.º 26
0
    def provide(self, request):

        timing = Timing(self)
        timing.start()

        # update recent requests
        self.last_5.popleft()
        self.last_5.append(request)

        if request != self.current_request:

            current_count = sum([
                recent_request == self.current_request
                for recent_request in self.last_5
            ])
            new_count = sum(
                [recent_request == request for recent_request in self.last_5])
            if new_count > current_count or self.current_request is None:

                if self.workers is not None:
                    logger.info(
                        "new request received, stopping current workers...")
                    self.workers.stop()

                self.current_request = copy.deepcopy(request)

                logger.info(
                    "starting new set of workers (%s, cache size %s)...",
                    self.num_workers, self.cache_size)
                self.workers = ProducerPool(
                    [
                        lambda i=i: self.__run_worker(i)
                        for i in range(self.num_workers)
                    ],
                    queue_size=self.cache_size,
                )
                self.workers.start()

                logger.debug("getting batch from queue...")
                batch = self.workers.get()

                timing.stop()
                batch.profiling_stats.add(timing)

            else:
                logger.debug("Resolving new request sequentially")
                batch = self.get_upstream_provider().request_batch(request)

                timing.stop()
                batch.profiling_stats.add(timing)

        else:
            logger.debug("getting batch from queue...")
            batch = self.workers.get()

            timing.stop()
            batch.profiling_stats.add(timing)

        return batch
Ejemplo n.º 27
0
    def provide(self, request):
        output = gp.Batch()

        timing_provide = Timing(self, "provide")
        timing_provide.start()

        spec = self.array_spec.copy()
        spec.roi = request[self.key].roi

        data = self.array[spec.roi]
        if "c" not in self.array.axes:
            # add a channel dimension
            data = np.expand_dims(data, 0)
        if np.any(np.isnan(data)):
            raise ValueError("INPUT DATA CAN'T BE NAN")
        output[self.key] = gp.Array(data, spec=spec)

        timing_provide.stop()

        output.profiling_stats.add(timing_provide)

        return output
Ejemplo n.º 28
0
    def provide(self, request):

        report_next_timeout = 10
        num_rejected = 0

        timing = Timing(self)
        timing.start()

        assert self.mask_volume_type in request.volumes, "Reject can only be used if a GT mask is requested"

        have_good_batch = False
        while not have_good_batch:

            batch = self.upstream_provider.request_batch(request)
            mask_ratio = batch.volumes[self.mask_volume_type].data.mean()
            have_good_batch = mask_ratio >= self.min_masked

            if not have_good_batch:

                logger.debug("reject batch with mask ratio %f at " %
                             mask_ratio +
                             str(batch.volumes[self.mask_volume_type].roi))
                num_rejected += 1

                if timing.elapsed() > report_next_timeout:

                    logger.warning(
                        "rejected %d batches, been waiting for a good one since %ds"
                        % (num_rejected, report_next_timeout))
                    report_next_timeout *= 2

        logger.debug("good batch with mask ratio %f found at " % mask_ratio +
                     str(batch.volumes[self.mask_volume_type].roi))

        timing.stop()
        batch.profiling_stats.add(timing)

        return batch
Ejemplo n.º 29
0
    def provide(self, request):

        timing = Timing(self)
        timing.start()

        batch = Batch()

        with h5py.File(self.filename, 'r') as hdf_file:

            for (array_key, request_spec) in request.array_specs.items():

                logger.debug("Reading %s in %s...", array_key,
                             request_spec.roi)

                voxel_size = self.spec[array_key].voxel_size

                # scale request roi to voxel units
                dataset_roi = request_spec.roi / voxel_size

                # shift request roi into dataset
                dataset_roi = dataset_roi - self.spec[
                    array_key].roi.get_offset() / voxel_size

                # create array spec
                array_spec = self.spec[array_key].copy()
                array_spec.roi = request_spec.roi

                # add array to batch
                batch.arrays[array_key] = Array(
                    self.__read(hdf_file, self.datasets[array_key],
                                dataset_roi), array_spec)

        logger.debug("done")

        timing.stop()
        batch.profiling_stats.add(timing)

        return batch
Ejemplo n.º 30
0
    def provide(self, request):

        timing = Timing(self)
        timing.start()

        batch = Batch()

        cv = CloudVolume(self.cloudvolume_url, use_https=True, mip=self.mip)

        request_spec = request.array_specs[self.array_key]
        array_key = self.array_key
        logger.debug("Reading %s in %s...", array_key, request_spec.roi)

        voxel_size = self.array_spec.voxel_size

        # scale request roi to voxel units
        dataset_roi = request_spec.roi / voxel_size

        # shift request roi into dataset
        dataset_roi = dataset_roi - self.spec[
            array_key].roi.get_offset() / voxel_size

        # create array spec
        array_spec = self.array_spec.copy()
        array_spec.roi = request_spec.roi
        # array_spec.voxel_size = array_spec.voxel_size

        # add array to batch
        batch.arrays[array_key] = Array(
            self.__read(cv, dataset_roi),
            array_spec)

        logger.debug("done")

        timing.stop()
        batch.profiling_stats.add(timing)

        return batch