예제 #1
0
    def layer_op(self,
                 fixed_image,
                 moving_image,
                 is_training=True,
                 **unused_kwargs):
        """

        :param fixed_image: tensor, fixed image for registration (defines reference space)
        :param moving_image: tensor, moving image to be registered to fixed
        :param is_training: boolean, True if network is in training mode
        :return: displacement fields transformed by estimating affine
        """

        spatial_rank = infer_spatial_rank(moving_image)
        spatial_shape = fixed_image.get_shape().as_list()[1:-1]

        # resize the moving image to match the fixed
        moving_image = Resize(spatial_shape)(moving_image)
        img = tf.concat([moving_image, fixed_image], axis=-1)
        res_1 = DownRes(self.fea[0], kernel_size=7,
                        **self.res_param)(img, is_training)[0]
        res_2 = DownRes(self.fea[1], **self.res_param)(res_1, is_training)[0]
        res_3 = DownRes(self.fea[2], **self.res_param)(res_2, is_training)[0]
        res_4 = DownRes(self.fea[3], **self.res_param)(res_3, is_training)[0]

        conv_5 = Conv(n_output_chns=self.fea[4],
                      kernel_size=self.k_conv,
                      with_bias=False,
                      feature_normalization='batch',
                      **self.res_param)(res_4, is_training)

        if spatial_rank == 2:
            affine_size = 6
        elif spatial_rank == 3:
            affine_size = 12
        else:
            tf.logging.fatal('Not supported spatial rank')
            raise NotImplementedError

        if self.affine_w_initializer is None:
            self.affine_w_initializer = init_affine_w()
        if self.affine_b_initializer is None:
            self.affine_b_initializer = init_affine_b(spatial_rank)
        affine = FC(n_output_chns=affine_size,
                    feature_normalization=None,
                    w_initializer=self.affine_w_initializer,
                    b_initializer=self.affine_b_initializer,
                    **self.affine_param)(conv_5)
        grid_global = Grid(source_shape=spatial_shape,
                           output_shape=spatial_shape)(affine)
        return grid_global
예제 #2
0
    def layer_op(self, tensor_a, tensor_b):
        """
        match the spatial shape and concatenate the tensors
        tensor_a will be cropped and resized to match tensor_b.

        :param tensor_a: tensor, input
        :param tensor_b: tensor, input
        :return: concatenated tensor
        """
        crop_border = (tensor_a.shape[1] - tensor_b.shape[1]) // 2
        tensor_a = Crop(border=crop_border)(tensor_a)
        output_spatial_shape = tensor_b.shape[1:-1]
        tensor_a = Resize(new_size=output_spatial_shape)(tensor_a)
        return ElementWise('CONCAT')(tensor_a, tensor_b)
예제 #3
0
    def layer_op(self,
                 fixed_image,
                 moving_image,
                 base_grid=None,
                 is_training=True,
                 **unused_kwargs):
        """

        :param fixed_image:
        :param moving_image:
        :param base_grid:
        :param is_training:
        :return: estimated dense displacement fields
        """

        spatial_rank = infer_spatial_rank(fixed_image)
        spatial_shape = fixed_image.get_shape().as_list()[1:-1]
        check_spatial_dims(fixed_image, lambda x: x % 16 == 0)

        #  resize the moving image to match the fixed
        moving_image = Resize(spatial_shape)(moving_image)
        img = tf.concat([moving_image, fixed_image], axis=-1)
        down_res_0, conv_0_0, _ = \
            DownRes(self.fea[0], kernel_size=7, **self.down_res_param)(img, is_training)
        down_res_1, conv_0_1, _ = \
            DownRes(self.fea[1], **self.down_res_param)(down_res_0, is_training)
        down_res_2, conv_0_2, _ = \
            DownRes(self.fea[2], **self.down_res_param)(down_res_1, is_training)
        down_res_3, conv_0_3, _ = \
            DownRes(self.fea[3], **self.down_res_param)(down_res_2, is_training)

        conv_4 = Conv(n_output_chns=self.fea[4],
                      kernel_size=self.k_conv,
                      **self.down_res_param)(down_res_3, is_training)

        up_res_0 = UpRes(self.fea[3], **self.up_res_param)(conv_4, conv_0_3,
                                                           is_training)
        up_res_1 = UpRes(self.fea[2], **self.up_res_param)(up_res_0, conv_0_2,
                                                           is_training)
        up_res_2 = UpRes(self.fea[1], **self.up_res_param)(up_res_1, conv_0_1,
                                                           is_training)
        up_res_3 = UpRes(self.fea[0], **self.up_res_param)(up_res_2, conv_0_0,
                                                           is_training)

        if self.multi_scale_fusion:
            output_list = [up_res_3, up_res_2, up_res_1, up_res_0, conv_4]
        else:
            output_list = [up_res_3]

        # converting all output layers to displacement fields
        dense_fields = []
        for scale_out in output_list:
            field = Conv(n_output_chns=spatial_rank,
                         kernel_size=self.k_conv,
                         with_bias=True,
                         with_bn=False,
                         acti_func=None,
                         **self.disp_param)(scale_out)
            resized_field = Resize(new_size=spatial_shape)(field)
            dense_fields.append(resized_field)

        if base_grid is None:
            # adding a reference grid if it doesn't exist
            in_spatial_size = [None] * spatial_rank
            base_grid = _create_affine_features(output_shape=spatial_shape,
                                                source_shape=in_spatial_size)
            base_grid = np.asarray(base_grid[:-1])
            base_grid = np.reshape(base_grid.T,
                                   [-1] + spatial_shape + [spatial_rank])
            base_grid = tf.constant(base_grid, dtype=resized_field.dtype)

        if self.multi_scale_fusion and len(dense_fields) > 1:
            dense_field = tf.reduce_sum(dense_fields, axis=0)
        else:
            dense_field = dense_fields[0]

        # TODO filtering
        if self.smoothing_func is not None:
            dense_field = self.smoothing_func(dense_field, spatial_rank)

        tf.add_to_collection('bending_energy',
                             _computing_bending_energy(dense_field))
        tf.add_to_collection('gradient_norm',
                             _computing_gradient_norm(dense_field))

        dense_field = dense_field + base_grid
        return dense_field
    def layer_op(self):
        """
        This function concatenate image and label volumes at the last dim
        and randomly cropping the volumes (also the cropping margins)
        """
        image_id, fixed_inputs, moving_inputs, fixed_shape, moving_shape = \
            self.iterator.get_next()
        # TODO preprocessing layer modifying
        #      image shapes will not be supported
        # assuming the same shape across modalities, using the first
        image_id.set_shape((self.batch_size, ))
        image_id = tf.to_float(image_id)

        fixed_inputs.set_shape((self.batch_size, ) +
                               (None, ) * self.spatial_rank + (2, ))
        # last dim is 1 image + 1 label
        moving_inputs.set_shape((self.batch_size, ) + self.moving_image_shape +
                                (2, ))
        fixed_shape.set_shape((self.batch_size, self.spatial_rank + 1))
        moving_shape.set_shape((self.batch_size, self.spatial_rank + 1))

        # resizing the moving_inputs to match the target
        # assumes the same shape across the batch
        target_spatial_shape = \
            tf.unstack(fixed_shape[0], axis=0)[:self.spatial_rank]
        moving_inputs = Resize(new_size=target_spatial_shape)(moving_inputs)
        combined_volume = tf.concat([fixed_inputs, moving_inputs], axis=-1)

        # smoothing_layer = Smoothing(
        #     sigma=1, truncate=3.0, type_str='gaussian')
        # combined_volume = tf.unstack(combined_volume, axis=-1)
        # combined_volume[0] = tf.expand_dims(combined_volume[0], axis=-1)
        # combined_volume[1] = smoothing_layer(
        #     tf.expand_dims(combined_volume[1]), axis=-1)
        # combined_volume[2] = tf.expand_dims(combined_volume[2], axis=-1)
        # combined_volume[3] = smoothing_layer(
        #     tf.expand_dims(combined_volume[3]), axis=-1)
        # combined_volume = tf.stack(combined_volume, axis=-1)

        # TODO affine data augmentation here
        if self.spatial_rank == 3:

            window_channels = np.prod(self.window_size[self.spatial_rank:]) * 4
            # TODO if no affine augmentation:
            img_spatial_shape = target_spatial_shape
            win_spatial_shape = [
                tf.constant(dim)
                for dim in self.window_size[:self.spatial_rank]
            ]
            # when img==win make sure shift => 0.0
            # otherwise interpolation is out of bound
            batch_shift = [
                tf.random_uniform(shape=(self.batch_size, 1),
                                  minval=0,
                                  maxval=tf.maximum(tf.to_float(img - win - 1),
                                                    0.01))
                for (win, img) in zip(win_spatial_shape, img_spatial_shape)
            ]
            batch_shift = tf.concat(batch_shift, axis=1)
            affine_constraints = ((1.0, 0.0, 0.0, None), (0.0, 1.0, 0.0, None),
                                  (0.0, 0.0, 1.0, None))
            computed_grid = AffineGridWarperLayer(
                source_shape=(None, None, None),
                output_shape=self.window_size[:self.spatial_rank],
                constraints=affine_constraints)(batch_shift)
            computed_grid.set_shape((self.batch_size, ) +
                                    self.window_size[:self.spatial_rank] +
                                    (self.spatial_rank, ))
            resampler = ResamplerLayer(interpolation='linear',
                                       boundary='replicate')
            windows = resampler(combined_volume, computed_grid)
            out_shape = [self.batch_size] + \
                        list(self.window_size[:self.spatial_rank]) + \
                        [window_channels]
            windows.set_shape(out_shape)

            image_id = tf.reshape(image_id, (self.batch_size, 1))
            start_location = tf.zeros((self.batch_size, self.spatial_rank))
            locations = tf.concat([image_id, start_location, batch_shift],
                                  axis=1)
        return windows, locations