Ejemplo n.º 1
0
    def __create_transformation(self, target_shape):

        transformation = augment.create_identity_transformation(
            target_shape, subsample=self.subsample)
        if sum(self.jitter_sigma) > 0:
            transformation += augment.create_elastic_transformation(
                target_shape,
                self.control_point_spacing,
                self.jitter_sigma,
                subsample=self.subsample,
            )
        rotation = random.random(
        ) * self.rotation_max_amount + self.rotation_start
        if rotation != 0:
            transformation += augment.create_rotation_transformation(
                target_shape, rotation, subsample=self.subsample)

        if self.subsample > 1:
            transformation = augment.upscale_transformation(
                transformation, target_shape)

        if self.prob_slip + self.prob_shift > 0:
            self.__misalign(transformation)

        return transformation
Ejemplo n.º 2
0
    def prepare(self, request):

        total_roi = request.get_total_roi()
        logger.debug("total ROI is %s" % total_roi)
        dims = len(total_roi.get_shape())

        # create a transformation for the total ROI
        rotation = random.random(
        ) * self.rotation_max_amount + self.rotation_start
        self.total_transformation = augment.create_identity_transformation(
            total_roi.get_shape(), subsample=self.subsample)
        self.total_transformation += augment.create_elastic_transformation(
            total_roi.get_shape(),
            self.control_point_spacing,
            self.jitter_sigma,
            subsample=self.subsample)
        self.total_transformation += augment.create_rotation_transformation(
            total_roi.get_shape(), rotation, subsample=self.subsample)

        if self.subsample > 1:
            self.total_transformation = augment.upscale_transformation(
                self.total_transformation, total_roi.get_shape())

        if self.prob_slip + self.prob_shift > 0:
            self.__misalign()

        # crop the parts corresponding to the requested volume ROIs
        self.transformations = {}
        logger.debug("total ROI is %s" % total_roi)
        for (volume_type, roi) in request.volumes.items():

            logger.debug("downstream request ROI for %s is %s" %
                         (volume_type, roi))

            roi_in_total_roi = roi.shift(-total_roi.get_offset())

            transformation = np.copy(
                self.total_transformation[(slice(None), ) +
                                          roi_in_total_roi.get_bounding_box()])
            self.transformations[volume_type] = transformation

            # update request ROI to get all voxels necessary to perfrom
            # transformation
            roi = self.__recompute_roi(roi, transformation)
            request.volumes[volume_type] = roi

            logger.debug("upstream request roi for %s = %s" %
                         (volume_type, roi))