def frozen_flow_transform(t, y, x0, bottom, wind_velocity=None):
    """
    Computes the frozen flow transform on the coordinates.

    Args:
        t: time in seconds
        y: position in km, origin at centre of Earth
        x0: position of reference point.
        bottom: bottom of ionosphere in km
        wind_velocity: layer velocity in km/s

    Returns:
        Coordinates inverse rotating the coordinates to take into account the flow of ionosphere around
        surface of Earth.
    """
    # rotate around Earth's core
    if t is None:
        return y
    if wind_velocity is None:
        return y

    rotation_axis = jnp.cross(wind_velocity, x0)
    rotation_axis /= jnp.linalg.norm(rotation_axis)
    # v(r) = theta_dot * r
    # km/s / km
    theta_dot = jnp.linalg.norm(wind_velocity) / (bottom + jnp.linalg.norm(x0))

    angle = -theta_dot * t
    # Rotation
    u_cross_x = jnp.cross(rotation_axis, y)
    rotated_y = rotation_axis * (rotation_axis @ y) \
              + jnp.cos(angle) * jnp.cross(u_cross_x, rotation_axis) \
              + jnp.sin(angle) * u_cross_x
    # print(rotated_y - y, t*wind_velocity, wind_velocity)
    return rotated_y
示例#2
0
def vector_rotate(a, b, theta):
    """
    Rotate vector a around vector b by an angle theta (radians)

    Programming Notes:
      u: parallel projection of a on b_hat.
      v: perpendicular projection of a on b_hat.
      w: a vector perpendicular to both a and b.
    """

    if a.ndim == 2:
        b_hat = b / np.linalg.norm(b)
        dot = np.einsum('ij,j->i', a, b_hat)
        u = np.einsum('i,j->ij', dot, b_hat)
        v = a - u
        w = np.cross(b_hat, v)
        c = u + v * np.cos(theta) + w * np.sin(theta)
    elif a.ndim == 1:
        b_hat = b / np.linalg.norm(b)
        u = b_hat * np.dot(a, b_hat)
        v = a - u
        w = np.cross(b_hat, v)
        c = u + v * np.cos(theta) + w * np.sin(theta)
    else:
        raise Exception(
            'Input array must be 1d (vector) or 2d (array of vectors)')
    return c
示例#3
0
    def extend(tri, pt, multi_m):
        """
        Args:
            tri: NUM_DIHEDRALS x [NUM_FRAGS/0,         BATCH_SIZE, NUM_DIMENSIONS]
            pt:                  [NUM_FRAGS/FRAG_SIZE, BATCH_SIZE, NUM_DIMENSIONS]
            multi_m: bool indicating whether m (and tri) is higher rank. pt is always higher rank; what changes is what the first rank is.
        """

        bc = normalize(tri.c - tri.b,
                       axis=-1)  # [NUM_FRAGS/0, BATCH_SIZE, NUM_DIMS]
        n = normalize(np.cross(tri.b - tri.a, bc),
                      axis=-1)  # [NUM_FRAGS/0, BATCH_SIZE, NUM_DIMS]
        if multi_m:  # multiple fragments, one atom at a time.
            m = np.transpose(
                np.stack([bc, np.cross(n, bc), n]),
                [1, 2, 3, 0])  # [NUM_FRAGS,   BATCH_SIZE, NUM_DIMS, 3 TRANS]
        else:  # single fragment, reconstructed entirely at once.
            s = onp.pad(
                pt.shape, [[0, 1]],
                constant_values=3)  # FRAG_SIZE, BATCH_SIZE, NUM_DIMS, 3 TRANS
            m = np.transpose(np.stack([bc, np.cross(n, bc), n]),
                             [1, 2, 0])  # [BATCH_SIZE, NUM_DIMS, 3 TRANS]
            m = np.reshape(np.tile(m, [s[0], 1, 1]),
                           s)  # [FRAG_SIZE, BATCH_SIZE, NUM_DIMS, 3 TRANS]
        coord = np.squeeze(
            np.matmul(m, np.expand_dims(pt, 3)),
            axis=3) + tri.c  # [NUM_FRAGS/FRAG_SIZE, BATCH_SIZE, NUM_DIMS]
        return coord
示例#4
0
def _dihedral(pos: jnp.ndarray, indices: jnp.ndarray,
              tvecs: jnp.ndarray) -> float:
    dx1 = pos[indices[1]] - pos[indices[0]] + tvecs[0]
    dx2 = pos[indices[2]] - pos[indices[1]] + tvecs[1]
    dx3 = pos[indices[3]] - pos[indices[2]] + tvecs[2]
    numer = dx2 @ jnp.cross(jnp.cross(dx1, dx2), jnp.cross(dx2, dx3))
    denom = jnp.linalg.norm(dx2) * jnp.cross(dx1, dx2) @ jnp.cross(dx2, dx3)
    return jnp.arctan2(numer, denom)
示例#5
0
def calc_torsions(xyz: np.ndarray, struct_descr_dict):
    b1 = xyz[struct_descr_dict['torsions'][:, 1]] - xyz[struct_descr_dict['torsions'][:, 0]]
    b2 = xyz[struct_descr_dict['torsions'][:, 2]] - xyz[struct_descr_dict['torsions'][:, 1]]
    b3 = xyz[struct_descr_dict['torsions'][:, 3]] - xyz[struct_descr_dict['torsions'][:, 2]]

    b12 = np.cross(b1, b2)
    b23 = np.cross(b2, b3)
    return np.arctan2((np.cross(b12, b23) * b2).sum(axis=-1) / np.linalg.norm(b2, axis=-1), (b12 * b23).sum(axis=-1))
示例#6
0
 def __call__(self, r, h, p):
     r0, n, s = jnp.array(p[-7:-4]), jnp.array(p[-4:-1]), p[-1]
     nz = jnp.array([0.0, 0.0, 1.0])
     v = jnp.cross(jnp.cross(nz, n), n)
     rdelt = s * v / jnp.sqrt(jnp.dot(v, v))
     g0 = self.base_gfunc(r, h, p[:-7])
     g1 = self.base_gfunc(r + rdelt, h, p[:-7])
     return soft_if_then_logjit(jnp.dot(r - r0, n), g0, g1, h)
示例#7
0
def dihedral_angle(p1, p2, p3, p4):
    """
    Returns the dihedral angle defined by four points in space
    (around the line defined by the two central points).
    """
    q = p3 - p2
    r = np.cross(p2 - p1, q)
    s = np.cross(q, p4 - p3)
    return np.arctan2(np.dot(np.cross(r, s), q), np.dot(r, s) * linalg.norm(q))
示例#8
0
  def torsionVecs_(self, P):
      p0 = P[0]
      p1 = P[1]
      p2 = P[2]
      p3 = P[3]

      r1 = p0 - p1
      r2 = p1 - p2
      r3 = p3 - p2
      cp_12 = np.cross(r1, r2)
      cp_32 = np.cross(r3, r2)
      return np.dstack((cp_12, np.zeros(cp_12.shape), cp_32)) \
        .squeeze() \
        .transpose([1, 0])
示例#9
0
def signed_torsion_angle(ci, cj, ck, cl):
    """
    Batch compute the signed angle of a torsion angle.  The torsion angle
    between two planes should be periodic but not necessarily symmetric.

    Parameters
    ----------
    ci: shape [num_torsions, 3] np.array
        coordinates of the 1st atom in the 1-4 torsion angle

    cj: shape [num_torsions, 3] np.array
        coordinates of the 2nd atom in the 1-4 torsion angle

    ck: shape [num_torsions, 3] np.array
        coordinates of the 3rd atom in the 1-4 torsion angle

    cl: shape [num_torsions, 3] np.array
        coordinates of the 4th atom in the 1-4 torsion angle

    Returns
    -------
    shape [num_torsions,] np.array
        array of torsion angles.

    """

    # Taken from the wikipedia arctan2 implementation:
    # https://en.wikipedia.org/wiki/Dihedral_angle

    # We use an identical but numerically stable arctan2
    # implementation as opposed to the OpenMM energy function to
    # avoid asingularity when the angle is zero.

    rij = delta_r(cj, ci)
    rkj = delta_r(cj, ck)
    rkl = delta_r(cl, ck)

    n1 = np.cross(rij, rkj)
    n2 = np.cross(rkj, rkl)

    lhs = np.linalg.norm(n1, axis=-1)
    rhs = np.linalg.norm(n2, axis=-1)
    bot = lhs * rhs

    y = np.sum(np.multiply(np.cross(n1, n2),
                           rkj / np.linalg.norm(rkj, axis=-1, keepdims=True)),
               axis=-1)
    x = np.sum(np.multiply(n1, n2), -1)

    return np.arctan2(y, x)
示例#10
0
  def torsionVecs(self, P):
      p0 = P[...,[0],[0,1,2]]
      p1 = P[...,[1],[0,1,2]]
      p2 = P[...,[2],[0,1,2]]
      p3 = P[...,[3],[0,1,2]]

      r1 = p0 - p1
      r2 = p1 - p2
      r3 = p3 - p2
      cp_12 = np.cross(r1, r2)
      cp_32 = np.cross(r3, r2)
      return np.dstack((cp_12, np.zeros(cp_12.shape), cp_32)) \
        .squeeze() \
        .transpose([0, 2, 1])
示例#11
0
    def biot_savart_saddle(r, I, dl, l):
        """
		Inputs:

		r : Position we want to evaluate at, NZ x NT x 3
		I : Current in ith coil, length NC
		dl : Vector which has coil segment length and direction, NC x NS x 3
		l : Positions of center of each coil segment, NC x NS x 3

		Returns: 

		A NZ x NT x 3 array which is the magnetic field vector on the surface points 
		"""
        mu_0 = 1.
        mu_0I = I * mu_0
        mu_0Idl = mu_0I[:, np.newaxis, np.newaxis] * dl  # NC x NS x 3
        r_minus_l = r[
            np.newaxis, :, :,
            np.newaxis, :] - l[:, np.newaxis,
                               np.newaxis, :, :]  # NC x NZ x NT x NS x 3
        top = np.cross(mu_0Idl[:, np.newaxis, np.newaxis, :, :],
                       r_minus_l)  # NC x NZ x NT x NS x 3
        bottom = np.linalg.norm(r_minus_l, axis=-1)**3  # NC x NZ x NT x NS
        B = np.sum(top / bottom[:, :, :, :, np.newaxis],
                   axis=(0, 3))  # NZ x NT x 3
        return B
示例#12
0
    def computeB(I, dl, l, r, zeta, z):
        """
			Inputs:

			r, zeta, z : The coordinates of the point we want the magnetic field at. Cylindrical coordinates.

			Outputs: 

			B_z, B_zeta, B_z : the magnetic field components at the input coordinates created by the currents in the coils. Cylindrical coordinates.
		"""
        x = r * np.cos(zeta)
        y = r * np.sin(zeta)
        xyz = np.asarray([x, y, z])

        mu_0 = 1.
        mu_0I = I * mu_0  # NC
        mu_0Idl = mu_0I[:, np.newaxis, np.newaxis, np.newaxis,
                        np.newaxis] * dl  # NC x NS x NNR x NBR x 3
        r_minus_l = xyz[
            np.newaxis, np.newaxis, np.newaxis,
            np.newaxis, :] - l[:, :, :, :, :]  # NC x NS x NNR x NBR x 3
        top = np.cross(mu_0Idl, r_minus_l)  # NC x x NS x NNR x NBR x 3
        bottom = np.linalg.norm(r_minus_l, axis=-1)**3  # NC x NS x NNR x NBR
        B_xyz = np.sum(top / bottom[:, :, :, :, np.newaxis],
                       axis=(0, 1, 2, 3))  # 3, xyz coordinates
        B_x = B_xyz[0]
        B_y = B_xyz[1]
        B_z = B_xyz[2]
        B_r = B_x * np.cos(zeta) + B_y * np.sin(zeta)
        B_zeta = -B_x * np.sin(zeta) + B_y * np.cos(zeta)
        return B_r, B_zeta, B_z
示例#13
0
def biot_savart_oncoil(r_eval, dl, ll, I_arr):
    """
    Calculate the Biot-Savart integral over the coils (also ON) a segment of the
    coil.
    specified by l and dl.
    Arguments:
    *r_eval*: (lenght 3 array) the point wherer the field is to be evaluated in cartesian
    coordinates. Has to be on a coil.
    *dl*: ( n_coils, nsegments, 3)-array of the distance vector to every
    other coil line segment
    *l* ( n_coils, nsegments, 3)-array of the position of each coil segment

    Note on algoritnm: the None allows one to add new axes to in-line
    cast the array into the proper shape.
    The biot-savart integral is calculated as a sum over all segments.

    returns:
    *B*: magnetic field at position r_eval
    """
    top = np.cross(dl, r_eval[None, None, :] - ll) * I_arr[:, None,
                                                           None]  #unchecked
    bottom = np.linalg.norm(r_eval[None, None, :] - ll, axis=-1)**3
    # sum over all infinitesimal line segments, replacing the NaN with zero
    B = np.sum(np.nan_to_num(top / bottom[:, :, None]), axis=(0, 1))
    return B
示例#14
0
def angle(p1, p2, p3):
    """
    Returns the angle defined by three points in space
    (around the one in the middle).
    """
    q = p1 - p2
    r = p3 - p2
    return np.arctan2(linalg.norm(np.cross(q, r)), np.dot(q, r))
示例#15
0
def geo(joint):
    idx_a = npj.array([1,5,9,13,17])
    idx_b = npj.array([2,6,10,14,18])
    idx_c = npj.array([3,7,11,15,19])
    idx_d = npj.array([4,8,12,16,20])
    p_a = joint[:,idx_a,:]
    p_b = joint[:,idx_b,:]
    p_c = joint[:,idx_c,:]
    p_d = joint[:,idx_d,:]
    v_ab = p_a - p_b #(B, 5, 3)
    v_bc = p_b - p_c #(B, 5, 3)
    v_cd = p_c - p_d #(B, 5, 3)
    loss_1 = npj.abs(npj.sum(npj.cross(v_ab, v_bc, -1) * v_cd, -1)).mean()
    loss_2 = - npj.clip(npj.sum(npj.cross(v_ab, v_bc, -1) * npj.cross(v_bc, v_cd, -1)), -npj.inf, 0).mean()
    loss = 10000*loss_1 + 100000*loss_2

    return loss
示例#16
0
def biot_savart(r, I, dl, l):
    mu_0 = 1.0
    mu_0I = I * mu_0
    mu_0Idl = (mu_0I[:, np.newaxis, np.newaxis] * dl)
    r_minus_l = (r[np.newaxis, np.newaxis, :] - l)
    top = np.cross(mu_0Idl, r_minus_l)
    bottom = (np.linalg.norm(r_minus_l, axis=-1)**3)
    B = np.sum(top / bottom[:, :, np.newaxis], axis=(0, 1))
    return B
示例#17
0
 def compute_intersection_point(denom):
     t1 = np.cross(v2, v1) / denom
     t2 = (v1 @ v3) / denom
     condition = np.logical_or(np.logical_or(t1 < 0.0, t2 < 0.0), t2 > 1.0)
     return jax.lax.cond(
         condition,
         true_fun=lambda t: np.array([np.inf, np.inf]),
         false_fun=lambda t: ray_origin + t1 * ray_direction,
         operand=t1,
     )
示例#18
0
    def extend(prev_three_coords, point, multi_m):
        """
        Aligns an atom or an entire fragment depending on value of `multi_m`
        with the preceding three atoms.
        :param prev_three_coords: Named tuple storing the last three atom
        coordinates ("a", "b", "c") where "c" is the current end of the
        structure (i.e. closest to the atom/ fragment that will be added now).
        Shape NUM_DIHEDRALS x [NUM_FRAGS/0, BATCH_SIZE, NUM_DIMENSIONS].
        First rank depends on value of `multi_m`.
        :param point: Point describing the atom that is added to the structure.
        Shape [NUM_FRAGS/FRAG_SIZE, BATCH_SIZE, NUM_DIMENSIONS]
        First rank depends on value of `multi_m`.
        :param multi_m: If True, a single atom is added to the chain for
        multiple fragments in parallel. If False, an single fragment is added.
        Note the different parameter dimensions.
        :return: Coordinates of the atom/ fragment.
        """
        # Normalize rows: https://necromuralist.github.io/neural_networks/posts/normalizing-with-numpy/
        Xbc = (prev_three_coords.c - prev_three_coords.b)
        bc = Xbc / onp.linalg.norm(Xbc, axis=-1, keepdims=True)

        Xn = onp.cross(prev_three_coords.b - prev_three_coords.a,
                       bc,
                       axisa=-1,
                       axisb=-1,
                       axisc=-1)
        n = Xn / onp.linalg.norm(Xn, axis=-1, keepdims=True)

        if multi_m:  # multiple fragments, one atom at a time
            m = onp.transpose(onp.stack([bc, onp.cross(n, bc), n]),
                              (1, 2, 3, 0))
        else:  # single fragment, reconstructed entirely at once.
            s = point.shape + (3, )  # +
            m = onp.transpose(onp.stack([bc, onp.cross(n, bc), n]), (1, 2, 0))
            m = onp.tile(m, (s[0], 1, 1)).reshape(s)

        coord = onp.squeeze(onp.matmul(m, onp.expand_dims(point, axis=3)),
                            axis=3) + prev_three_coords.c

        return coord
示例#19
0
    def _recov_axis_batch(hand_joints, transf, joints_mapping, up_axis_base):
        """
        input: hand_joints[B, 21, 3], transf[B, 16, 4, 4]
        output: b_axis[B, 15, 3], u_axis[B, 15, 3], l_axis[B, 15, 3]
        """
        bs = transf.shape[0]

        b_axis = hand_joints[:, joints_mapping] - hand_joints[:, [
            i + 1 for i in joints_mapping
        ]]
        b_axis = (np.transpose(transf[:, 1:, :3, :3], (0, 1, 3, 2))
                  @ np.expand_dims(b_axis, -1)).squeeze(-1)

        l_axis = np.cross(b_axis, up_axis_base)

        u_axis = np.cross(l_axis, b_axis)

        return (
            b_axis / np.expand_dims(np.linalg.norm(b_axis, axis=2), -1),
            u_axis / np.expand_dims(np.linalg.norm(u_axis, axis=2), -1),
            l_axis / np.expand_dims(np.linalg.norm(l_axis, axis=2), -1),
        )
示例#20
0
 def __init__(
     self,
     pinhole: np.ndarray,
     sensor: np.ndarray,
     left: np.ndarray,
     resolution: Tuple[int, int],
 ):
     self.pinhole = pinhole
     self.sensor = sensor
     self.left = left / np.linalg.norm(left)
     up = np.cross(sensor - pinhole, left)
     self.up = up / np.linalg.norm(up)
     self.resolution = resolution
        def dot_cross_product(x):
            #     (R_N - R_H1)dot-prod [(R_N - R_H2)cross-prod(R_N - R_H3)]/ (r_NH1 * r_NH2 * r_NH3)

            x = jnp.reshape(x, (self.n_atoms, 3))
            R_nh1 = x[0, :] - x[1, :]
            R_nh1 = R_nh1 / jnp.linalg.norm(R_nh1)
            R_nh2 = x[0, :] - x[2, :]
            R_nh2 = R_nh2 / jnp.linalg.norm(R_nh2)
            R_nh3 = x[0, :] - x[3, :]
            R_nh3 = R_nh3 / jnp.linalg.norm(R_nh3)

            b = jnp.cross(R_nh2, R_nh3)
            c = jnp.dot(R_nh1, b)

            return c
示例#22
0
def coil_force(ll, dl):
    """
    Calculate the Lorentz Force on the coils
    """
    # vector map biot_savart
    BS_coils = vmap(
        biot_savart_oncoil, (0, None, None, None), 0
    )  # map the input (which will be l) over the first dimension of the input aray (coils)
    BS_elements = vmap(
        BS_coils, (0, None, None, None), 0
    )  # map the input (which is the l array) over it's second dimension (elements of each coil).
    elementwise_force = np.cross(dl, BS_elements(ll, ll, dl, I_arr))
    #percoil_force = np.sum(elementwise_force), axis = 1) # sum over elements
    Total_force = np.sum(np.linalg.norm(elementwise_force, axis=-1),
                         axis=(0, 1))
    return Total_force
示例#23
0
 def _state_dot(self, s, thrusts):
     t1, t2, t3, t4 = thrusts
     x_dot = s[3:6]  # The velocities(t+1 x_dots equal the t x_dots)
     R = self._rotation_matrix(s[6:9])  # The acceleration
     x_dotdot = np.array([
         0, 0, -self.weight * self.g
     ]) + np.dot(R, np.array([0, 0, np.sum(thrusts)])) / self.weight
     a_dot = s[
         9:12]  # The angular rates(t+1 theta_dots equal the t theta_dots)
     # The angular accelerations
     tau = np.array([
         self.L * (t1 - t3), self.L * (t2 - t4),
         self.b * (t1 - t2 + t3 - t4)
     ])
     a_dotdot = np.dot(self.invI,
                       (tau - np.cross(a_dot, np.dot(self.I, a_dot))))
     state_dot = np.concatenate((x_dot, x_dotdot, a_dot, a_dotdot))
     return state_dot
示例#24
0
文件: backend.py 项目: dpanici/DESC
def cross(a, b, axis):
    """Batched vector cross product

    Parameters
    ----------
    a : array-like
        first array of vectors
    b : array-like
        second array of vectors
    axis : int
        axis along which vectors are stored

    Returns
    -------
    y : array-like
        y = a x b

    """
    return jnp.cross(a, b, axis=axis)
示例#25
0
def _dihedral_angle(p1, p2, p3, p4):
    q = p3 - p2
    r = np.cross(p2 - p1, q)
    s = np.cross(q, p4 - p3)
    return np.arctan2(np.dot(np.cross(r, s), q), np.dot(r, s) * linalg.norm(q))
示例#26
0
def _angle(p1, p2, p3):
    q = p1 - p2
    r = p3 - p2
    return np.arctan2(linalg.norm(np.cross(q, r)), np.dot(q, r))
示例#27
0
def torsion_pure(d1gamma, d2gamma, d3gamma):
    """
    This function is used in a Python+Jax implementation of formula for torsion.
    """

    return jnp.sum(jnp.cross(d1gamma, d2gamma, axis=1) * d3gamma, axis=1) / jnp.sum(jnp.cross(d1gamma, d2gamma, axis=1)**2, axis=1)
示例#28
0
def kappa_pure(d1gamma, d2gamma):
    """
    This function is used in a Python+Jax implementation of formula for curvature.
    """

    return jnp.linalg.norm(jnp.cross(d1gamma, d2gamma), axis=1)/jnp.linalg.norm(d1gamma, axis=1)**3
示例#29
0
    def __call__(self, rng_0, rng_1, batch, randomized):
        """Generalizale Patch-Based Neural Rendering Model.

    Args:
      rng_0: jnp.ndarray, random number generator for coarse model sampling.
      rng_1: jnp.ndarray, random number generator for fine model sampling.
      batch: data batch. data_types.Batch
      randomized: bool, use randomized stratified sampling.

    Returns:
      ret: list, [(rgb, None, Optional[acc])]
    """
        del rng_1

        # Get the batch rays
        batch_rays = batch.target_view.rays

        #---------------------------------------------------------------------------------------
        # Get image height and width. To be use to clip projection
        # outside the image.
        image_height, image_width = batch.reference_views.rgb.shape[1:3]
        # The the min and max depth as intergers.
        min_depth = batch.reference_views.min_depth[0][0]
        max_depth = batch.reference_views.max_depth[0][0]

        projected_coordinates, _, wcoords = self.projector.epipolar_projection(
            rng_0,
            batch_rays,
            batch.reference_views.ref_worldtocamera,
            batch.reference_views.intrinsic_matrix,
            image_height,
            image_width,
            min_depth,
            max_depth,
        )

        #---------------------------------------------------------------------------------------
        # Add a [0, 0, 0, 1] row to the ref cam to world
        bottom = jnp.tile(
            jnp.array([[[0, 0, 0, 1.]]]),
            (batch.reference_views.ref_cameratoworld.shape[0], 1, 1))
        ref_cameratoworld = jnp.concatenate([
            batch.reference_views.ref_cameratoworld[Ellipsis, :3, :4], bottom
        ], -2)

        #--------------------------------------------------------------------------------------
        # Compute the canonical transformation.
        # get the cam to world of the target rays.
        camtoworld = jnp.linalg.inv(
            batch.reference_views.target_worldtocam)  # (1, 4, 4)
        # Get the up vector
        upv = camtoworld[Ellipsis, :3, 1]
        # Get the ray direction.
        raydir = batch_rays.directions
        rdotup = (raydir * upv).sum(-1, keepdims=True)
        orthoup = upv - rdotup * raydir
        orthoup = orthoup / (orthoup**2).sum(-1, keepdims=True)
        vec0 = jnp.cross(orthoup, raydir)
        vec0 = vec0 / (vec0**2).sum(-1, keepdims=True)
        # Stack it such that the matrix is the transpose (inverse).
        r_relative = jnp.stack([vec0, orthoup, raydir], axis=1)
        #--------------------------------------------------------------------------------------

        translation_relative = -jnp.matmul(r_relative,
                                           batch_rays.origins[Ellipsis, None])

        canonical_transform = jnp.concatenate(
            [r_relative, translation_relative], axis=-1)
        bottom = jnp.tile(jnp.array([[[0, 0, 0, 1.]]]),
                          (r_relative.shape[0], 1, 1))

        canonical_transform = jnp.concatenate([canonical_transform, bottom],
                                              axis=-2)

        wcoords = jnp.pad(wcoords, ((0, 0), (0, 0), (0, 1)), constant_values=1)
        wcoords = jnp.einsum("nij,nkj->nki", canonical_transform, wcoords)
        wcoords = wcoords[Ellipsis, :3]
        canonical_ref_cameratoworld = jnp.einsum("bxz, nzy->nbxy",
                                                 canonical_transform,
                                                 ref_cameratoworld)

        projected_rays = self.projector.get_near_rays_canonical(
            projected_coordinates, canonical_ref_cameratoworld,
            batch.reference_views.intrinsic_matrix)

        # Next we need to get the rgb values and the rays corresponding to these
        # projections.
        ref_images = model_utils.uint2float(batch.reference_views.rgb)
        if self.epipolar_config.normalize_ref_image:
            ref_images = (ref_images - self.mean) / self.std

        projected_rgb_and_feat = self._get_pixel_projection(
            projected_coordinates, ref_images)

        batch.reference_views.rgb = None
        #----------------------------------------------------------------------------------------
        # Canonicalize the batch rays.
        # Mutiply the directions with the inverse of the target camera to world
        # matrix

        batch_rays.directions = (
            r_relative[Ellipsis, :3, :3]
            @ batch_rays.directions[Ellipsis, None])[Ellipsis, 0]
        batch_rays.origins = jnp.zeros_like(batch_rays.origins)

        # Get LF representation of the batch and the projected rays.
        # Below we consider the representation extracted from the batch rays as the
        # query and representation extracted from the projected rays as keys and the
        # projected rgb as the values.
        _, input_q, _ = self._get_query(batch_rays)
        _, input_k, _ = self._get_key(projected_rays, projected_rgb_and_feat,
                                      wcoords)

        #----------------------------------------------------------------------------------------
        # Append the relative orientation between the target and reference cameras.
        rel_camera_rot = einshape(
            "nxy->bnp(xy)",
            jnp.matmul(
                batch.reference_views.target_worldtocam[Ellipsis, :3, :3],
                batch.reference_views.ref_cameratoworld[Ellipsis, :3, :3]),
            b=input_k.shape[0],
            p=input_k.shape[2])
        rel_camera_tran = einshape(
            "nx->bnpx",
            (batch.reference_views.target_worldtocam[Ellipsis, :3, -1] -
             batch.reference_views.ref_cameratoworld[Ellipsis, :3, -1]),
            b=input_k.shape[0],
            p=input_k.shape[2])
        input_k = jnp.concatenate([input_k, rel_camera_rot, rel_camera_tran],
                                  axis=-1)

        # Get the average feature over each epipolar line
        avg_projection_features, e_attn = self._get_avg_features(
            input_q, input_k, randomized=randomized)

        rgb, n_attn = self._predict_color(input_q, avg_projection_features,
                                          randomized)
        rgb_coarse = self._get_reg_prediction(projected_rgb_and_feat, e_attn,
                                              n_attn)

        ret = [(rgb_coarse, None, None)]
        ret.append((rgb, None, None))

        if self.return_attn:
            return ret, {
                "e_attn": e_attn,
                "n_attn": n_attn,
                "p_coord": projected_coordinates.swapaxes(0, 1)
            }
        else:
            return ret
示例#30
0
def cross(a, b, axisa=-1, axisb=-1, axisc=-1, axis=None):
  if isinstance(a, JaxArray): a = a.value
  if isinstance(b, JaxArray): b = b.value
  return JaxArray(jnp.cross(a, b, axisa=axisa, axisb=axisb, axisc=axisc, axis=axis))