示例#1
0
def test_differentiator_flags_for_nonsmooth_and_discontinuous():
    import pymbolic.functions as pf
    from pymbolic.mapper.differentiator import differentiate

    x = prim.Variable("x")

    with pytest.raises(ValueError):
        differentiate(pf.fabs(x), x)

    result = differentiate(pf.fabs(x), x, allowed_nonsmoothness="continuous")
    assert result == pf.sign(x)

    with pytest.raises(ValueError):
        differentiate(pf.sign(x), x)

    result = differentiate(pf.sign(x), x, allowed_nonsmoothness="discontinuous")
    assert result == 0
示例#2
0
def map_math_functions_by_name(i, func, pars, allowed_nonsmoothness="none"):
    def make_f(name):
        return primitives.Lookup(primitives.Variable("math"), name)

    if func == make_f("sin") and len(pars) == 1:
        return make_f("cos")(*pars)
    elif func == make_f("cos") and len(pars) == 1:
        return -make_f("sin")(*pars)
    elif func == make_f("tan") and len(pars) == 1:
        return make_f("tan")(*pars)**2 + 1
    elif func == make_f("log") and len(pars) == 1:
        return primitives.quotient(1, pars[0])
    elif func == make_f("exp") and len(pars) == 1:
        return make_f("exp")(*pars)
    elif func == make_f("sinh") and len(pars) == 1:
        return make_f("cosh")(*pars)
    elif func == make_f("cosh") and len(pars) == 1:
        return make_f("sinh")(*pars)
    elif func == make_f("tanh") and len(pars) == 1:
        return 1 - make_f("tanh")(*pars)**2
    elif func == make_f("expm1") and len(pars) == 1:
        return make_f("exp")(*pars)
    elif func == make_f("fabs") and len(pars) == 1:
        if allowed_nonsmoothness in ["continuous", "discontinuous"]:
            from pymbolic.functions import sign
            return sign(*pars)
        else:
            raise ValueError("fabs is not smooth"
                             ", pass allowed_nonsmoothness='continuous' "
                             "to return sign")
    elif func == make_f("copysign") and len(pars) == 2:
        if allowed_nonsmoothness == "discontinuous":
            return 0
        else:
            raise ValueError("sign is discontinuous"
                             ", pass allowed_nonsmoothness='discontinuous' "
                             "to return 0")
    else:
        raise RuntimeError("unrecognized function, cannot differentiate")