def testWarpTensor(self): eim = elastic_image.ElasticImage(tf.ones([3, 4]), tf.ones([5, 2])) eim.rotation = tf.constant(0.) eim.translation = tf.constant([3, 5]) eim.non_rigid = tf.ones([5, 2]) * 4 tensors = warp_utils.warp_tensor(eim, True, True, True) with self.test_session(): self.assertAllClose(tensors[0].eval(), np.zeros([1, 5, 2])) self.assertAllClose(tensors[1].eval(), np.array([[[3, 5]]])) self.assertAllClose(tensors[2].eval(), np.ones([1, 5, 2]) * 4)
def testWarpVariable(self): eim = elastic_image.ElasticImage(tf.ones([3, 4]), tf.ones([5, 2])) eim.rotation = tf.constant(7.) eim.translation = tf.constant([3, 5]) eim.non_rigid = tf.zeros([5, 2]) variables = warp_utils.warp_variables(eim, True, True, True) with self.test_session(): self.assertAllClose(variables[0].eval(), 7) self.assertAllClose(variables[1].eval(), [3, 5]) self.assertAllClose(variables[2].eval(), np.zeros([5, 2]))
def load_image( image: tf.Tensor, control_points: tf.Tensor, initial_rotation: float, initial_translation: List[float], initial_non_rigid_values_or_scale: Union[List[float], float] ): """Creates ElasticImage object with variables for registration. Convenience function that wraps instantiation of `ElasticImage` as well as variable creation using `warp_parameters`. Args: image: `tf.Tensor` of shape `[height, width]`. control_points: `tf.Tensor` of shape `[num_control_points, 2]`. Describes locations of control points in image used to parametrize warp. initial_rotation: Initial rotation of image in degrees. initial_translation: Initial translation of image. initial_non_rigid_values_or_scale: Either a List of initial displacements control_point grid or a float which parametrizes the scale of random initialization. See `warp_parameters.make_elastic_warp_variable`. Returns: ElasticImage object containing image and variables for rotation, translation, and non-rigid warp. Raises: ValueError: if `control_points` has an invalid shape. """ if not control_points.shape.is_compatible_with([None, 2]): raise ValueError("`control_points` must be compatible with [None, 2]" "got {}".format(control_points.shape)) eimage = elastic_image.ElasticImage(image, control_points) control_points_shape = eimage.control_points.shape rotation = warp_parameters.make_rotation_warp_variable(initial_rotation) eimage.rotation = rotation translation = warp_parameters.make_translation_warp_variable( initial_translation ) eimage.translation = translation non_rigid = warp_parameters.make_elastic_warp_variable( control_points_shape, initial_non_rigid_values_or_scale) eimage.non_rigid = non_rigid return eimage
def setUp(self): tf.reset_default_graph() self.image = tf.constant([[1, 2, 3], [4, 5, 6], [7, 8, 9]]) self.control_points = tf.constant([[0, 0], [1, 1]]) self.ei = elastic_image.ElasticImage(self.image, self.control_points)
def testElasticImageBadControlPoints(self): image = tf.constant([[1, 2, 3], [4, 5, 6], [7, 8, 9]]) control_points = tf.constant([[0], [1]]) with self.assertRaises(ValueError): elastic_image.ElasticImage(image, control_points)