def test_tps_param__tps_scal(self, tps_scal): from supermariopy.tfutils.tps import ( tps_parameters, make_input_tps_param, ThinPlateSpline, ) tf.random.set_random_seed(42) bs = 1 scal = 1.0 rot_scal = 0.1 off_scal = 0.1 scal_var = 0.05 augm_scal = 1.0 tps_param_dic = tps_parameters( bs, scal, tps_scal, rot_scal, off_scal, scal_var, augm_scal ) image = get_test_image() coord, vector = make_input_tps_param(tps_param_dic) t_images, t_mesh = ThinPlateSpline( image, coord, vector, image.shape[1], image.shape[-1] ) with plt.rc_context({"figure.figsize": [10, 5]}): plt.subplot(121) plt.imshow(np.squeeze(image[0])) plt.subplot(122) plt.imshow(np.squeeze(t_images[0])) return plt.gcf()
def test_input_tps_param_compatibility(): batch_size = 1 tf.set_random_seed(42) no_transform_params_tf = tf_tps.no_transformation_parameters(batch_size) tps_params = tf_tps.tps_parameters(**no_transform_params_tf) coords_tf, t_vector_tf = tf_tps.make_input_tps_param(tps_params) torch.manual_seed(42) no_transform_params_pt = pt_tps.no_transformation_parameters(batch_size) tps_params = pt_tps.tps_parameters(**no_transform_params_pt) coords_pt, t_vector_pt = pt_tps.make_input_tps_param(tps_params) assert np.allclose(coords_tf, coords_pt.numpy()) assert np.allclose(t_vector_tf, t_vector_pt.numpy())
def test_tps_parameter_compatiblity(): bs = 1 scal = 1.0 tps_scal = 0.05 rot_scal = 0.1 off_scal = 0.15 scal_var = 0.05 augm_scal = 1.0 tps_param_dic_tf = tf_tps.tps_parameters(bs, scal, tps_scal, rot_scal, off_scal, scal_var, augm_scal) tps_param_dic_pt = pt_tps.tps_parameters(bs, scal, tps_scal, rot_scal, off_scal, scal_var, augm_scal) keys = ["coord", "vector"] for k in keys: tps_param_dic_pt[k].numpy().shape == tps_param_dic_tf[k].shape
def test_tps_no_transform_params(self): from supermariopy.tfutils.tps import ( make_input_tps_param, ThinPlateSpline, tps_parameters, no_transformation_parameters, ) tf.random.set_random_seed(42) trf_args = no_transformation_parameters(1) tps_param_dic = tps_parameters(**trf_args) image = get_test_image() coord, vector = make_input_tps_param(tps_param_dic) t_images, t_mesh = ThinPlateSpline(image, coord, vector, image.shape[1], image.shape[-1]) with plt.rc_context({"figure.figsize": [10, 5]}): plt.subplot(121) plt.imshow(np.squeeze(image[0])) plt.subplot(122) plt.imshow(np.squeeze(t_images[0])) return plt.gcf()