예제 #1
0
  def test_unnormalized_quaternion_weights_exception_raised(self):
    """Tests if quaternion_weights raise exceptions for unnormalized input."""
    q1 = self._pick_random_quaternion()
    q2 = tf.nn.l2_normalize(q1, axis=-1)
    p = tf.constant((0.5), dtype=q1.dtype)

    with self.assertRaises(tf.errors.InvalidArgumentError):
      self.evaluate(slerp.quaternion_weights(q1, q2, p))
예제 #2
0
  def test_interpolate_with_weights_quaternion_preset(self):
    """Compares interpolate to quaternion_weights + interpolate_with_weights."""
    q1 = self._pick_random_quaternion()
    q2 = q1 + tf.ones_like(q1)
    q1 = tf.nn.l2_normalize(q1, axis=-1)
    q2 = tf.nn.l2_normalize(q2, axis=-1)

    weight1, weight2 = slerp.quaternion_weights(q1, q2, 0.25)
    qf = slerp.interpolate_with_weights(q1, q2, weight1, weight2)
    qi = slerp.interpolate(
        q1, q2, 0.25, method=slerp.InterpolationType.QUATERNION)

    self.assertAllClose(qf, qi, atol=1e-9)