def layer_op(self, fixed_image, moving_image, is_training=True, **unused_kwargs): """ :param fixed_image: tensor, fixed image for registration (defines reference space) :param moving_image: tensor, moving image to be registered to fixed :param is_training: boolean, True if network is in training mode :return: displacement fields transformed by estimating affine """ spatial_rank = infer_spatial_rank(moving_image) spatial_shape = fixed_image.get_shape().as_list()[1:-1] # resize the moving image to match the fixed moving_image = Resize(spatial_shape)(moving_image) img = tf.concat([moving_image, fixed_image], axis=-1) res_1 = DownRes(self.fea[0], kernel_size=7, **self.res_param)(img, is_training)[0] res_2 = DownRes(self.fea[1], **self.res_param)(res_1, is_training)[0] res_3 = DownRes(self.fea[2], **self.res_param)(res_2, is_training)[0] res_4 = DownRes(self.fea[3], **self.res_param)(res_3, is_training)[0] conv_5 = Conv(n_output_chns=self.fea[4], kernel_size=self.k_conv, with_bias=False, feature_normalization='batch', **self.res_param)(res_4, is_training) if spatial_rank == 2: affine_size = 6 elif spatial_rank == 3: affine_size = 12 else: tf.logging.fatal('Not supported spatial rank') raise NotImplementedError if self.affine_w_initializer is None: self.affine_w_initializer = init_affine_w() if self.affine_b_initializer is None: self.affine_b_initializer = init_affine_b(spatial_rank) affine = FC(n_output_chns=affine_size, feature_normalization=None, w_initializer=self.affine_w_initializer, b_initializer=self.affine_b_initializer, **self.affine_param)(conv_5) grid_global = Grid(source_shape=spatial_shape, output_shape=spatial_shape)(affine) return grid_global
def layer_op(self, tensor_a, tensor_b): """ match the spatial shape and concatenate the tensors tensor_a will be cropped and resized to match tensor_b. :param tensor_a: tensor, input :param tensor_b: tensor, input :return: concatenated tensor """ crop_border = (tensor_a.shape[1] - tensor_b.shape[1]) // 2 tensor_a = Crop(border=crop_border)(tensor_a) output_spatial_shape = tensor_b.shape[1:-1] tensor_a = Resize(new_size=output_spatial_shape)(tensor_a) return ElementWise('CONCAT')(tensor_a, tensor_b)
def layer_op(self, fixed_image, moving_image, base_grid=None, is_training=True, **unused_kwargs): """ :param fixed_image: :param moving_image: :param base_grid: :param is_training: :return: estimated dense displacement fields """ spatial_rank = infer_spatial_rank(fixed_image) spatial_shape = fixed_image.get_shape().as_list()[1:-1] check_spatial_dims(fixed_image, lambda x: x % 16 == 0) # resize the moving image to match the fixed moving_image = Resize(spatial_shape)(moving_image) img = tf.concat([moving_image, fixed_image], axis=-1) down_res_0, conv_0_0, _ = \ DownRes(self.fea[0], kernel_size=7, **self.down_res_param)(img, is_training) down_res_1, conv_0_1, _ = \ DownRes(self.fea[1], **self.down_res_param)(down_res_0, is_training) down_res_2, conv_0_2, _ = \ DownRes(self.fea[2], **self.down_res_param)(down_res_1, is_training) down_res_3, conv_0_3, _ = \ DownRes(self.fea[3], **self.down_res_param)(down_res_2, is_training) conv_4 = Conv(n_output_chns=self.fea[4], kernel_size=self.k_conv, **self.down_res_param)(down_res_3, is_training) up_res_0 = UpRes(self.fea[3], **self.up_res_param)(conv_4, conv_0_3, is_training) up_res_1 = UpRes(self.fea[2], **self.up_res_param)(up_res_0, conv_0_2, is_training) up_res_2 = UpRes(self.fea[1], **self.up_res_param)(up_res_1, conv_0_1, is_training) up_res_3 = UpRes(self.fea[0], **self.up_res_param)(up_res_2, conv_0_0, is_training) if self.multi_scale_fusion: output_list = [up_res_3, up_res_2, up_res_1, up_res_0, conv_4] else: output_list = [up_res_3] # converting all output layers to displacement fields dense_fields = [] for scale_out in output_list: field = Conv(n_output_chns=spatial_rank, kernel_size=self.k_conv, with_bias=True, with_bn=False, acti_func=None, **self.disp_param)(scale_out) resized_field = Resize(new_size=spatial_shape)(field) dense_fields.append(resized_field) if base_grid is None: # adding a reference grid if it doesn't exist in_spatial_size = [None] * spatial_rank base_grid = _create_affine_features(output_shape=spatial_shape, source_shape=in_spatial_size) base_grid = np.asarray(base_grid[:-1]) base_grid = np.reshape(base_grid.T, [-1] + spatial_shape + [spatial_rank]) base_grid = tf.constant(base_grid, dtype=resized_field.dtype) if self.multi_scale_fusion and len(dense_fields) > 1: dense_field = tf.reduce_sum(dense_fields, axis=0) else: dense_field = dense_fields[0] # TODO filtering if self.smoothing_func is not None: dense_field = self.smoothing_func(dense_field, spatial_rank) tf.add_to_collection('bending_energy', _computing_bending_energy(dense_field)) tf.add_to_collection('gradient_norm', _computing_gradient_norm(dense_field)) dense_field = dense_field + base_grid return dense_field
def layer_op(self): """ This function concatenate image and label volumes at the last dim and randomly cropping the volumes (also the cropping margins) """ image_id, fixed_inputs, moving_inputs, fixed_shape, moving_shape = \ self.iterator.get_next() # TODO preprocessing layer modifying # image shapes will not be supported # assuming the same shape across modalities, using the first image_id.set_shape((self.batch_size, )) image_id = tf.to_float(image_id) fixed_inputs.set_shape((self.batch_size, ) + (None, ) * self.spatial_rank + (2, )) # last dim is 1 image + 1 label moving_inputs.set_shape((self.batch_size, ) + self.moving_image_shape + (2, )) fixed_shape.set_shape((self.batch_size, self.spatial_rank + 1)) moving_shape.set_shape((self.batch_size, self.spatial_rank + 1)) # resizing the moving_inputs to match the target # assumes the same shape across the batch target_spatial_shape = \ tf.unstack(fixed_shape[0], axis=0)[:self.spatial_rank] moving_inputs = Resize(new_size=target_spatial_shape)(moving_inputs) combined_volume = tf.concat([fixed_inputs, moving_inputs], axis=-1) # smoothing_layer = Smoothing( # sigma=1, truncate=3.0, type_str='gaussian') # combined_volume = tf.unstack(combined_volume, axis=-1) # combined_volume[0] = tf.expand_dims(combined_volume[0], axis=-1) # combined_volume[1] = smoothing_layer( # tf.expand_dims(combined_volume[1]), axis=-1) # combined_volume[2] = tf.expand_dims(combined_volume[2], axis=-1) # combined_volume[3] = smoothing_layer( # tf.expand_dims(combined_volume[3]), axis=-1) # combined_volume = tf.stack(combined_volume, axis=-1) # TODO affine data augmentation here if self.spatial_rank == 3: window_channels = np.prod(self.window_size[self.spatial_rank:]) * 4 # TODO if no affine augmentation: img_spatial_shape = target_spatial_shape win_spatial_shape = [ tf.constant(dim) for dim in self.window_size[:self.spatial_rank] ] # when img==win make sure shift => 0.0 # otherwise interpolation is out of bound batch_shift = [ tf.random_uniform(shape=(self.batch_size, 1), minval=0, maxval=tf.maximum(tf.to_float(img - win - 1), 0.01)) for (win, img) in zip(win_spatial_shape, img_spatial_shape) ] batch_shift = tf.concat(batch_shift, axis=1) affine_constraints = ((1.0, 0.0, 0.0, None), (0.0, 1.0, 0.0, None), (0.0, 0.0, 1.0, None)) computed_grid = AffineGridWarperLayer( source_shape=(None, None, None), output_shape=self.window_size[:self.spatial_rank], constraints=affine_constraints)(batch_shift) computed_grid.set_shape((self.batch_size, ) + self.window_size[:self.spatial_rank] + (self.spatial_rank, )) resampler = ResamplerLayer(interpolation='linear', boundary='replicate') windows = resampler(combined_volume, computed_grid) out_shape = [self.batch_size] + \ list(self.window_size[:self.spatial_rank]) + \ [window_channels] windows.set_shape(out_shape) image_id = tf.reshape(image_id, (self.batch_size, 1)) start_location = tf.zeros((self.batch_size, self.spatial_rank)) locations = tf.concat([image_id, start_location, batch_shift], axis=1) return windows, locations