示例#1
0
 def run_for_partition(self, partition: Partition, roi, corrections):
     with set_num_threads(1):
         try:
             previous_id = None
             device_class = get_device_class()
             # numpy_udfs and cupy_udfs contain references to the objects in
             # self._udfs
             numpy_udfs, cupy_udfs = self._udf_lists(device_class)
             # Will only be populated if actually on CUDA worker
             # and any UDF supports 'cupy' (and not 'cuda')
             if cupy_udfs:
                 # Avoid importing if not used
                 import cupy
                 device = get_use_cuda()
                 previous_id = cupy.cuda.Device().id
                 cupy.cuda.Device(device).use()
             (meta, tiling_scheme,
              dtype) = self._init_udfs(numpy_udfs, cupy_udfs, partition,
                                       roi, corrections, device_class)
             # print("UDF TilingScheme: %r" % tiling_scheme.shape)
             partition.set_corrections(corrections)
             self._run_udfs(numpy_udfs, cupy_udfs, partition, tiling_scheme,
                            roi, dtype)
             self._wrapup_udfs(numpy_udfs, cupy_udfs, partition)
         finally:
             if previous_id is not None:
                 cupy.cuda.Device(previous_id).use()
         # Make sure results are in the same order as the UDFs
         return tuple(udf.results for udf in self._udfs)
示例#2
0
 def enter(self):
     """
     Note: we are using the @contextmanager decorator here,
     because with separate `__enter__`, `__exit__` methods,
     we can't easily delegate to `set_num_threads`, or other
     contextmanagers that may come later.
     """
     with set_num_threads(self._threads_per_worker):
         yield self
示例#3
0
    def __call__(self):
        num_masks = len(self.masks)
        part = zeros_aligned((num_masks,) + tuple(self.partition.shape.nav), dtype=self.dtype)

        # FIXME: tileshape negotiation!
        shape = self.partition.shape
        tileshape = Shape(
            (1,) + tuple(shape.sig),
            sig_dims=shape.sig.dims
        )
        tiling_scheme = self.tiling_scheme
        if tiling_scheme is None:
            tiling_scheme = TilingScheme.make_for_shape(
                tileshape=tileshape,
                dataset_shape=shape,  # ...
            )

        tiles = self.partition.get_tiles(
            tiling_scheme=tiling_scheme,
            dest_dtype=self.read_dtype
        )

        with set_num_threads(1):
            try:
                import torch
            except ImportError:
                torch = None
            for data_tile in tiles:
                flat_data = data_tile.flat_data
                masks = self.masks.get(data_tile, self.mask_dtype)
                if isinstance(masks, sparse.SparseArray):
                    result = sparse.dot(flat_data, masks)
                elif scipy.sparse.issparse(masks):
                    # This is scipy.sparse using the old matrix interface
                    # where "*" is the dot product
                    result = flat_data * masks
                elif self.use_torch:
                    result = torch.mm(
                        torch.from_numpy(flat_data),
                        torch.from_numpy(masks),
                    ).numpy()
                else:
                    result = flat_data.dot(masks)
                dest_slice = data_tile.tile_slice.shift(self.partition.slice)
                reshaped = self.reshaped_data(data=result, dest_slice=dest_slice)
                # Ellipsis to match the "number of masks" part of the result
                part[(...,) + dest_slice.get(nav_only=True)] += reshaped
            return [
                MaskResultTile(
                    data=part,
                    dest_slice=self.partition.slice.get(nav_only=True),
                )
            ]
示例#4
0
    def run_for_partition(self, partition, roi):
        with set_num_threads(1):
            dtype = self._get_dtype(partition.dtype)
            meta = UDFMeta(
                partition_shape=partition.slice.adjust_for_roi(roi).shape,
                dataset_shape=partition.meta.shape,
                roi=roi,
                dataset_dtype=partition.dtype,
                input_dtype=dtype,
            )
            self._udf.set_meta(meta)
            self._udf.init_result_buffers()
            self._udf.allocate_for_part(partition, roi)
            self._udf.init_task_data()
            if hasattr(self._udf, 'preprocess'):
                self._udf.clear_views()
                self._udf.preprocess()
            method = self._udf.get_method()
            if method == 'tile':
                tiles = partition.get_tiles(full_frames=False,
                                            roi=roi,
                                            dest_dtype=dtype,
                                            mmap=True)
            elif method == 'frame':
                tiles = partition.get_tiles(full_frames=True,
                                            roi=roi,
                                            dest_dtype=dtype,
                                            mmap=True)
            elif method == 'partition':
                tiles = [
                    partition.get_macrotile(roi=roi,
                                            dest_dtype=dtype,
                                            mmap=True)
                ]

            for tile in tiles:
                if method == 'tile':
                    self._udf.set_views_for_tile(partition, tile)
                    self._udf.set_slice(tile.tile_slice)
                    self._udf.process_tile(tile.data)
                elif method == 'frame':
                    tile_slice = tile.tile_slice
                    for frame_idx, frame in enumerate(tile.data):
                        frame_slice = Slice(
                            origin=(tile_slice.origin[0] + frame_idx, ) +
                            tile_slice.origin[1:],
                            shape=Shape((1, ) + tuple(tile_slice.shape)[1:],
                                        sig_dims=tile_slice.shape.sig.dims),
                        )
                        self._udf.set_slice(frame_slice)
                        self._udf.set_views_for_frame(partition, tile,
                                                      frame_idx)
                        self._udf.process_frame(frame)
                elif method == 'partition':
                    self._udf.set_views_for_tile(partition, tile)
                    self._udf.set_slice(partition.slice)
                    self._udf.process_partition(tile.data)

            if hasattr(self._udf, 'postprocess'):
                self._udf.clear_views()
                self._udf.postprocess()

            self._udf.cleanup()
            self._udf.clear_views()

            if self._debug:
                try:
                    cloudpickle.loads(cloudpickle.dumps(partition))
                except TypeError:
                    raise TypeError("could not pickle partition")
                try:
                    cloudpickle.loads(cloudpickle.dumps(self._udf.results))
                except TypeError:
                    raise TypeError("could not pickle results")

            return self._udf.results, partition
示例#5
0
    def run_for_partition(self, partition: Partition, roi):
        with set_num_threads(1):
            dtype = self._get_dtype(partition.dtype)
            meta = UDFMeta(
                partition_shape=partition.slice.adjust_for_roi(roi).shape,
                dataset_shape=partition.meta.shape,
                roi=roi,
                dataset_dtype=partition.dtype,
                input_dtype=dtype,
                tiling_scheme=None,
            )
            udfs = self._udfs
            for udf in udfs:
                udf.set_meta(meta)
                udf.init_result_buffers()
                udf.allocate_for_part(partition, roi)
                udf.init_task_data()
                if hasattr(udf, 'preprocess'):
                    udf.clear_views()
                    udf.preprocess()
            neg = Negotiator()
            tiling_scheme = neg.get_scheme(
                udfs=udfs,
                partition=partition,
                read_dtype=dtype,
                roi=roi,
            )

            # FIXME: don't fully re-create?
            meta = UDFMeta(
                partition_shape=partition.slice.adjust_for_roi(roi).shape,
                dataset_shape=partition.meta.shape,
                roi=roi,
                dataset_dtype=partition.dtype,
                input_dtype=dtype,
                tiling_scheme=tiling_scheme,
            )
            for udf in udfs:
                udf.set_meta(meta)
            # print("UDF TilingScheme: %r" % tiling_scheme.shape)

            tiles = partition.get_tiles(tiling_scheme=tiling_scheme,
                                        roi=roi,
                                        dest_dtype=dtype)

            for tile in tiles:
                for udf in udfs:
                    method = udf.get_method()
                    if method == 'tile':
                        udf.set_contiguous_views_for_tile(partition, tile)
                        udf.set_slice(tile.tile_slice)
                        udf.process_tile(tile)
                    elif method == 'frame':
                        tile_slice = tile.tile_slice
                        for frame_idx, frame in enumerate(tile):
                            frame_slice = Slice(
                                origin=(tile_slice.origin[0] + frame_idx, ) +
                                tile_slice.origin[1:],
                                shape=Shape(
                                    (1, ) + tuple(tile_slice.shape)[1:],
                                    sig_dims=tile_slice.shape.sig.dims),
                            )
                            udf.set_slice(frame_slice)
                            udf.set_views_for_frame(partition, tile, frame_idx)
                            udf.process_frame(frame)
                    elif method == 'partition':
                        udf.set_views_for_tile(partition, tile)
                        udf.set_slice(partition.slice)
                        udf.process_partition(tile)
            for udf in udfs:
                udf.flush()
                if hasattr(udf, 'postprocess'):
                    udf.clear_views()
                    udf.postprocess()

                udf.cleanup()
                udf.clear_views()

            if self._debug:
                try:
                    cloudpickle.loads(cloudpickle.dumps(partition))
                except TypeError:
                    raise TypeError("could not pickle partition")
                try:
                    cloudpickle.loads(
                        cloudpickle.dumps([u.results for u in udfs]))
                except TypeError:
                    raise TypeError("could not pickle results")

            return tuple(udf.results for udf in udfs)