Ejemplo n.º 1
0
def xlog1py(x, y):
    x, y = _promote_args_like(osp_special.xlog1py, x, y)
    return lax._safe_mul(x, np.log1p(y))
Ejemplo n.º 2
0
def _xlog1py_jvp_rhs(g, ans, x, y):
    shape = lax.broadcast_shapes(np.shape(g), np.shape(x))
    g = np.broadcast_to(g, shape)
    x = np.broadcast_to(x, shape)
    x, y = _promote_args_like(osp_special.xlog1py, x, y)
    return g * lax._safe_mul(x, np.reciprocal(1 + y))
Ejemplo n.º 3
0
def xlog1py_jvp_rhs(g, x, y, jaxpr, aval, consts):
    x, y = _promote_args_like(osp_special.xlog1py, x, y)
    g, x = _promote_args_like(osp_special.xlog1py, g, x)
    jac = lax._safe_mul(lax._brcast(x, y), lax._brcast(lax.reciprocal(1 + y),
                                                       x))
    return lax.mul(lax._brcast(g, jac), jac)
Ejemplo n.º 4
0
def _xlog1py_jvp_lhs(g, ans, x, y):
    shape = lax.broadcast_shapes(np.shape(g), np.shape(y))
    g = np.broadcast_to(g, shape)
    y = np.broadcast_to(y, shape)
    g, y = _promote_args_like(osp_special.xlog1py, g, y)
    return lax._safe_mul(g, np.log1p(y))
Ejemplo n.º 5
0
def xlog1py_jvp_lhs(g, x, y, jaxpr, aval, consts):
    x, y = _promote_args_like(osp_special.xlog1py, x, y)
    g, y = _promote_args_like(osp_special.xlog1py, g, y)
    return lax._safe_mul(lax._brcast(g, y), lax._brcast(lax.log1p(y), g))