コード例 #1
0
ファイル: test_syntax_sugar.py プロジェクト: jmv2009/tsfc
def test_expressions():
    x = gem.Variable("x", (3, 4))
    y = gem.Variable("y", (4, ))
    i, j = gem.indices(2)

    xij = x[i, j]
    yj = y[j]

    assert xij == gem.Indexed(x, (i, j))
    assert yj == gem.Indexed(y, (j, ))

    assert xij + yj == gem.Sum(xij, yj)
    assert xij * yj == gem.Product(xij, yj)
    assert xij - yj == gem.Sum(xij, gem.Product(gem.Literal(-1), yj))
    assert xij / yj == gem.Division(xij, yj)

    assert xij + 1 == gem.Sum(xij, gem.Literal(1))
    assert 1 + xij == gem.Sum(gem.Literal(1), xij)

    assert (xij + y).shape == (4, )

    assert (x @ y).shape == (3, )

    assert x.T.shape == (4, 3)

    with pytest.raises(ValueError):
        xij.T @ y

    with pytest.raises(ValueError):
        xij + "foo"
コード例 #2
0
def compile_to_gem(expr, translator):
    """Compile a single pointwise expression to GEM.

    :arg expr: The expression to compile.
    :arg translator: a :class:`Translator` instance.
    :returns: A (lvalue, rvalue) pair of preprocessed GEM."""
    if not isinstance(expr, Assign):
        raise ValueError(
            f"Don't know how to assign expression of type {type(expr)}")
    spaces = tuple(c.function_space() for c in expr.coefficients)
    if any(
            type(s.ufl_element()) is ufl.MixedElement for s in spaces
            if s is not None):
        raise ValueError("Not expecting a mixed space at this point, "
                         "did you forget to index a function with .sub(...)?")
    if len(set(s.ufl_element() for s in spaces if s is not None)) != 1:
        raise ValueError("All coefficients must be defined on the same space")
    lvalue = expr.lvalue
    rvalue = expr.rvalue
    broadcast = all(
        isinstance(c, firedrake.Constant)
        for c in expr.rcoefficients) and rvalue.ufl_shape == ()
    if not broadcast and lvalue.ufl_shape != rvalue.ufl_shape:
        try:
            rvalue = reshape(rvalue, lvalue.ufl_shape)
        except ValueError:
            raise ValueError(
                "Mismatching shapes between lvalue and rvalue in pointwise assignment"
            )
    rvalue, = map_expr_dags(LowerCompoundAlgebra(), [rvalue])
    try:
        lvalue, rvalue = map_expr_dags(translator, [lvalue, rvalue])
    except (AssertionError, ValueError):
        raise ValueError("Mismatching shapes in pointwise assignment. "
                         "For intrinsically vector-/tensor-valued spaces make "
                         "sure you're not using shaped Constants or literals.")

    indices = gem.indices(len(lvalue.shape))
    if not broadcast:
        if rvalue.shape != lvalue.shape:
            raise ValueError(
                "Mismatching shapes in pointwise assignment. "
                "For intrinsically vector-/tensor-valued spaces make "
                "sure you're not using shaped Constants or literals.")
        rvalue = gem.Indexed(rvalue, indices)
    lvalue = gem.Indexed(lvalue, indices)
    if isinstance(expr, IAdd):
        rvalue = gem.Sum(lvalue, rvalue)
    elif isinstance(expr, ISub):
        rvalue = gem.Sum(lvalue, gem.Product(gem.Literal(-1), rvalue))
    elif isinstance(expr, IMul):
        rvalue = gem.Product(lvalue, rvalue)
    elif isinstance(expr, IDiv):
        rvalue = gem.Division(lvalue, rvalue)
    return preprocess_gem([lvalue, rvalue])
コード例 #3
0
ファイル: sympy2gem.py プロジェクト: jmv2009/FInAT
def sympy2gem_rational(node, self):
    return gem.Division(*(map(self, node.as_numer_denom())))