def _radial_and_tangential_undistort( xd: jnp.ndarray, yd: jnp.ndarray, k1: float = 0, k2: float = 0, k3: float = 0, p1: float = 0, p2: float = 0, eps: float = 1e-9, max_iterations=10) -> Tuple[jnp.ndarray, jnp.ndarray]: """Computes undistorted (x, y) from (xd, yd).""" # Initialize from the distorted point. x = xd.copy() y = yd.copy() for _ in range(max_iterations): fx, fy, fx_x, fx_y, fy_x, fy_y = _compute_residual_and_jacobian( x=x, y=y, xd=xd, yd=yd, k1=k1, k2=k2, k3=k3, p1=p1, p2=p2) denominator = fy_x * fx_y - fx_x * fy_y x_numerator = fx * fy_y - fy * fx_y y_numerator = fy * fx_x - fx * fy_x step_x = jnp.where( jnp.abs(denominator) > eps, x_numerator / denominator, jnp.zeros_like(denominator)) step_y = jnp.where( jnp.abs(denominator) > eps, y_numerator / denominator, jnp.zeros_like(denominator)) x = x + step_x y = y + step_y return x, y
def generate_positions(rng: rjax.PRNGKey, genpcls: dict, pos0: np.ndarray, pcl: GenParticle = None): """ Generates position according to the momentum direction and lifetime Traverses decay tree recursively """ if pcl is None: genpcls['root']['pos'] = Position.from_ndarray(pos0) pcl = genpcls['root']['gpcl'] else: genpcls[pcl.name]['pos'] = Position.from_ndarray(pos0) for ch in pcl.children: particle = genpcls[ch.name]['pcl'] if particle.lifetime > 0.0001 and particle.lifetime < 1: mom = genpcls[ch.name]['mom'] nevt = mom.size rng, key = rjax.split(rng) time = particle.lifetime * rjax.exponential(key, (nevt, 1)) # TODO: add gamma factor multiplier here (relativistic correction) chpos = pos0 + mom.velocity(particle.mass) * time else: chpos = pos0.copy() generate_positions(rng, genpcls, chpos, ch)