Ejemplo n.º 1
0
def test_temporary_simplifications():
    """Test that we can locally modify the simplifications class attrib"""
    assert scalars_to_op in OperatorPlus.simplifications
    with temporary_rules(OperatorPlus):
        OperatorPlus.simplifications.remove(scalars_to_op)
        assert scalars_to_op not in OperatorPlus.simplifications
    assert scalars_to_op in OperatorPlus.simplifications
Ejemplo n.º 2
0
def test_exception_teardown():
    """Test that teardown works when breaking out due to an exception"""
    class TemporaryRulesException(Exception):
        pass

    h1 = LocalSpace("h1")
    a = OperatorSymbol("a", hs=h1)
    b = OperatorSymbol("b", hs=h1)
    hs_repr = "LocalSpace('h1')"
    rule_name = 'extra'
    rule = (pattern_head(6, a), lambda: b)
    simplifications = OperatorPlus.simplifications
    try:
        with temporary_rules(ScalarTimesOperator, OperatorPlus):
            ScalarTimesOperator.add_rule(rule_name, rule[0], rule[1])
            OperatorPlus.simplifications.remove(scalars_to_op)
            raise TemporaryRulesException
    except TemporaryRulesException:
        assert rule not in ScalarTimesOperator._rules.values()
        assert scalars_to_op in OperatorPlus.simplifications
    finally:
        # Even if this failed we don't want to make a mess for other tests
        try:
            ScalarTimesOperator.del_rules(rule_name)
        except KeyError:
            pass
        OperatorPlus.simplifications = simplifications
Ejemplo n.º 3
0
def test_simplify():
    """Test simplification of expr according to manual rules"""
    h1 = LocalSpace("h1")
    a = OperatorSymbol("a", hs=h1)
    b = OperatorSymbol("b", hs=h1)
    c = OperatorSymbol("c", hs=h1)
    d = OperatorSymbol("d", hs=h1)

    expr = 2 * (a * b * c - b * c * a)

    A_ = wc('A', head=Operator)
    B_ = wc('B', head=Operator)
    C_ = wc('C', head=Operator)

    def b_times_c_equal_d(B, C):
        if B.label == 'b' and C.label == 'c':
            return d
        else:
            raise CannotSimplify

    with temporary_rules(OperatorTimes):
        OperatorTimes.add_rule('extra', pattern_head(B_, C_),
                               b_times_c_equal_d)
        new_expr = expr.rebuild()

    commutator_rule = (
        pattern(
            OperatorPlus,
            pattern(OperatorTimes, A_, B_),
            pattern(ScalarTimesOperator, -1, pattern(OperatorTimes, B_, A_)),
        ),
        lambda A, B: OperatorSymbol("Commut%s%s" %
                                    (A.label.upper(), B.label.upper()),
                                    hs=A.space),
    )
    assert commutator_rule[0].match(new_expr.term)

    with temporary_rules(OperatorTimes):
        OperatorTimes.add_rule('extra', pattern_head(B_, C_),
                               b_times_c_equal_d)
        new_expr = _apply_rules(expr, [commutator_rule])
    assert (srepr(new_expr) ==
            "ScalarTimesOperator(ScalarValue(2), OperatorSymbol('CommutAD', "
            "hs=LocalSpace('h1')))")
Ejemplo n.º 4
0
def rewrite_with_operator_pm_cc(expr):
    """Try to rewrite expr using :class:`.OperatorPlusMinusCC`.

    Example::

        >>> A = OperatorSymbol('A', hs=1)
        >>> sum = A + A.dag()
        >>> sum2 = rewrite_with_operator_pm_cc(sum)
        >>> print(ascii(sum2))
        A^(1) + c.c.
    """
    # TODO: move this to the toolbox
    from qalgebra.toolbox.core import temporary_rules

    def _combine_operator_p_cc(A, B):
        if B.adjoint() == A:
            return OperatorPlusMinusCC(A, sign=+1)
        else:
            raise CannotSimplify

    def _combine_operator_m_cc(A, B):
        if B.adjoint() == A:
            return OperatorPlusMinusCC(A, sign=-1)
        else:
            raise CannotSimplify

    def _scal_combine_operator_pm_cc(c, A, d, B):
        if B.adjoint() == A:
            if c == d:
                return c * OperatorPlusMinusCC(A, sign=+1)
            elif c == -d:
                return c * OperatorPlusMinusCC(A, sign=-1)
        raise CannotSimplify

    A = wc("A", head=Operator)
    B = wc("B", head=Operator)
    c = wc("c", head=Scalar)
    d = wc("d", head=Scalar)

    with temporary_rules(OperatorPlus, clear=True):
        OperatorPlus.add_rule('PM1', pattern_head(A, B),
                              _combine_operator_p_cc)
        OperatorPlus.add_rule(
            'PM2',
            pattern_head(pattern(ScalarTimesOperator, -1, B), A),
            _combine_operator_m_cc,
        )
        OperatorPlus.add_rule(
            'PM3',
            pattern_head(
                pattern(ScalarTimesOperator, c, A),
                pattern(ScalarTimesOperator, d, B),
            ),
            _scal_combine_operator_pm_cc,
        )
        return expr.rebuild()
Ejemplo n.º 5
0
def test_no_rules():
    """Test creation of expr when rule application for one or more operation is
    suppressed"""
    A, B = (OperatorSymbol(s, hs=0) for s in ('A', 'B'))
    expr = lambda: Commutator.create(2 * A, 2 * (3 * B))
    myrepr = lambda e: srepr(e, cache={A: 'A', B: 'B'})
    assert (myrepr(
        expr()) == 'ScalarTimesOperator(ScalarValue(12), Commutator(A, B))')
    with temporary_rules(ScalarTimesOperator, clear=True):
        assert (myrepr(expr()) == 'ScalarTimesOperator(ScalarValue(4), '
                'ScalarTimesOperator(ScalarValue(3), Commutator(A, B)))')
    with temporary_rules(Commutator, clear=True):
        assert (myrepr(
            expr()) == 'Commutator(ScalarTimesOperator(ScalarValue(2), A), '
                'ScalarTimesOperator(ScalarValue(6), B))')
    with temporary_rules(Commutator, ScalarTimesOperator, clear=True):
        assert (myrepr(
            expr()) == 'Commutator(ScalarTimesOperator(ScalarValue(2), A), '
                'ScalarTimesOperator(ScalarValue(2), '
                'ScalarTimesOperator(ScalarValue(3), B)))')
    assert (myrepr(
        expr()) == 'ScalarTimesOperator(ScalarValue(12), Commutator(A, B))')
Ejemplo n.º 6
0
def test_extra_rules():
    """Test creation of expr with extra rules"""
    h1 = LocalSpace("h1")
    a = OperatorSymbol("a", hs=h1)
    b = OperatorSymbol("b", hs=h1)
    hs_repr = "LocalSpace('h1')"
    rule = (pattern_head(6, a), lambda: b)
    with temporary_rules(ScalarTimesOperator):
        ScalarTimesOperator.add_rule('extra', rule[0], rule[1])
        assert ('extra', rule) in ScalarTimesOperator._rules.items()
        expr = 2 * a * 3 + 3 * (2 * a * 3)
        assert expr == 4 * b
    assert rule not in ScalarTimesOperator._rules.values()
    assert (srepr(2 * a * 3 + 3 *
                  (2 * a * 3)) == "ScalarTimesOperator(ScalarValue(24), "
            "OperatorSymbol('a', hs=" + hs_repr + "))")
Ejemplo n.º 7
0
def test_extra_binary_rules():
    """Test creation of expr with extra binary rules"""
    h1 = LocalSpace("h1")
    a = OperatorSymbol("a", hs=h1)
    b = OperatorSymbol("b", hs=h1)
    c = OperatorSymbol("c", hs=h1)
    A_ = wc('A', head=Operator)
    B_ = wc('B', head=Operator)
    rule = (
        pattern_head(
            pattern(OperatorTimes, A_, B_),
            pattern(ScalarTimesOperator, -1, pattern(OperatorTimes, B_, A_)),
        ),
        lambda A, B: c,
    )
    with temporary_rules(OperatorPlus):
        OperatorPlus.add_rule('extra', rule[0], rule[1])
        assert ('extra', rule) in OperatorPlus._binary_rules.items()
        expr = 2 * (a * b - b * a + IdentityOperator)
        assert expr == 2 * (c + IdentityOperator)
    assert rule not in OperatorPlus._binary_rules.values()