def frozen_flow_transform(t, y, x0, bottom, wind_velocity=None): """ Computes the frozen flow transform on the coordinates. Args: t: time in seconds y: position in km, origin at centre of Earth x0: position of reference point. bottom: bottom of ionosphere in km wind_velocity: layer velocity in km/s Returns: Coordinates inverse rotating the coordinates to take into account the flow of ionosphere around surface of Earth. """ # rotate around Earth's core if t is None: return y if wind_velocity is None: return y rotation_axis = jnp.cross(wind_velocity, x0) rotation_axis /= jnp.linalg.norm(rotation_axis) # v(r) = theta_dot * r # km/s / km theta_dot = jnp.linalg.norm(wind_velocity) / (bottom + jnp.linalg.norm(x0)) angle = -theta_dot * t # Rotation u_cross_x = jnp.cross(rotation_axis, y) rotated_y = rotation_axis * (rotation_axis @ y) \ + jnp.cos(angle) * jnp.cross(u_cross_x, rotation_axis) \ + jnp.sin(angle) * u_cross_x # print(rotated_y - y, t*wind_velocity, wind_velocity) return rotated_y
def vector_rotate(a, b, theta): """ Rotate vector a around vector b by an angle theta (radians) Programming Notes: u: parallel projection of a on b_hat. v: perpendicular projection of a on b_hat. w: a vector perpendicular to both a and b. """ if a.ndim == 2: b_hat = b / np.linalg.norm(b) dot = np.einsum('ij,j->i', a, b_hat) u = np.einsum('i,j->ij', dot, b_hat) v = a - u w = np.cross(b_hat, v) c = u + v * np.cos(theta) + w * np.sin(theta) elif a.ndim == 1: b_hat = b / np.linalg.norm(b) u = b_hat * np.dot(a, b_hat) v = a - u w = np.cross(b_hat, v) c = u + v * np.cos(theta) + w * np.sin(theta) else: raise Exception( 'Input array must be 1d (vector) or 2d (array of vectors)') return c
def extend(tri, pt, multi_m): """ Args: tri: NUM_DIHEDRALS x [NUM_FRAGS/0, BATCH_SIZE, NUM_DIMENSIONS] pt: [NUM_FRAGS/FRAG_SIZE, BATCH_SIZE, NUM_DIMENSIONS] multi_m: bool indicating whether m (and tri) is higher rank. pt is always higher rank; what changes is what the first rank is. """ bc = normalize(tri.c - tri.b, axis=-1) # [NUM_FRAGS/0, BATCH_SIZE, NUM_DIMS] n = normalize(np.cross(tri.b - tri.a, bc), axis=-1) # [NUM_FRAGS/0, BATCH_SIZE, NUM_DIMS] if multi_m: # multiple fragments, one atom at a time. m = np.transpose( np.stack([bc, np.cross(n, bc), n]), [1, 2, 3, 0]) # [NUM_FRAGS, BATCH_SIZE, NUM_DIMS, 3 TRANS] else: # single fragment, reconstructed entirely at once. s = onp.pad( pt.shape, [[0, 1]], constant_values=3) # FRAG_SIZE, BATCH_SIZE, NUM_DIMS, 3 TRANS m = np.transpose(np.stack([bc, np.cross(n, bc), n]), [1, 2, 0]) # [BATCH_SIZE, NUM_DIMS, 3 TRANS] m = np.reshape(np.tile(m, [s[0], 1, 1]), s) # [FRAG_SIZE, BATCH_SIZE, NUM_DIMS, 3 TRANS] coord = np.squeeze( np.matmul(m, np.expand_dims(pt, 3)), axis=3) + tri.c # [NUM_FRAGS/FRAG_SIZE, BATCH_SIZE, NUM_DIMS] return coord
def _dihedral(pos: jnp.ndarray, indices: jnp.ndarray, tvecs: jnp.ndarray) -> float: dx1 = pos[indices[1]] - pos[indices[0]] + tvecs[0] dx2 = pos[indices[2]] - pos[indices[1]] + tvecs[1] dx3 = pos[indices[3]] - pos[indices[2]] + tvecs[2] numer = dx2 @ jnp.cross(jnp.cross(dx1, dx2), jnp.cross(dx2, dx3)) denom = jnp.linalg.norm(dx2) * jnp.cross(dx1, dx2) @ jnp.cross(dx2, dx3) return jnp.arctan2(numer, denom)
def calc_torsions(xyz: np.ndarray, struct_descr_dict): b1 = xyz[struct_descr_dict['torsions'][:, 1]] - xyz[struct_descr_dict['torsions'][:, 0]] b2 = xyz[struct_descr_dict['torsions'][:, 2]] - xyz[struct_descr_dict['torsions'][:, 1]] b3 = xyz[struct_descr_dict['torsions'][:, 3]] - xyz[struct_descr_dict['torsions'][:, 2]] b12 = np.cross(b1, b2) b23 = np.cross(b2, b3) return np.arctan2((np.cross(b12, b23) * b2).sum(axis=-1) / np.linalg.norm(b2, axis=-1), (b12 * b23).sum(axis=-1))
def __call__(self, r, h, p): r0, n, s = jnp.array(p[-7:-4]), jnp.array(p[-4:-1]), p[-1] nz = jnp.array([0.0, 0.0, 1.0]) v = jnp.cross(jnp.cross(nz, n), n) rdelt = s * v / jnp.sqrt(jnp.dot(v, v)) g0 = self.base_gfunc(r, h, p[:-7]) g1 = self.base_gfunc(r + rdelt, h, p[:-7]) return soft_if_then_logjit(jnp.dot(r - r0, n), g0, g1, h)
def dihedral_angle(p1, p2, p3, p4): """ Returns the dihedral angle defined by four points in space (around the line defined by the two central points). """ q = p3 - p2 r = np.cross(p2 - p1, q) s = np.cross(q, p4 - p3) return np.arctan2(np.dot(np.cross(r, s), q), np.dot(r, s) * linalg.norm(q))
def torsionVecs_(self, P): p0 = P[0] p1 = P[1] p2 = P[2] p3 = P[3] r1 = p0 - p1 r2 = p1 - p2 r3 = p3 - p2 cp_12 = np.cross(r1, r2) cp_32 = np.cross(r3, r2) return np.dstack((cp_12, np.zeros(cp_12.shape), cp_32)) \ .squeeze() \ .transpose([1, 0])
def signed_torsion_angle(ci, cj, ck, cl): """ Batch compute the signed angle of a torsion angle. The torsion angle between two planes should be periodic but not necessarily symmetric. Parameters ---------- ci: shape [num_torsions, 3] np.array coordinates of the 1st atom in the 1-4 torsion angle cj: shape [num_torsions, 3] np.array coordinates of the 2nd atom in the 1-4 torsion angle ck: shape [num_torsions, 3] np.array coordinates of the 3rd atom in the 1-4 torsion angle cl: shape [num_torsions, 3] np.array coordinates of the 4th atom in the 1-4 torsion angle Returns ------- shape [num_torsions,] np.array array of torsion angles. """ # Taken from the wikipedia arctan2 implementation: # https://en.wikipedia.org/wiki/Dihedral_angle # We use an identical but numerically stable arctan2 # implementation as opposed to the OpenMM energy function to # avoid asingularity when the angle is zero. rij = delta_r(cj, ci) rkj = delta_r(cj, ck) rkl = delta_r(cl, ck) n1 = np.cross(rij, rkj) n2 = np.cross(rkj, rkl) lhs = np.linalg.norm(n1, axis=-1) rhs = np.linalg.norm(n2, axis=-1) bot = lhs * rhs y = np.sum(np.multiply(np.cross(n1, n2), rkj / np.linalg.norm(rkj, axis=-1, keepdims=True)), axis=-1) x = np.sum(np.multiply(n1, n2), -1) return np.arctan2(y, x)
def torsionVecs(self, P): p0 = P[...,[0],[0,1,2]] p1 = P[...,[1],[0,1,2]] p2 = P[...,[2],[0,1,2]] p3 = P[...,[3],[0,1,2]] r1 = p0 - p1 r2 = p1 - p2 r3 = p3 - p2 cp_12 = np.cross(r1, r2) cp_32 = np.cross(r3, r2) return np.dstack((cp_12, np.zeros(cp_12.shape), cp_32)) \ .squeeze() \ .transpose([0, 2, 1])
def biot_savart_saddle(r, I, dl, l): """ Inputs: r : Position we want to evaluate at, NZ x NT x 3 I : Current in ith coil, length NC dl : Vector which has coil segment length and direction, NC x NS x 3 l : Positions of center of each coil segment, NC x NS x 3 Returns: A NZ x NT x 3 array which is the magnetic field vector on the surface points """ mu_0 = 1. mu_0I = I * mu_0 mu_0Idl = mu_0I[:, np.newaxis, np.newaxis] * dl # NC x NS x 3 r_minus_l = r[ np.newaxis, :, :, np.newaxis, :] - l[:, np.newaxis, np.newaxis, :, :] # NC x NZ x NT x NS x 3 top = np.cross(mu_0Idl[:, np.newaxis, np.newaxis, :, :], r_minus_l) # NC x NZ x NT x NS x 3 bottom = np.linalg.norm(r_minus_l, axis=-1)**3 # NC x NZ x NT x NS B = np.sum(top / bottom[:, :, :, :, np.newaxis], axis=(0, 3)) # NZ x NT x 3 return B
def computeB(I, dl, l, r, zeta, z): """ Inputs: r, zeta, z : The coordinates of the point we want the magnetic field at. Cylindrical coordinates. Outputs: B_z, B_zeta, B_z : the magnetic field components at the input coordinates created by the currents in the coils. Cylindrical coordinates. """ x = r * np.cos(zeta) y = r * np.sin(zeta) xyz = np.asarray([x, y, z]) mu_0 = 1. mu_0I = I * mu_0 # NC mu_0Idl = mu_0I[:, np.newaxis, np.newaxis, np.newaxis, np.newaxis] * dl # NC x NS x NNR x NBR x 3 r_minus_l = xyz[ np.newaxis, np.newaxis, np.newaxis, np.newaxis, :] - l[:, :, :, :, :] # NC x NS x NNR x NBR x 3 top = np.cross(mu_0Idl, r_minus_l) # NC x x NS x NNR x NBR x 3 bottom = np.linalg.norm(r_minus_l, axis=-1)**3 # NC x NS x NNR x NBR B_xyz = np.sum(top / bottom[:, :, :, :, np.newaxis], axis=(0, 1, 2, 3)) # 3, xyz coordinates B_x = B_xyz[0] B_y = B_xyz[1] B_z = B_xyz[2] B_r = B_x * np.cos(zeta) + B_y * np.sin(zeta) B_zeta = -B_x * np.sin(zeta) + B_y * np.cos(zeta) return B_r, B_zeta, B_z
def biot_savart_oncoil(r_eval, dl, ll, I_arr): """ Calculate the Biot-Savart integral over the coils (also ON) a segment of the coil. specified by l and dl. Arguments: *r_eval*: (lenght 3 array) the point wherer the field is to be evaluated in cartesian coordinates. Has to be on a coil. *dl*: ( n_coils, nsegments, 3)-array of the distance vector to every other coil line segment *l* ( n_coils, nsegments, 3)-array of the position of each coil segment Note on algoritnm: the None allows one to add new axes to in-line cast the array into the proper shape. The biot-savart integral is calculated as a sum over all segments. returns: *B*: magnetic field at position r_eval """ top = np.cross(dl, r_eval[None, None, :] - ll) * I_arr[:, None, None] #unchecked bottom = np.linalg.norm(r_eval[None, None, :] - ll, axis=-1)**3 # sum over all infinitesimal line segments, replacing the NaN with zero B = np.sum(np.nan_to_num(top / bottom[:, :, None]), axis=(0, 1)) return B
def angle(p1, p2, p3): """ Returns the angle defined by three points in space (around the one in the middle). """ q = p1 - p2 r = p3 - p2 return np.arctan2(linalg.norm(np.cross(q, r)), np.dot(q, r))
def geo(joint): idx_a = npj.array([1,5,9,13,17]) idx_b = npj.array([2,6,10,14,18]) idx_c = npj.array([3,7,11,15,19]) idx_d = npj.array([4,8,12,16,20]) p_a = joint[:,idx_a,:] p_b = joint[:,idx_b,:] p_c = joint[:,idx_c,:] p_d = joint[:,idx_d,:] v_ab = p_a - p_b #(B, 5, 3) v_bc = p_b - p_c #(B, 5, 3) v_cd = p_c - p_d #(B, 5, 3) loss_1 = npj.abs(npj.sum(npj.cross(v_ab, v_bc, -1) * v_cd, -1)).mean() loss_2 = - npj.clip(npj.sum(npj.cross(v_ab, v_bc, -1) * npj.cross(v_bc, v_cd, -1)), -npj.inf, 0).mean() loss = 10000*loss_1 + 100000*loss_2 return loss
def biot_savart(r, I, dl, l): mu_0 = 1.0 mu_0I = I * mu_0 mu_0Idl = (mu_0I[:, np.newaxis, np.newaxis] * dl) r_minus_l = (r[np.newaxis, np.newaxis, :] - l) top = np.cross(mu_0Idl, r_minus_l) bottom = (np.linalg.norm(r_minus_l, axis=-1)**3) B = np.sum(top / bottom[:, :, np.newaxis], axis=(0, 1)) return B
def compute_intersection_point(denom): t1 = np.cross(v2, v1) / denom t2 = (v1 @ v3) / denom condition = np.logical_or(np.logical_or(t1 < 0.0, t2 < 0.0), t2 > 1.0) return jax.lax.cond( condition, true_fun=lambda t: np.array([np.inf, np.inf]), false_fun=lambda t: ray_origin + t1 * ray_direction, operand=t1, )
def extend(prev_three_coords, point, multi_m): """ Aligns an atom or an entire fragment depending on value of `multi_m` with the preceding three atoms. :param prev_three_coords: Named tuple storing the last three atom coordinates ("a", "b", "c") where "c" is the current end of the structure (i.e. closest to the atom/ fragment that will be added now). Shape NUM_DIHEDRALS x [NUM_FRAGS/0, BATCH_SIZE, NUM_DIMENSIONS]. First rank depends on value of `multi_m`. :param point: Point describing the atom that is added to the structure. Shape [NUM_FRAGS/FRAG_SIZE, BATCH_SIZE, NUM_DIMENSIONS] First rank depends on value of `multi_m`. :param multi_m: If True, a single atom is added to the chain for multiple fragments in parallel. If False, an single fragment is added. Note the different parameter dimensions. :return: Coordinates of the atom/ fragment. """ # Normalize rows: https://necromuralist.github.io/neural_networks/posts/normalizing-with-numpy/ Xbc = (prev_three_coords.c - prev_three_coords.b) bc = Xbc / onp.linalg.norm(Xbc, axis=-1, keepdims=True) Xn = onp.cross(prev_three_coords.b - prev_three_coords.a, bc, axisa=-1, axisb=-1, axisc=-1) n = Xn / onp.linalg.norm(Xn, axis=-1, keepdims=True) if multi_m: # multiple fragments, one atom at a time m = onp.transpose(onp.stack([bc, onp.cross(n, bc), n]), (1, 2, 3, 0)) else: # single fragment, reconstructed entirely at once. s = point.shape + (3, ) # + m = onp.transpose(onp.stack([bc, onp.cross(n, bc), n]), (1, 2, 0)) m = onp.tile(m, (s[0], 1, 1)).reshape(s) coord = onp.squeeze(onp.matmul(m, onp.expand_dims(point, axis=3)), axis=3) + prev_three_coords.c return coord
def _recov_axis_batch(hand_joints, transf, joints_mapping, up_axis_base): """ input: hand_joints[B, 21, 3], transf[B, 16, 4, 4] output: b_axis[B, 15, 3], u_axis[B, 15, 3], l_axis[B, 15, 3] """ bs = transf.shape[0] b_axis = hand_joints[:, joints_mapping] - hand_joints[:, [ i + 1 for i in joints_mapping ]] b_axis = (np.transpose(transf[:, 1:, :3, :3], (0, 1, 3, 2)) @ np.expand_dims(b_axis, -1)).squeeze(-1) l_axis = np.cross(b_axis, up_axis_base) u_axis = np.cross(l_axis, b_axis) return ( b_axis / np.expand_dims(np.linalg.norm(b_axis, axis=2), -1), u_axis / np.expand_dims(np.linalg.norm(u_axis, axis=2), -1), l_axis / np.expand_dims(np.linalg.norm(l_axis, axis=2), -1), )
def __init__( self, pinhole: np.ndarray, sensor: np.ndarray, left: np.ndarray, resolution: Tuple[int, int], ): self.pinhole = pinhole self.sensor = sensor self.left = left / np.linalg.norm(left) up = np.cross(sensor - pinhole, left) self.up = up / np.linalg.norm(up) self.resolution = resolution
def dot_cross_product(x): # (R_N - R_H1)dot-prod [(R_N - R_H2)cross-prod(R_N - R_H3)]/ (r_NH1 * r_NH2 * r_NH3) x = jnp.reshape(x, (self.n_atoms, 3)) R_nh1 = x[0, :] - x[1, :] R_nh1 = R_nh1 / jnp.linalg.norm(R_nh1) R_nh2 = x[0, :] - x[2, :] R_nh2 = R_nh2 / jnp.linalg.norm(R_nh2) R_nh3 = x[0, :] - x[3, :] R_nh3 = R_nh3 / jnp.linalg.norm(R_nh3) b = jnp.cross(R_nh2, R_nh3) c = jnp.dot(R_nh1, b) return c
def coil_force(ll, dl): """ Calculate the Lorentz Force on the coils """ # vector map biot_savart BS_coils = vmap( biot_savart_oncoil, (0, None, None, None), 0 ) # map the input (which will be l) over the first dimension of the input aray (coils) BS_elements = vmap( BS_coils, (0, None, None, None), 0 ) # map the input (which is the l array) over it's second dimension (elements of each coil). elementwise_force = np.cross(dl, BS_elements(ll, ll, dl, I_arr)) #percoil_force = np.sum(elementwise_force), axis = 1) # sum over elements Total_force = np.sum(np.linalg.norm(elementwise_force, axis=-1), axis=(0, 1)) return Total_force
def _state_dot(self, s, thrusts): t1, t2, t3, t4 = thrusts x_dot = s[3:6] # The velocities(t+1 x_dots equal the t x_dots) R = self._rotation_matrix(s[6:9]) # The acceleration x_dotdot = np.array([ 0, 0, -self.weight * self.g ]) + np.dot(R, np.array([0, 0, np.sum(thrusts)])) / self.weight a_dot = s[ 9:12] # The angular rates(t+1 theta_dots equal the t theta_dots) # The angular accelerations tau = np.array([ self.L * (t1 - t3), self.L * (t2 - t4), self.b * (t1 - t2 + t3 - t4) ]) a_dotdot = np.dot(self.invI, (tau - np.cross(a_dot, np.dot(self.I, a_dot)))) state_dot = np.concatenate((x_dot, x_dotdot, a_dot, a_dotdot)) return state_dot
def cross(a, b, axis): """Batched vector cross product Parameters ---------- a : array-like first array of vectors b : array-like second array of vectors axis : int axis along which vectors are stored Returns ------- y : array-like y = a x b """ return jnp.cross(a, b, axis=axis)
def _dihedral_angle(p1, p2, p3, p4): q = p3 - p2 r = np.cross(p2 - p1, q) s = np.cross(q, p4 - p3) return np.arctan2(np.dot(np.cross(r, s), q), np.dot(r, s) * linalg.norm(q))
def _angle(p1, p2, p3): q = p1 - p2 r = p3 - p2 return np.arctan2(linalg.norm(np.cross(q, r)), np.dot(q, r))
def torsion_pure(d1gamma, d2gamma, d3gamma): """ This function is used in a Python+Jax implementation of formula for torsion. """ return jnp.sum(jnp.cross(d1gamma, d2gamma, axis=1) * d3gamma, axis=1) / jnp.sum(jnp.cross(d1gamma, d2gamma, axis=1)**2, axis=1)
def kappa_pure(d1gamma, d2gamma): """ This function is used in a Python+Jax implementation of formula for curvature. """ return jnp.linalg.norm(jnp.cross(d1gamma, d2gamma), axis=1)/jnp.linalg.norm(d1gamma, axis=1)**3
def __call__(self, rng_0, rng_1, batch, randomized): """Generalizale Patch-Based Neural Rendering Model. Args: rng_0: jnp.ndarray, random number generator for coarse model sampling. rng_1: jnp.ndarray, random number generator for fine model sampling. batch: data batch. data_types.Batch randomized: bool, use randomized stratified sampling. Returns: ret: list, [(rgb, None, Optional[acc])] """ del rng_1 # Get the batch rays batch_rays = batch.target_view.rays #--------------------------------------------------------------------------------------- # Get image height and width. To be use to clip projection # outside the image. image_height, image_width = batch.reference_views.rgb.shape[1:3] # The the min and max depth as intergers. min_depth = batch.reference_views.min_depth[0][0] max_depth = batch.reference_views.max_depth[0][0] projected_coordinates, _, wcoords = self.projector.epipolar_projection( rng_0, batch_rays, batch.reference_views.ref_worldtocamera, batch.reference_views.intrinsic_matrix, image_height, image_width, min_depth, max_depth, ) #--------------------------------------------------------------------------------------- # Add a [0, 0, 0, 1] row to the ref cam to world bottom = jnp.tile( jnp.array([[[0, 0, 0, 1.]]]), (batch.reference_views.ref_cameratoworld.shape[0], 1, 1)) ref_cameratoworld = jnp.concatenate([ batch.reference_views.ref_cameratoworld[Ellipsis, :3, :4], bottom ], -2) #-------------------------------------------------------------------------------------- # Compute the canonical transformation. # get the cam to world of the target rays. camtoworld = jnp.linalg.inv( batch.reference_views.target_worldtocam) # (1, 4, 4) # Get the up vector upv = camtoworld[Ellipsis, :3, 1] # Get the ray direction. raydir = batch_rays.directions rdotup = (raydir * upv).sum(-1, keepdims=True) orthoup = upv - rdotup * raydir orthoup = orthoup / (orthoup**2).sum(-1, keepdims=True) vec0 = jnp.cross(orthoup, raydir) vec0 = vec0 / (vec0**2).sum(-1, keepdims=True) # Stack it such that the matrix is the transpose (inverse). r_relative = jnp.stack([vec0, orthoup, raydir], axis=1) #-------------------------------------------------------------------------------------- translation_relative = -jnp.matmul(r_relative, batch_rays.origins[Ellipsis, None]) canonical_transform = jnp.concatenate( [r_relative, translation_relative], axis=-1) bottom = jnp.tile(jnp.array([[[0, 0, 0, 1.]]]), (r_relative.shape[0], 1, 1)) canonical_transform = jnp.concatenate([canonical_transform, bottom], axis=-2) wcoords = jnp.pad(wcoords, ((0, 0), (0, 0), (0, 1)), constant_values=1) wcoords = jnp.einsum("nij,nkj->nki", canonical_transform, wcoords) wcoords = wcoords[Ellipsis, :3] canonical_ref_cameratoworld = jnp.einsum("bxz, nzy->nbxy", canonical_transform, ref_cameratoworld) projected_rays = self.projector.get_near_rays_canonical( projected_coordinates, canonical_ref_cameratoworld, batch.reference_views.intrinsic_matrix) # Next we need to get the rgb values and the rays corresponding to these # projections. ref_images = model_utils.uint2float(batch.reference_views.rgb) if self.epipolar_config.normalize_ref_image: ref_images = (ref_images - self.mean) / self.std projected_rgb_and_feat = self._get_pixel_projection( projected_coordinates, ref_images) batch.reference_views.rgb = None #---------------------------------------------------------------------------------------- # Canonicalize the batch rays. # Mutiply the directions with the inverse of the target camera to world # matrix batch_rays.directions = ( r_relative[Ellipsis, :3, :3] @ batch_rays.directions[Ellipsis, None])[Ellipsis, 0] batch_rays.origins = jnp.zeros_like(batch_rays.origins) # Get LF representation of the batch and the projected rays. # Below we consider the representation extracted from the batch rays as the # query and representation extracted from the projected rays as keys and the # projected rgb as the values. _, input_q, _ = self._get_query(batch_rays) _, input_k, _ = self._get_key(projected_rays, projected_rgb_and_feat, wcoords) #---------------------------------------------------------------------------------------- # Append the relative orientation between the target and reference cameras. rel_camera_rot = einshape( "nxy->bnp(xy)", jnp.matmul( batch.reference_views.target_worldtocam[Ellipsis, :3, :3], batch.reference_views.ref_cameratoworld[Ellipsis, :3, :3]), b=input_k.shape[0], p=input_k.shape[2]) rel_camera_tran = einshape( "nx->bnpx", (batch.reference_views.target_worldtocam[Ellipsis, :3, -1] - batch.reference_views.ref_cameratoworld[Ellipsis, :3, -1]), b=input_k.shape[0], p=input_k.shape[2]) input_k = jnp.concatenate([input_k, rel_camera_rot, rel_camera_tran], axis=-1) # Get the average feature over each epipolar line avg_projection_features, e_attn = self._get_avg_features( input_q, input_k, randomized=randomized) rgb, n_attn = self._predict_color(input_q, avg_projection_features, randomized) rgb_coarse = self._get_reg_prediction(projected_rgb_and_feat, e_attn, n_attn) ret = [(rgb_coarse, None, None)] ret.append((rgb, None, None)) if self.return_attn: return ret, { "e_attn": e_attn, "n_attn": n_attn, "p_coord": projected_coordinates.swapaxes(0, 1) } else: return ret
def cross(a, b, axisa=-1, axisb=-1, axisc=-1, axis=None): if isinstance(a, JaxArray): a = a.value if isinstance(b, JaxArray): b = b.value return JaxArray(jnp.cross(a, b, axisa=axisa, axisb=axisb, axisc=axisc, axis=axis))