def test_resample(): """ Test resample by confirming that it generates appropriate resampling on two test cases with outputs within check_equal's tolerance level, and one which should fail (incompatible shapes). """ # linear, vol has no feature channel - Pass interpolation = "linear" vol = tf.constant(np.array([[[0, 1, 2], [3, 4, 5]]], dtype=np.float32)) # shape = [1,2,3] loc = tf.constant( np.array( [[ [[0, 0], [0, 1], [0, 3]], # outside frame [[0.4, 0], [0.5, 1], [0.6, 2]], [[0.4, 0.7], [0.5, 0.5], [0.6, 0.3]], ]], # resampled = 3x+y dtype=np.float32, )) # shape = [1,3,3,2] want = tf.constant( np.array([[[0, 1, 2], [1.2, 2.5, 3.8], [1.9, 2, 2.1]]], dtype=np.float32)) # shape = [1,3,3] get = layer_util.resample(vol=vol, loc=loc, interpolation=interpolation) assert check_equal(want, get) # linear, vol has feature channel - Pass interpolation = "linear" vol = tf.constant( np.array([[[[0, 0], [1, 1], [2, 2]], [[3, 3], [4, 4], [5, 5]]]], dtype=np.float32)) # shape = [1,2,3,2] loc = tf.constant( np.array( [[ [[0, 0], [0, 1], [0, 3]], # outside frame [[0.4, 0], [0.5, 1], [0.6, 2]], [[0.4, 0.7], [0.5, 0.5], [0.6, 0.3]], ]], # resampled = 3x+y dtype=np.float32, )) # shape = [1,3,3,2] want = tf.constant( np.array( [[ [[0, 0], [1, 1], [2, 2]], [[1.2, 1.2], [2.5, 2.5], [3.8, 3.8]], [[1.9, 1.9], [2, 2], [2.1, 2.1]], ]], dtype=np.float32, )) # shape = [1,3,3,2] get = layer_util.resample(vol=vol, loc=loc, interpolation=interpolation) assert check_equal(want, get) # Inconsistent shapes for resampling - Fail interpolation = "linear" vol = tf.constant(np.array([[0]], dtype=np.float32)) # shape = [1,1] loc = tf.constant(np.array([[0, 0], [0, 0]], dtype=np.float32)) # shape = [2,2] with pytest.raises(ValueError) as execinfo: layer_util.resample(vol=vol, loc=loc, interpolation=interpolation) msg = " ".join(execinfo.value.args[0].split()) assert "vol shape inconsistent with loc" in msg
def test_shape_error(self): vol = tf.constant(np.array([[0]], dtype=np.float32)) # shape = [1,1] loc = tf.constant(np.array([[0, 0], [0, 0]], dtype=np.float32)) # shape = [2,2] with pytest.raises(ValueError) as err_info: layer_util.resample(vol=vol, loc=loc) assert "vol shape inconsistent with loc" in str(err_info.value)
def test_interpolation_error(self): interpolation = "nearest" vol = tf.constant(np.array([[0]], dtype=np.float32)) # shape = [1,1] loc = tf.constant(np.array([[0, 0], [0, 0]], dtype=np.float32)) # shape = [2,2] with pytest.raises(ValueError) as err_info: layer_util.resample(vol=vol, loc=loc, interpolation=interpolation) assert "resample supports only linear interpolation" in str( err_info.value)
def call(self, inputs, **kwargs) -> tf.Tensor: """ :param inputs: (ddf, image) - ddf, shape = (batch, f_dim1, f_dim2, f_dim3, 3) - image, shape = (batch, m_dim1, m_dim2, m_dim3) :param kwargs: additional arguments. :return: shape = (batch, f_dim1, f_dim2, f_dim3) """ ddf, image = inputs return layer_util.resample(vol=image, loc=self.grid_ref + ddf)
def transform(image: tf.Tensor, grid_ref: tf.Tensor, params: tf.Tensor) -> tf.Tensor: """ Transforms the reference grid and then resample the image. :param image: shape = (batch, dim1, dim2, dim3) :param grid_ref: shape = (dim1, dim2, dim3, 3) :param params: DDF, shape = (batch, dim1, dim2, dim3, 3) :return: shape = (batch, dim1, dim2, dim3) """ return resample(vol=image, loc=grid_ref[None, ...] + params)
def _transform(image, grid_ref, transforms): """ :param image: shape = [batch, dim1, dim2, dim3] :param grid_ref: shape = [dim1, dim2, dim3, 3] :param transforms: shape = [batch, 4, 3] :return: shape = [batch, dim1, dim2, dim3] """ transformed = layer_util.resample(vol=image, loc=layer_util.warp_grid(grid_ref, transforms)) return transformed
def test_repeat_extrapolation(self, channel): x = self.loc[..., 0] y = self.loc[..., 1] x = tf.clip_by_value(x, self.x_min, self.x_max) y = tf.clip_by_value(y, self.y_min, self.y_max) expected = 3 * x + y vol = self.vol if channel > 0: vol = tf.repeat(vol[..., None], channel, axis=-1) expected = tf.repeat(expected[..., None], channel, axis=-1) got = layer_util.resample(vol=vol, loc=self.loc, zero_boundary=False) assert is_equal_tf(expected, got)
def _transform(image, grid_ref, transforms): """ Resamples an input image from the reference grid by the series of input transforms. :param image: shape = (batch, dim1, dim2, dim3) :param grid_ref: shape = [dim1, dim2, dim3, 3] :param transforms: shape = [batch, 4, 3] :return: shape = (batch, dim1, dim2, dim3) """ transformed = layer_util.resample(vol=image, loc=layer_util.warp_grid( grid_ref, transforms)) return transformed
def test_repeat_zero_bound(self, channel): x = self.loc[..., 0] y = self.loc[..., 1] expected = 3 * x + y expected = (expected * tf.cast(x > self.x_min, tf.float32) * tf.cast(x <= self.x_max, tf.float32)) expected = (expected * tf.cast(y > self.y_min, tf.float32) * tf.cast(y <= self.y_max, tf.float32)) vol = self.vol if channel > 0: vol = tf.repeat(vol[..., None], channel, axis=-1) expected = tf.repeat(expected[..., None], channel, axis=-1) got = layer_util.resample(vol=vol, loc=self.loc, zero_boundary=True) assert is_equal_tf(expected, got)
def call(self, inputs, **kwargs): """ wrap an image into a fixed size using ddf same functionality as transform of neuron https://github.com/adalca/neuron/blob/master/neuron/utils.py vol = image loc_shift = ddf :param inputs: [ddf, image] ddf.shape = [batch, f_dim1, f_dim2, f_dim3, 3] image.shape = [batch, m_dim1, m_dim2, m_dim3] :param kwargs: :return: shape = [batch, f_dim1, f_dim2, f_dim3] """ grid_warped = self.grid_ref + inputs[0] # [batch, f_dim1, f_dim2, f_dim3, 3] image_warped = layer_util.resample( vol=inputs[1], loc=grid_warped ) # [batch, f_dim1, f_dim2, f_dim3] return image_warped
def train_step(grid, weights, optimizer, mov, fix): """ Train step function for backprop using gradient tape :param grid: reference grid return from layer_util.get_reference_grid :param weights: trainable affine parameters [1, 4, 3] :param optimizer: tf.optimizers :param mov: moving image [1, m_dim1, m_dim2, m_dim3] :param fix: fixed image [1, f_dim1, f_dim2, f_dim3] :return loss: image dissimilarity to minimise """ with tf.GradientTape() as tape: pred = layer_util.resample(vol=mov, loc=layer_util.warp_grid(grid, weights)) loss = image_loss.dissimilarity_fn( y_true=fix, y_pred=pred, name=image_loss_name ) gradients = tape.gradient(loss, [weights]) optimizer.apply_gradients(zip(gradients, [weights])) return loss
def predict_on_dataset( dataset: tf.data.Dataset, fixed_grid_ref: tf.Tensor, model: tf.keras.Model, model_method: str, save_dir: str, save_nifti: bool, save_png: bool, ): """ Function to predict results from a dataset from some model :param dataset: where data is stored :param fixed_grid_ref: shape=(1, f_dim1, f_dim2, f_dim3, 3) :param model: model to be used for prediction :param model_method: str, ddf / dvf / affine / conditional :param save_dir: str, path to store dir :param save_nifti: if true, outputs will be saved in nifti format :param save_png: if true, outputs will be saved in png format """ # remove the save_dir in case it exists if os.path.exists(save_dir): shutil.rmtree(save_dir) sample_index_strs = [] metric_lists = [] for _, inputs_dict in enumerate(dataset): outputs_dict = model.predict(x=inputs_dict) # moving image/label # (batch, m_dim1, m_dim2, m_dim3) moving_image = inputs_dict["moving_image"] moving_label = inputs_dict.get("moving_label", None) # fixed image/labelimage_index # (batch, f_dim1, f_dim2, f_dim3) fixed_image = inputs_dict["fixed_image"] fixed_label = inputs_dict.get("fixed_label", None) # indices to identify the pair # (batch, num_indices) last indice is for label, -1 means unlabeled data indices = inputs_dict.get("indices") # ddf / dvf # (batch, f_dim1, f_dim2, f_dim3, 3) ddf = outputs_dict.get("ddf", None) dvf = outputs_dict.get("dvf", None) affine = outputs_dict.get("affine", None) # (batch, 4, 3) # prediction # (batch, f_dim1, f_dim2, f_dim3) pred_fixed_label = outputs_dict.get("pred_fixed_label", None) pred_fixed_image = (layer_util.resample( vol=moving_image, loc=fixed_grid_ref + ddf) if ddf is not None else None) # save images of inputs and outputs for sample_index in range(moving_image.shape[0]): # save moving/fixed image under pair_dir # save moving/fixed label, pred fixed image/label, ddf/dvf under label dir # if labeled, label dir is a sub dir of pair_dir, otherwise = pair_dir # init output path indices_i = indices[sample_index, :].numpy().astype(int).tolist() pair_dir, label_dir = build_pair_output_path(indices=indices_i, save_dir=save_dir) # save image/label # if model is conditional, the pred_fixed_image depends on the input label conditional = model_method == "conditional" arr_save_dirs = [ pair_dir, pair_dir, label_dir if conditional else pair_dir, label_dir, label_dir, label_dir, ] arrs = [ moving_image, fixed_image, pred_fixed_image, moving_label, fixed_label, pred_fixed_label, ] names = [ "moving_image", "fixed_image", "pred_fixed_image", # or warped moving image "moving_label", "fixed_label", "pred_fixed_label", # or warped moving label ] for arr_save_dir, arr, name in zip(arr_save_dirs, arrs, names): if arr is not None: # for files under pair_dir, do not overwrite save_array( save_dir=arr_save_dir, arr=arr[sample_index, :, :, :], name=name, gray=True, save_nifti=save_nifti, save_png=save_png, overwrite=arr_save_dir == label_dir, ) # save ddf / dvf arrs = [ddf, dvf] names = ["ddf", "dvf"] for arr, name in zip(arrs, names): if arr is not None: arr = normalize_array(arr=arr[sample_index, :, :, :]) save_array( save_dir=label_dir if conditional else pair_dir, arr=arr, name=name, gray=False, save_nifti=save_nifti, save_png=save_png, ) # save affine if affine is not None: np.savetxt( fname=os.path.join(label_dir if conditional else pair_dir, "affine.txt"), x=affine[sample_index, :, :].numpy(), delimiter=",", ) # calculate metric sample_index_str = "_".join([str(x) for x in indices_i]) if sample_index_str in sample_index_strs: raise ValueError( "Sample is repeated, maybe the dataset has been repeated.") sample_index_strs.append(sample_index_str) metric = calculate_metrics( fixed_image=fixed_image, fixed_label=fixed_label, pred_fixed_image=pred_fixed_image, pred_fixed_label=pred_fixed_label, fixed_grid_ref=fixed_grid_ref, sample_index=sample_index, ) metric["pair_index"] = indices_i[:-1] metric["label_index"] = indices_i[-1] metric_lists.append(metric) # save metric save_metric_dict(save_dir=save_dir, metrics=metric_lists)
raise ("Download the data using demo_data.py script") if not os.path.exists(FILE_PATH): raise ("Download the data using demo_data.py script") fid = h5py.File(FILE_PATH, "r") fixed_image = tf.cast(tf.expand_dims(fid["image"], axis=0), dtype=tf.float32) fixed_image = (fixed_image - tf.reduce_min(fixed_image)) / ( tf.reduce_max(fixed_image) - tf.reduce_min(fixed_image) ) # normalisation to [0,1] # generate a radomly-affine-transformed moving image fixed_image_size = fixed_image.shape transform_random = layer_util.random_transform_generator(batch_size=1, scale=0.2) grid_ref = layer_util.get_reference_grid(grid_size=fixed_image_size[1:4]) grid_random = layer_util.warp_grid(grid_ref, transform_random) moving_image = layer_util.resample(vol=fixed_image, loc=grid_random) # warp the labels to get ground-truth using the same random affine, for validation fixed_labels = tf.cast(tf.expand_dims(fid["label"], axis=0), dtype=tf.float32) moving_labels = tf.stack( [ layer_util.resample(vol=fixed_labels[..., idx], loc=grid_random) for idx in range(fixed_labels.shape[4]) ], axis=4, ) ## optimisation @tf.function def train_step(grid, weights, optimizer, mov, fix): """
def test_resample(self): # linear, vol has no feature channel interpolation = "linear" vol = tf.constant( np.array([[ [0, 1, 2], [3, 4, 5], ]], dtype=np.float32)) # shape = [1,2,3] loc = tf.constant( np.array( [[ [[0, 0], [0, 1], [0, 3]], # outside frame [[0.4, 0], [0.5, 1], [0.6, 2]], [[0.4, 0.7], [0.5, 0.5], [0.6, 0.3]], # resampled = 3x+y ]], dtype=np.float32)) # shape = [1,3,3,2] want = tf.constant( np.array([[ [0, 1, 2], [1.2, 2.5, 3.8], [1.9, 2, 2.1], ]], dtype=np.float32)) # shape = [1,3,3] get = layer_util.resample(vol=vol, loc=loc, interpolation=interpolation) self.check_equal(want, get) # linear, vol has feature channel interpolation = "linear" vol = tf.constant( np.array([[ [ [0, 0], [1, 1], [2, 2], ], [ [3, 3], [4, 4], [5, 5], ], ]], dtype=np.float32)) # shape = [1,2,3,2] loc = tf.constant( np.array( [[ [[0, 0], [0, 1], [0, 3]], # outside frame [[0.4, 0], [0.5, 1], [0.6, 2]], [[0.4, 0.7], [0.5, 0.5], [0.6, 0.3]], # resampled = 3x+y ]], dtype=np.float32)) # shape = [1,3,3,2] want = tf.constant( np.array([[ [[0, 0], [1, 1], [2, 2]], [[1.2, 1.2], [2.5, 2.5], [3.8, 3.8]], [[1.9, 1.9], [2, 2], [2.1, 2.1]], ]], dtype=np.float32)) # shape = [1,3,3,2] get = layer_util.resample(vol=vol, loc=loc, interpolation=interpolation) self.check_equal(want, get)