Пример #1
0
  def test_homomorphism(self):
    rng = jax.random.PRNGKey(42)
    keys = jax.random.split(rng, 4)
    vec_q1 = jax.random.normal(keys[0], (2, 3))

    q1 = jnp.concatenate([
        jnp.ones_like(vec_q1)[:, :1],
        vec_q1], axis=-1)

    q2 = jax.random.normal(keys[1], (2, 4))
    t1 = jax.random.normal(keys[2], (2, 3))
    t2 = jax.random.normal(keys[3], (2, 3))

    a1 = quat_affine.QuatAffine(q1, t1, unstack_inputs=True)
    a2 = quat_affine.QuatAffine(q2, t2, unstack_inputs=True)
    a21 = a2.pre_compose(jnp.concatenate([vec_q1, t1], axis=-1))

    rng, key = jax.random.split(rng)
    x = jax.random.normal(key, (2, 3))
    new_x = a21.apply_to_point(jnp.moveaxis(x, -1, 0))
    new_x_apply2 = a2.apply_to_point(a1.apply_to_point(jnp.moveaxis(x, -1, 0)))

    self._assert_check({
        'quat': (q2r(quat_affine.quat_multiply(a2.quaternion, a1.quaternion)),
                 q2r(a21.quaternion)),
        'rot': (jnp.matmul(r2t(a2.rotation), r2t(a1.rotation)),
                r2t(a21.rotation)),
        'point': (v2t(new_x_apply2),
                  v2t(new_x)),
        'inverse': (x, v2t(a21.invert_point(new_x))),
    })
Пример #2
0
  def test_batching(self):
    """Test that affine applies batchwise."""
    rng = jax.random.PRNGKey(42)
    keys = jax.random.split(rng, 3)
    q = jax.random.uniform(keys[0], (5, 2, 4))
    t = jax.random.uniform(keys[1], (2, 3))
    x = jax.random.uniform(keys[2], (5, 1, 3))

    a = quat_affine.QuatAffine(q, t, unstack_inputs=True)
    y = v2t(a.apply_to_point(jnp.moveaxis(x, -1, 0)))

    y_list = []
    for i in range(5):
      for j in range(2):
        a_local = quat_affine.QuatAffine(q[i, j], t[j],
                                         unstack_inputs=True)
        y_local = v2t(a_local.apply_to_point(jnp.moveaxis(x[i, 0], -1, 0)))
        y_list.append(y_local)
    y_combine = jnp.reshape(jnp.stack(y_list, axis=0), (5, 2, 3))

    self._assert_check({
        'batch': (y_combine, y),
        'quat': (q2r(a.quaternion),
                 q2r(quat_affine.rot_to_quat(a.rotation))),
    })
Пример #3
0
  def test_double_cover(self):
    """Test that -q is the same rotation as q."""
    rng = jax.random.PRNGKey(42)
    keys = jax.random.split(rng)
    q = jax.random.normal(keys[0], (2, 4))
    trans = jax.random.normal(keys[1], (2, 3))
    a1 = quat_affine.QuatAffine(q, trans, unstack_inputs=True)
    a2 = quat_affine.QuatAffine(-q, trans, unstack_inputs=True)

    self._assert_check({
        'rot': (r2t(a1.rotation),
                r2t(a2.rotation)),
        'trans': (v2t(a1.translation),
                  v2t(a2.translation)),
    })
Пример #4
0
def generate_new_affine(sequence_mask):
    num_residues, _ = sequence_mask.shape
    quaternion = jnp.tile(jnp.reshape(jnp.asarray([1., 0., 0., 0.]), [1, 4]),
                          [num_residues, 1])

    translation = jnp.zeros([num_residues, 3])
    return quat_affine.QuatAffine(quaternion, translation, unstack_inputs=True)
Пример #5
0
def rigids_to_quataffine(r: Rigids) -> quat_affine.QuatAffine:
    """Convert Rigids r into QuatAffine, inverse of 'rigids_from_quataffine'."""
    return quat_affine.QuatAffine(
        quaternion=None,
        rotation=[[r.rot.xx, r.rot.xy,
                   r.rot.xz], [r.rot.yx, r.rot.yy, r.rot.yz],
                  [r.rot.zx, r.rot.zy, r.rot.zz]],
        translation=[r.trans.x, r.trans.y, r.trans.z])
Пример #6
0
  def test_conversion(self):
    quat = jnp.array([-2., 5., -1., 4.])

    rotation = jnp.array([
        [0.26087, 0.130435, 0.956522],
        [-0.565217, -0.782609, 0.26087],
        [0.782609, -0.608696, -0.130435]])

    translation = jnp.array([1., -3., 4.])
    point = jnp.array([0.7, 3.2, -2.9])

    a = quat_affine.QuatAffine(quat, translation, unstack_inputs=True)
    true_new_point = jnp.matmul(rotation, point[:, None])[:, 0] + translation

    self._assert_check({
        'rot': (rotation, r2t(a.rotation)),
        'trans': (translation, v2t(a.translation)),
        'point': (true_new_point,
                  v2t(a.apply_to_point(jnp.moveaxis(point, -1, 0)))),
        # Because of the double cover, we must be careful and compare rotations
        'quat': (q2r(a.quaternion),
                 q2r(quat_affine.rot_to_quat(a.rotation))),

    })