def xlog1py(x, y): x, y = _promote_args_like(osp_special.xlog1py, x, y) return lax._safe_mul(x, np.log1p(y))
def _xlog1py_jvp_rhs(g, ans, x, y): shape = lax.broadcast_shapes(np.shape(g), np.shape(x)) g = np.broadcast_to(g, shape) x = np.broadcast_to(x, shape) x, y = _promote_args_like(osp_special.xlog1py, x, y) return g * lax._safe_mul(x, np.reciprocal(1 + y))
def xlog1py_jvp_rhs(g, x, y, jaxpr, aval, consts): x, y = _promote_args_like(osp_special.xlog1py, x, y) g, x = _promote_args_like(osp_special.xlog1py, g, x) jac = lax._safe_mul(lax._brcast(x, y), lax._brcast(lax.reciprocal(1 + y), x)) return lax.mul(lax._brcast(g, jac), jac)
def _xlog1py_jvp_lhs(g, ans, x, y): shape = lax.broadcast_shapes(np.shape(g), np.shape(y)) g = np.broadcast_to(g, shape) y = np.broadcast_to(y, shape) g, y = _promote_args_like(osp_special.xlog1py, g, y) return lax._safe_mul(g, np.log1p(y))
def xlog1py_jvp_lhs(g, x, y, jaxpr, aval, consts): x, y = _promote_args_like(osp_special.xlog1py, x, y) g, y = _promote_args_like(osp_special.xlog1py, g, y) return lax._safe_mul(lax._brcast(g, y), lax._brcast(lax.log1p(y), g))