Ejemplo n.º 1
0
    def test_tps_param__tps_scal(self, tps_scal):
        from eddata.utils.tps import (
            tps_parameters,
            make_input_tps_param,
            ThinPlateSpline,
        )

        tf.random.set_random_seed(42)
        bs = 1
        scal = 1.0
        rot_scal = 0.1
        off_scal = 0.1
        scal_var = 0.05
        augm_scal = 1.0

        tps_param_dic = tps_parameters(bs, scal, tps_scal, rot_scal, off_scal,
                                       scal_var, augm_scal)
        image = get_test_image()

        coord, vector = make_input_tps_param(tps_param_dic)
        t_images, t_mesh = ThinPlateSpline(image, coord, vector,
                                           image.shape[1], image.shape[-1])

        with plt.rc_context({"figure.figsize": [10, 5]}):
            plt.subplot(121)
            plt.imshow(np.squeeze(image[0]))
            plt.subplot(122)
            plt.imshow(np.squeeze(t_images[0]))
        return plt.gcf()
Ejemplo n.º 2
0
    def test_tps_no_transform_params(self):
        from eddata.utils.tps import (
            make_input_tps_param,
            ThinPlateSpline,
            tps_parameters,
            no_transformation_parameters,
        )

        tf.random.set_random_seed(42)
        trf_args = no_transformation_parameters(1)
        tps_param_dic = tps_parameters(**trf_args)
        image = get_test_image()

        coord, vector = make_input_tps_param(tps_param_dic)
        t_images, t_mesh = ThinPlateSpline(image, coord, vector,
                                           image.shape[1], image.shape[-1])

        with plt.rc_context({"figure.figsize": [10, 5]}):
            plt.subplot(121)
            plt.imshow(np.squeeze(image[0]))
            plt.subplot(122)
            plt.imshow(np.squeeze(t_images[0]))
        return plt.gcf()
Ejemplo n.º 3
0
    def make_tps(self, views, tps_parameters):
        """make tps on views

        Parameters
        ----------
        views : list of tf.tensor
            list of 2 image sensors shaped [N, H, W, C].
        tps_parameters: dict:
            tps transformation args

        Returns
        -------

        """
        img_batch = tf.concat(views, axis=0)
        bs_doubled = img_batch.shape.as_list()[0]
        tps_param_dict = tps.tps_parameters(bs_doubled, **tps_parameters)
        coord, vector = tps.make_input_tps_param(tps_param_dict)
        t_images, t_mesh = tps.ThinPlateSpline(img_batch, coord, vector,
                                               img_batch.shape[1],
                                               img_batch.shape[-1])
        augmented_views = tf.split(t_images, 2, axis=0)
        return augmented_views