def test_tps_param__tps_scal(self, tps_scal): from eddata.utils.tps import ( tps_parameters, make_input_tps_param, ThinPlateSpline, ) tf.random.set_random_seed(42) bs = 1 scal = 1.0 rot_scal = 0.1 off_scal = 0.1 scal_var = 0.05 augm_scal = 1.0 tps_param_dic = tps_parameters(bs, scal, tps_scal, rot_scal, off_scal, scal_var, augm_scal) image = get_test_image() coord, vector = make_input_tps_param(tps_param_dic) t_images, t_mesh = ThinPlateSpline(image, coord, vector, image.shape[1], image.shape[-1]) with plt.rc_context({"figure.figsize": [10, 5]}): plt.subplot(121) plt.imshow(np.squeeze(image[0])) plt.subplot(122) plt.imshow(np.squeeze(t_images[0])) return plt.gcf()
def test_tps_no_transform_params(self): from eddata.utils.tps import ( make_input_tps_param, ThinPlateSpline, tps_parameters, no_transformation_parameters, ) tf.random.set_random_seed(42) trf_args = no_transformation_parameters(1) tps_param_dic = tps_parameters(**trf_args) image = get_test_image() coord, vector = make_input_tps_param(tps_param_dic) t_images, t_mesh = ThinPlateSpline(image, coord, vector, image.shape[1], image.shape[-1]) with plt.rc_context({"figure.figsize": [10, 5]}): plt.subplot(121) plt.imshow(np.squeeze(image[0])) plt.subplot(122) plt.imshow(np.squeeze(t_images[0])) return plt.gcf()
def make_tps(self, views, tps_parameters): """make tps on views Parameters ---------- views : list of tf.tensor list of 2 image sensors shaped [N, H, W, C]. tps_parameters: dict: tps transformation args Returns ------- """ img_batch = tf.concat(views, axis=0) bs_doubled = img_batch.shape.as_list()[0] tps_param_dict = tps.tps_parameters(bs_doubled, **tps_parameters) coord, vector = tps.make_input_tps_param(tps_param_dict) t_images, t_mesh = tps.ThinPlateSpline(img_batch, coord, vector, img_batch.shape[1], img_batch.shape[-1]) augmented_views = tf.split(t_images, 2, axis=0) return augmented_views