Пример #1
0
    def setup(self):

        upstream = self.get_upstream_provider()
        self.upstream_spec = upstream.spec

        if self.mask and self.min_masked > 0:

            assert self.mask in self.upstream_spec, (
                "Upstream provider does not have %s" % self.mask)
            self.mask_spec = self.upstream_spec.array_specs[self.mask]

            logger.info("requesting complete mask...")

            mask_request = BatchRequest({self.mask: self.mask_spec})
            mask_batch = upstream.request_batch(mask_request)

            logger.info("allocating mask integral array...")

            mask_data = mask_batch.arrays[self.mask].data
            mask_integral_dtype = np.uint64
            logger.debug("mask size is %s", mask_data.size)
            if mask_data.size < 2**32:
                mask_integral_dtype = np.uint32
            if mask_data.size < 2**16:
                mask_integral_dtype = np.uint16
            logger.debug("chose %s as integral array dtype",
                         mask_integral_dtype)

            self.mask_integral = np.array(mask_data > 0,
                                          dtype=mask_integral_dtype)
            self.mask_integral = integral_image(self.mask_integral)

        if self.ensure_nonempty:

            assert self.ensure_nonempty in self.upstream_spec, (
                "Upstream provider does not have %s" % self.ensure_nonempty)
            points_spec = self.upstream_spec.points_specs[self.ensure_nonempty]

            logger.info("requesting all %s points...", self.ensure_nonempty)

            points_request = BatchRequest({self.ensure_nonempty: points_spec})
            points_batch = upstream.request_batch(points_request)

            self.points = KDTree([
                p.location
                for p in points_batch[self.ensure_nonempty].data.values()
            ])

            logger.info("retrieved %d points", len(self.points.data))

        # clear bounding boxes of all provided arrays and points --
        # RandomLocation does not have limits (offsets are ignored)
        for key, spec in self.spec.items():
            spec.roi.set_shape(None)
            self.updates(key, spec)
Пример #2
0
    def prepare(self, request):

        if self.labels_mask:
            assert (
                request[self.labels].roi == request[self.labels_mask].roi
            ), ("requested GT label roi %s and GT label mask roi %s are not "
                "the same." %
                (request[self.labels].roi, request[self.labels_mask].roi))

        if self.unlabelled:
            assert (
                request[self.labels].roi == request[self.unlabelled].roi
            ), ("requested GT label roi %s and GT unlabelled mask roi %s are not "
                "the same." %
                (request[self.labels].roi, request[self.unlabelled].roi))

        deps = BatchRequest()

        # grow labels ROI to accomodate padding
        labels_roi = request[self.affinities].roi.grow(-self.padding_neg,
                                                       self.padding_pos)
        deps[self.labels] = request[self.affinities].copy()
        deps[self.labels].dtype = None
        deps[self.labels].roi = labels_roi

        if self.labels_mask:
            deps[self.labels_mask] = deps[self.labels].copy()
        if self.unlabelled:
            deps[self.unlabelled] = deps[self.labels].copy()

        return deps
Пример #3
0
    def setup(self):

        self.roi = self.get_spec().get_total_roi()
        if self.roi is None:
            raise RuntimeError("Can not draw random samples from a provider that does not have a bounding box.")

        if self.min_masked > 0:

            assert self.mask_volume_type in self.get_spec().volumes, "Upstream provider does not have %s"%self.mask_volume_type
            self.mask_roi = self.get_spec().volumes[self.mask_volume_type]

            logger.info("requesting complete mask...")

            mask_request = BatchRequest({self.mask_volume_type: self.mask_roi})
            mask_batch = self.get_upstream_provider().request_batch(mask_request)

            logger.info("allocating mask integral volume...")

            mask_data = mask_batch.volumes[self.mask_volume_type].data
            mask_integral_dtype = np.uint64
            logger.debug("mask size is " + str(mask_data.size))
            if mask_data.size < 2**32:
                mask_integral_dtype = np.uint32
            if mask_data.size < 2**16:
                mask_integral_dtype = np.uint16
            logger.debug("chose %s as integral volume dtype"%mask_integral_dtype)

            self.mask_integral = np.array(mask_data>0, dtype=mask_integral_dtype)
            self.mask_integral = integral_image(self.mask_integral)
Пример #4
0
    def __init__(self,
                 output_dir='snapshots',
                 output_filename='{id}.hdf',
                 every=1,
                 additional_request=None):
        '''
        output_dir: string

            The directory to save the snapshots. Will be created, if it does not exist.

        output_filename: string

            Template for output filenames. '{id}' in the string will be replaced 
            with the ID of the batch. '{iteration}' with the training iteration 
            (if training was performed on this batch).

        every:

            How often to save a batch. 'every=1' indicates that every batch will 
            be stored, 'every=2' every second and so on. By default, every batch 
            will be stored.

        additional_request:

            An additional batch request to merge with the passing request, if a 
            snapshot is to be made. If not given, only the volumes that are in 
            the batch anyway are recorded.
        '''
        self.output_dir = output_dir
        self.output_filename = output_filename
        self.every = max(1, every)
        self.additional_request = BatchRequest(
        ) if additional_request is None else additional_request
        self.n = 0
Пример #5
0
 def prepare(self, request):
     # TODO: move all randomness into the prepare method
     # TODO: write a test for this node
     np.random.seed(request.random_seed)
     deps = BatchRequest()
     deps[self.array] = request[self.array].copy()
     return deps
Пример #6
0
    def prepare(self, request):

        deps = BatchRequest()
        deps[self.labels] = request[self.scales]
        for mask in self.masks:
            deps[mask] = request[self.scales]
        return deps
Пример #7
0
    def provide(self, request):

        # create upstream requests
        upstream_requests = {}
        for key, spec in request.items():

            provider = self.key_to_provider[key]
            if provider not in upstream_requests:
                # use new random seeds per upstream request.
                # seeds picked by random should be deterministic since
                # the provided request already has a random seed.
                seed = random.randint(0, 2**32)
                upstream_requests[provider] = BatchRequest(random_seed=seed)

            upstream_requests[provider][key] = spec

        # execute requests, merge batches
        merged_batch = Batch()
        for provider, upstream_request in upstream_requests.items():

            batch = provider.request_batch(upstream_request)
            for key, array in batch.arrays.items():
                merged_batch.arrays[key] = array
            for key, graph in batch.graphs.items():
                merged_batch.graphs[key] = graph
            merged_batch.profiling_stats.merge_with(batch.profiling_stats)

        return merged_batch
Пример #8
0
    def __init__(
        self,
        dataset_names,
        output_dir="snapshots",
        output_filename="{id}.zarr",
        every=1,
        additional_request=None,
        compression_type=None,
        dataset_dtypes=None,
        store_value_range=False,
    ):
        self.dataset_names = dataset_names
        self.output_dir = output_dir
        self.output_filename = output_filename
        self.every = max(1, every)
        self.additional_request = (
            BatchRequest() if additional_request is None else additional_request
        )
        self.n = 0
        self.compression_type = compression_type
        self.store_value_range = store_value_range
        if dataset_dtypes is None:
            self.dataset_dtypes = {}
        else:
            self.dataset_dtypes = dataset_dtypes

        self.mode = "w"
Пример #9
0
    def prepare(self, request):

        if self.settings.mode == 'ball':
            context = np.ceil(self.settings.radius).astype(np.int)
        elif self.settings.mode == 'peak':
            context = np.ceil(2*self.settings.radius).astype(np.int)
        else:
            raise RuntimeError('unknown raster mode %s'%self.settings.mode)

        dims = self.array_spec.roi.dims()
        if len(context) == 1:
            context = context.repeat(dims)

        # request graph in a larger area to get rasterization from outside
        # graph
        graph_roi = request[self.array].roi.grow(
                Coordinate(context),
                Coordinate(context))

        # however, restrict the request to the graph actually provided
        graph_roi = graph_roi.intersect(self.spec[self.graph].roi)

        deps = BatchRequest()
        deps[self.graph] = GraphSpec(roi=graph_roi)

        if self.settings.mask is not None:

            mask_voxel_size = self.spec[self.settings.mask].voxel_size
            assert self.spec[self.array].voxel_size == mask_voxel_size, (
                "Voxel size of mask and rasterized volume need to be equal")

            new_mask_roi = graph_roi.snap_to_grid(mask_voxel_size)
            deps[self.settings.mask] = ArraySpec(roi=new_mask_roi)

        return deps
Пример #10
0
    def prepare(self, request):

        upstream_spec = self.get_upstream_provider().spec

        logger.debug("request: %s" % request)
        logger.debug("upstream spec: %s" % upstream_spec)

        # TODO: remove this?
        if self.key not in request:
            return

        roi = request[self.key].roi.copy()

        # change request to fit into upstream spec
        request[self.key].roi = roi.intersect(upstream_spec[self.key].roi)

        if request[self.key].roi.empty():

            logger.warning(
                "Requested %s ROI %s lies entirely outside of upstream "
                "ROI %s.", self.key, roi, upstream_spec[self.key].roi)

            # ensure a valid request by asking for empty ROI
            request[self.key].roi = Roi(
                upstream_spec[self.key].roi.get_offset(),
                (0, ) * upstream_spec[self.key].roi.dims())

        logger.debug("new request: %s" % request)

        deps = BatchRequest()
        deps[self.key] = request[self.key]
        return deps
Пример #11
0
 def prepare(self, request):
     deps = BatchRequest()
     for in_key, out_key in zip(self.arrays, self.output_arrays):
         spec = request[out_key].copy()
         if self.context is not None:
             spec.roi = spec.roi.grow(self.context, self.context)
         deps[in_key] = spec
     return deps
Пример #12
0
    def prepare(self, request):

        if not self.initialized and not self.spawn_subprocess:
            self.start()
            self.initialized = True

        deps = BatchRequest()
        for key in self.inputs.values():
            deps[key] = request[key]
        return deps
Пример #13
0
 def __init__(self,
              every=1,
              additional_request=None,
              ignore_key=lambda key: False):
     self.every = max(1, every)
     self.additional_request = BatchRequest(
     ) if additional_request is None else additional_request
     self.n = 0
     self.snapshots = {}
     self.ignore_key = ignore_key
Пример #14
0
    def prepare(self, request):
        """
        TODO: There is no prepare method for the train nodes.
        This is a pain because it means that whatever is in the pipeline
        when it passes this node will be used as the inputs/targets etc.

        If you request ground truth labels of size "input_size", your loss
        function will probably throw an error due to it comparing the output
        of your network with size "output_size" to your labels which have size
        "input_size".
        """
        deps = BatchRequest()
        # Get the roi for the outputs
        output_requests = BatchRequest()
        for array_key in self.outputs.values():
            if array_key in request:
                output_requests[array_key] = request[array_key].copy()
        output_total_roi = output_requests.get_total_roi()
        diff = self.output_size - output_total_roi.get_shape()
        assert Coordinate([x % 2 for x in diff
                           ]) == Coordinate([0] * len(self.output_size))

        output_roi = output_total_roi.grow(diff // 2, diff // 2)
        assert output_roi.get_shape() == self.output_size

        # Grow the output roi to fit the appropriate input roi
        diff = self.input_size - output_roi.get_shape()
        assert Coordinate([x % 2 for x in diff
                           ]) == Coordinate([0] * len(self.output_size))
        input_roi = output_roi.grow(diff // 2, diff // 2)

        # Request inputs:
        for array_key in self.inputs.values():
            deps[array_key] = ArraySpec(roi=input_roi)

        # Request targets:
        for array_key in self.targets.values():
            deps[array_key] = ArraySpec(roi=output_roi)
        deps[self.weights] = ArraySpec(roi=output_roi)

        return deps
Пример #15
0
    def prepare(self, request):
        deps = BatchRequest()

        if self.target not in request:
            return

        logger.debug("preparing upsampling of " + str(self.source))

        request_roi = request[self.target].roi
        logger.debug("request ROI is %s" % request_roi)

        # add or merge to batch request
        deps[self.source] = ArraySpec(roi=request_roi)

        return deps
Пример #16
0
    def process(self, batch, request):

        assert batch.get_total_roi().dims()==3, "DefectAugment works on 3D batches only"

        prob_missing_threshold = self.prob_missing
        prob_low_contrast_threshold = prob_missing_threshold + self.prob_low_contrast
        prob_artifact_threshold = prob_low_contrast_threshold + self.prob_artifact

        raw = batch.volumes[VolumeTypes.RAW]

        for c in range(batch.get_total_roi().get_shape()[self.axis]):

            r = random.random()

            section_selector = tuple(
                    slice(None if d != self.axis else c, None if d != self.axis else c+1)
                    for d in range(batch.get_total_roi().dims())
            )

            if r < prob_missing_threshold:

                logger.debug("Zero-out " + str(section_selector))
                raw.data[section_selector] = 0

            elif r < prob_low_contrast_threshold:

                logger.debug("Lower contrast " + str(section_selector))
                section = raw.data[section_selector]

                mean = section.mean()
                section -= mean
                section *= self.contrast_scale
                section += mean

                raw.data[section_selector] = section

            elif r < prob_artifact_threshold:

                logger.debug("Add artifact " + str(section_selector))
                section = raw.data[section_selector]

                artifact_request = BatchRequest()
                artifact_request.add_volume_request(VolumeTypes.RAW, section.shape)
                artifact_request.add_volume_request(VolumeTypes.ALPHA_MASK, section.shape)
                logger.debug("Requesting artifact batch " + str(artifact_request))

                artifact_batch = self.artifact_source.request_batch(artifact_request)
                artifact_alpha = artifact_batch.volumes[VolumeTypes.ALPHA_MASK].data
                artifact_raw   = artifact_batch.volumes[VolumeTypes.RAW].data

                assert artifact_raw.dtype == section.dtype
                assert artifact_alpha.dtype == np.float32
                assert artifact_alpha.min() >= 0.0
                assert artifact_alpha.max() <= 1.0

                raw.data[section_selector] = section*(1.0 - artifact_alpha) + artifact_raw*artifact_alpha
Пример #17
0
    def prepare(self, request):

        deps = BatchRequest()

        for (array_key, (src_points_key, trg_points_key)) in self.array_to_src_trg_points.items():
            if array_key in request:
                # increase or set request for points to be array roi + padding for partners outside roi for target points
                deps[src_points_key] = PointsSpec(request[array_key].roi)
                padded_roi = request[array_key].roi.grow((self.pad_for_partners), (self.pad_for_partners))
                deps[trg_points_key] = PointsSpec(padded_roi)

        for (array_key, stayinside_array_key) in self.array_keys_to_stayinside_array_keys.items():
            if array_key in request:
                deps[stayinside_array_key] = copy.deepcopy(request[array_key])

        return deps
Пример #18
0
    def prepare(self, request):
        deps = BatchRequest()
        for key, spec in request.items():
            if key in self.dataset_names:
                deps[key] = spec

        self.record_snapshot = self.n % self.every == 0

        if self.record_snapshot:
            # append additional array requests, don't overwrite existing ones
            for array_key, spec in self.additional_request.array_specs.items():
                if array_key not in deps:
                    deps[array_key] = spec
            for graph_key, spec in self.additional_request.graph_specs.items():
                if graph_key not in deps:
                    deps[graph_key] = spec

        return deps
Пример #19
0
 def __init__(self,
              dataset_names,
              output_dir='snapshots',
              output_filename='{id}.hdf',
              every=1,
              additional_request=None,
              compression_type=None,
              dataset_dtypes=None):
     self.dataset_names = dataset_names
     self.output_dir = output_dir
     self.output_filename = output_filename
     self.every = max(1, every)
     self.additional_request = BatchRequest(
     ) if additional_request is None else additional_request
     self.n = 0
     self.compression_type = compression_type
     if dataset_dtypes is None:
         self.dataset_dtypes = {}
     else:
         self.dataset_dtypes = dataset_dtypes
Пример #20
0
    def provide(self, request):

        # create upstream requests
        upstream_requests = {}
        for key, spec in request.items():

            provider = self.key_to_provider[key]
            if provider not in upstream_requests:
                upstream_requests[provider] = BatchRequest()

            upstream_requests[provider][key] = spec

        # execute requests, merge batches
        merged_batch = Batch()
        for provider, upstream_request in upstream_requests.items():

            batch = provider.request_batch(upstream_request)
            for key, array in batch.arrays.items():
                merged_batch.arrays[key] = array
            for key, points in batch.points.items():
                merged_batch.points[key] = points

        return merged_batch
Пример #21
0
    def provide(self, request):

        # create upstream requests
        upstream_requests = {}
        for key, spec in request.items():

            provider = self.key_to_provider[key]
            if provider not in upstream_requests:
                upstream_requests[provider] = BatchRequest()

            upstream_requests[provider][key] = spec

        # execute requests, merge batches
        merged_batch = Batch()
        for provider, upstream_request in upstream_requests.items():

            batch = provider.request_batch(upstream_request)
            for key, array in batch.arrays.items():
                merged_batch.arrays[key] = array
            for key, graph in batch.graphs.items():
                merged_batch.graphs[key] = graph
            merged_batch.profiling_stats.merge_with(batch.profiling_stats)

        return merged_batch
Пример #22
0
    def __select_random_location_with_points(
            self,
            request,
            lcm_shift_roi,
            lcm_voxel_size):

        request_points_roi = request[self.ensure_nonempty].roi

        while True:

            # How to pick shifts that ensure that a randomly chosen point is
            # contained in the request ROI:
            #
            #
            # request          point
            # [---------)      .
            # 0        +10     17
            #
            #         least shifted to contain point
            #         [---------)
            #         8        +10
            #         ==
            #         point-request.begin-request.shape+1
            #
            #                  most shifted to contain point:
            #                  [---------)
            #                  17       +10
            #                  ==
            #                  point-request.begin
            #
            #         all possible shifts
            #         [---------)
            #         8        +10
            #         ==
            #         point-request.begin-request.shape+1
            #                   ==
            #                   request.shape
            #
            # In the most shifted scenario, it could happen that the point lies
            # exactly at the lower boundary (17 in the example). This will cause
            # trouble if later we mirror the batch. The point would end up lying
            # on the other boundary, which is exclusive and thus not part of the
            # ROI. Therefore, we have to ensure that the point is well inside
            # the shifted ROI, not just on the boundary:
            #
            #         all possible shifts
            #         [--------)
            #         8       +9
            #                 ==
            #                 request.shape-1

            # pick a random point
            point_id = choice(self.points.data.keys())
            point = self.points.data[point_id]

            logger.debug(
                "select random point %d at %s",
                point_id,
                point.location)

            # get the lcm voxel that contains this point
            lcm_location = Coordinate(point.location/lcm_voxel_size)
            logger.debug(
                "belongs to lcm voxel %s",
                lcm_location)

            # mark all dimensions in which the point lies on the lower boundary
            # of the lcm voxel
            on_lower_boundary = lcm_location*lcm_voxel_size == point.location
            logger.debug(
                "lies on the lower boundary of the lcm voxel in dimensions %s",
                on_lower_boundary)

            # for each of these dimensions, we have to change the shape of the
            # shift ROI using the following correction
            lower_boundary_correction = Coordinate((
                -1 if o else 0
                for o in on_lower_boundary
            ))
            logger.debug(
                "lower bound correction for shape of shift ROI %s",
                lower_boundary_correction)

            # get the request ROI's shape in lcm
            lcm_roi_begin = request_points_roi.get_begin()/lcm_voxel_size
            lcm_roi_shape = request_points_roi.get_shape()/lcm_voxel_size
            logger.debug("Point request ROI: %s", request_points_roi)
            logger.debug("Point request lcm ROI shape: %s", lcm_roi_shape)

            # get all possible starting points of lcm_roi_shape that contain
            # lcm_location
            lcm_shift_roi_begin = (
                lcm_location - lcm_roi_begin - lcm_roi_shape +
                Coordinate((1,)*len(lcm_location))
            )
            lcm_shift_roi_shape = (
                lcm_roi_shape + lower_boundary_correction
            )
            lcm_point_shift_roi = Roi(lcm_shift_roi_begin, lcm_shift_roi_shape)
            logger.debug("lcm point shift roi: %s", lcm_point_shift_roi)

            # intersect with total shift ROI
            if not lcm_point_shift_roi.intersects(lcm_shift_roi):
                logger.debug(
                    "reject random shift, random point %s shift ROI %s does "
                    "not intersect total shift ROI %s", point.location,
                    lcm_point_shift_roi, lcm_shift_roi)
                continue
            lcm_point_shift_roi = lcm_point_shift_roi.intersect(lcm_shift_roi)

            # select a random shift from all possible shifts
            random_shift = self.__select_random_location(
                lcm_point_shift_roi,
                lcm_voxel_size)
            logger.debug("random shift: %s", random_shift)

            # count all points inside the shifted ROI
            points_request = BatchRequest()
            points_request[self.ensure_nonempty] = PointsSpec(
                roi=request_points_roi.shift(random_shift))
            logger.debug("points request: %s", points_request)
            points_batch = self.get_upstream_provider().request_batch(points_request)

            point_ids = points_batch.points[self.ensure_nonempty].data.keys()
            assert point_id in point_ids, (
                "Requested batch to contain point %s, but got points "
                "%s"%(point_id, point_ids))
            num_points = len(point_ids)

            # accept this shift with p=1/num_points
            #
            # This is to compensate the bias introduced by close-by points.
            accept = random() <= 1.0/num_points
            if accept:
                return random_shift
Пример #23
0
    def process(self, batch, request):

        assert batch.get_total_roi().dims() == 3, "defectaugment works on 3d batches only"

        raw = batch.arrays[self.intensities]
        raw_voxel_size = self.spec[self.intensities].voxel_size

        for c, augmentation_type in self.slice_to_augmentation.items():

            section_selector = tuple(
                slice(None if d != self.axis else c, None if d != self.axis else c+1)
                for d in range(raw.spec.roi.dims())
            )

            if augmentation_type == 'zero_out':
                raw.data[section_selector] = 0

            elif augmentation_type == 'low_contrast':
                section = raw.data[section_selector]

                mean = section.mean()
                section -= mean
                section *= self.contrast_scale
                section += mean

                raw.data[section_selector] = section

            elif augmentation_type == 'artifact':

                section = raw.data[section_selector]

                alpha_voxel_size = self.artifact_source.spec[self.artifacts_mask].voxel_size

                assert raw_voxel_size == alpha_voxel_size, ("Can only alpha blend RAW with "
                                                            "ALPHA_MASK if both have the same "
                                                            "voxel size")

                artifact_request = BatchRequest()
                artifact_request.add(self.artifacts, Coordinate(section.shape) * raw_voxel_size, voxel_size=raw_voxel_size)
                artifact_request.add(self.artifacts_mask, Coordinate(section.shape) * alpha_voxel_size, voxel_size=raw_voxel_size)
                logger.debug("Requesting artifact batch %s", artifact_request)

                artifact_batch = self.artifact_source.request_batch(artifact_request)
                artifact_alpha = artifact_batch.arrays[self.artifacts_mask].data
                artifact_raw   = artifact_batch.arrays[self.artifacts].data

                assert artifact_alpha.dtype == np.float32
                assert artifact_alpha.min() >= 0.0
                assert artifact_alpha.max() <= 1.0

                raw.data[section_selector] = section*(1.0 - artifact_alpha) + artifact_raw*artifact_alpha

            elif augmentation_type == 'deformed_slice':

                section = raw.data[section_selector].squeeze()

                # set interpolation to cubic, spec interploatable is true, else to 0
                interpolation = 3 if self.spec[self.intensities].interpolatable else 0

                # load the deformation fields that were prepared for this slice
                flow_x, flow_y, line_mask = self.deform_slice_transformations[c]

                # apply the deformation fields
                shape = section.shape
                section = map_coordinates(
                    section, (flow_y, flow_x), mode='constant', order=interpolation
                ).reshape(shape)

                # things can get smaller than 0 at the boundary, so we clip
                section = np.clip(section, 0., 1.)

                # zero-out data below the line mask
                section[line_mask] = 0.

                raw.data[section_selector] = section

        # in case we needed to change the ROI due to a deformation augment,
        # restore original ROI and crop the array data
        if 'deformed_slice' in self.slice_to_augmentation.values():
            old_roi = request[self.intensities].roi
            logger.debug("resetting roi to %s" % old_roi)
            crop = tuple(
                slice(None) if d == self.axis else slice(self.deformation_strength, -self.deformation_strength)
                for d in range(raw.spec.roi.dims())
            )
            raw.data = raw.data[crop]
            raw.spec.roi = old_roi
Пример #24
0
    def prepare(self, request):
        deps = BatchRequest()

        # we prepare the augmentations, by determining which slices
        # will be augmented by which method
        # If one of the slices is augmented with 'deform',
        # we prepare these trafos already
        # and request a bigger roi from upstream

        prob_missing_threshold = self.prob_missing
        prob_low_contrast_threshold = prob_missing_threshold + self.prob_low_contrast
        prob_artifact_threshold = prob_low_contrast_threshold + self.prob_artifact
        prob_deform_slice = prob_artifact_threshold + self.prob_deform

        spec = request[self.intensities].copy()
        roi = spec.roi
        logger.debug("downstream request ROI is %s" % roi)
        raw_voxel_size = self.spec[self.intensities].voxel_size

        # store the mapping slice to augmentation type in a dict
        self.slice_to_augmentation = {}
        # store the transformations for deform slice
        self.deform_slice_transformations = {}
        for c in range((roi / raw_voxel_size).get_shape()[self.axis]):
            r = random.random()

            if r < prob_missing_threshold:
                logger.debug("Zero-out " + str(c))
                self.slice_to_augmentation[c] = 'zero_out'

            elif r < prob_low_contrast_threshold:
                logger.debug("Lower contrast " + str(c))
                self.slice_to_augmentation[c] = 'lower_contrast'

            elif r < prob_artifact_threshold:
                logger.debug("Add artifact " + str(c))
                self.slice_to_augmentation[c] = 'artifact'

            elif r < prob_deform_slice:
                logger.debug("Add deformed slice " + str(c))
                self.slice_to_augmentation[c] = 'deformed_slice'
                # get the shape of a single slice
                slice_shape = (roi / raw_voxel_size).get_shape()
                slice_shape = slice_shape[:self.axis] + slice_shape[self.axis+1:]
                self.deform_slice_transformations[c] = self.__prepare_deform_slice(slice_shape)

        # prepare transformation and
        # request bigger upstream roi for deformed slice
        if 'deformed_slice' in self.slice_to_augmentation.values():

            # create roi sufficiently large to feed deformation
            logger.debug("before growth: %s" % spec.roi)
            growth = Coordinate(
                tuple(0 if d == self.axis else raw_voxel_size[d] * self.deformation_strength
                      for d in range(spec.roi.dims()))
            )
            logger.debug("growing request by %s" % str(growth))
            source_roi = roi.grow(growth, growth)

            # update request ROI to get all voxels necessary to perfrom
            # transformation
            spec.roi = source_roi
            logger.debug("upstream request roi is %s" % spec.roi)

        deps[self.intensities] = spec
Пример #25
0
 def prepare(self, request):
     deps = BatchRequest()
     for key in self.dataset_names.keys():
         deps[key] = request[key]
     return deps
Пример #26
0
 def prepare(self, request):
     deps = BatchRequest()
     for key in self.inputs.values():
         deps[key] = request[key]
     return deps
Пример #27
0
 def prepare(self, request):
     deps = BatchRequest()
     deps[self.array] = request[self.array]
     return deps
Пример #28
0
    def setup(self):

        upstream = self.get_upstream_provider()
        self.upstream_spec = upstream.spec

        if self.mask and self.min_masked > 0:

            assert self.mask in self.upstream_spec, (
                "Upstream provider does not have %s"%self.mask)
            self.mask_spec = self.upstream_spec.array_specs[self.mask]

            logger.info("requesting complete mask...")

            mask_request = BatchRequest({self.mask: self.mask_spec})
            mask_batch = upstream.request_batch(mask_request)

            logger.info("allocating mask integral array...")

            mask_data = mask_batch.arrays[self.mask].data
            mask_integral_dtype = np.uint64
            logger.debug("mask size is %s", mask_data.size)
            if mask_data.size < 2**32:
                mask_integral_dtype = np.uint32
            if mask_data.size < 2**16:
                mask_integral_dtype = np.uint16
            logger.debug("chose %s as integral array dtype", mask_integral_dtype)

            self.mask_integral = np.array(mask_data > 0, dtype=mask_integral_dtype)
            self.mask_integral = integral_image(self.mask_integral).astype(mask_integral_dtype)

        if self.ensure_nonempty:

            assert self.ensure_nonempty in self.upstream_spec, (
                "Upstream provider does not have %s"%self.ensure_nonempty)
            graph_spec = self.upstream_spec.graph_specs[self.ensure_nonempty]


            logger.info("requesting all %s points...", self.ensure_nonempty)

            nonempty_request = BatchRequest({self.ensure_nonempty: graph_spec})
            nonempty_batch = upstream.request_batch(nonempty_request)

            self.points = cKDTree(
                [p.location for p in nonempty_batch[self.ensure_nonempty].nodes]
            )

            point_counts = self.points.query_ball_point(
                [p.location for p in nonempty_batch[self.ensure_nonempty].nodes],
                r=self.point_balance_radius,
            )
            weights = [1 / len(point_count) for point_count in point_counts]
            self.cumulative_weights = list(itertools.accumulate(weights))

            logger.debug("retrieved %d points", len(self.points.data))

        # clear bounding boxes of all provided arrays and points --
        # RandomLocation does not have limits (offsets are ignored)
        for key, spec in self.spec.items():
            if spec.roi is not None:
                spec.roi.set_shape(None)
                self.updates(key, spec)
    def prepare(self, request):

        deps = BatchRequest()
        deps[self.label_array_key] = request[self.gradient_array_key]

        return deps
Пример #30
0
 def prepare(self, request):
     deps = BatchRequest()
     deps[self.array] = request[self.array]
     deps[self.array].dtype = None
     return deps