Exemple #1
def test_IndexedBase_assumptions():
    i = Symbol('i', integer=True)
    a = Symbol('a')
    A = IndexedBase(a, positive=True)
    for c in (A, A[i]):
        assert c.is_real
        assert c.is_complex
        assert not c.is_imaginary
        assert c.is_nonnegative
        assert c.is_nonzero
        assert c.is_commutative
        assert log(exp(c)) == c

    assert A != IndexedBase(a)
    assert A == IndexedBase(a, positive=True, real=True)
    assert A[i] != Indexed(a, i)
    def __getitem__(self, time):
        For indexing discrete time stochastic processes.


        if time not in self.index_set:
            raise IndexError("%s is not in the index set of %s" %
                             (time, self.symbol))
        idx_obj = Indexed(self.symbol, time)
        distribution = getattr(self, "distribution", None)
        pspace_obj = StochasticPSpace(self.symbol, self, distribution)
        return RandomIndexedSymbol(idx_obj, pspace_obj)
Exemple #3
 def pdf(self, *x):
     expr, rvs = self.args[0], self.args[1]
     marginalise_out = [i for i in random_symbols(expr) if i not in rvs]
     if isinstance(expr, CompoundDistribution):
         syms = Dummy('x', real=True)
         expr = expr.args[0].pdf(syms)
     elif isinstance(expr, JointDistribution):
         count = len(expr.domain.args)
         x = Dummy('x', real=True, finite=True)
         syms = tuple(Indexed(x, i) for i in count)
         expr = expr.pdf(syms)
         syms = tuple(rv.pspace.symbol if isinstance(rv, RandomSymbol
                                                     ) else rv.args[0]
                      for rv in rvs)
     return Lambda(syms, self.compute_pdf(expr, marginalise_out))(*x)
Exemple #4
def test_issue_12283():
    x = symbols('x')
    X = RandomSymbol(x)
    Y = RandomSymbol('Y')
    Z = RandomMatrixSymbol('Z', 2, 1)
    W = RandomMatrixSymbol('W', 2, 1)
    RI = RandomIndexedSymbol(Indexed('RI', 3))
    assert pspace(Z) == PSpace()
    assert pspace(RI) == PSpace()
    assert pspace(X) == PSpace()
    assert E(X) == Expectation(X)
    assert P(Y > 3) == Probability(Y > 3)
    assert variance(X) == Variance(X)
    assert variance(RI) == Variance(RI)
    assert covariance(X, Y) == Covariance(X, Y)
    assert covariance(W, Z) == Covariance(W, Z)
Exemple #5
def shift_grid(expr):
    if expr.is_Symbol:
        return expr
    if expr.is_Number:
        return expr
    if isinstance(expr, Indexed):
        b = expr.base
        idx = list(expr.indices)
        if b.label.name in shift_x:
            idx[1] -= hf
        if b.label.name in shift_y:
            idx[2] -= hf
        t = Indexed(b, *idx)
        return t
    args = tuple([shift_grid(arg) for arg in expr.args])
    expr2 = expr.func(*args)
    return expr2
Exemple #6
 def pdf(self, *x):
     expr, rvs = self.args[0], self.args[1]
     marginalise_out = [i for i in random_symbols(expr) if i not in self.args[1]]
     syms = [i.pspace.symbol for i in self.args[1]]
     for i in expr.atoms(Indexed):
         if isinstance(i, Indexed) and isinstance(i.base, RandomSymbol)\
          and i not in rvs:
     if isinstance(expr, CompoundDistribution):
         syms = Dummy('x', real=True)
         expr = expr.args[0].pdf(syms)
     elif isinstance(expr, JointDistribution):
         count = len(expr.domain.args)
         x = Dummy('x', real=True, finite=True)
         syms = [Indexed(x, i) for i in count]
         expr = expression.pdf(syms)
     return Lambda(syms, self.compute_pdf(expr, marginalise_out))(*x)
Exemple #7
def shift_grid(expr):
    shift all field indices in the input expression to whole number
    if it is staggered (i.e. has index that is in the middle of grids)
    the result expression can then be converted to C code
    if expr.is_Symbol:
        return expr
    if expr.is_Number:
        return expr
    if isinstance(expr, Indexed):
        b = expr.base
        idx = [x-hf if is_half(x) else x for x in list(expr.indices)]
        t = Indexed(b, *idx)
        return t
    args = tuple([shift_grid(arg) for arg in expr.args])  # recursive call
    expr2 = expr.func(*args)
    return expr2
Exemple #8
def test_sample_seed():
    x1, x2 = (Indexed('x', i) for i in (1, 2))
    pdf = exp(-x1**2 / 2 + x1 - x2**2 / 2 - S.Half) / (2 * pi)
    X = JointRV('x', pdf)

    libraries = ['scipy', 'numpy', 'pymc3']
    for lib in libraries:
            imported_lib = import_module(lib)
            if imported_lib:
                s0, s1, s2 = [], [], []
                s0 = list(sample(X, numsamples=10, library=lib, seed=0))
                s1 = list(sample(X, numsamples=10, library=lib, seed=0))
                s2 = list(sample(X, numsamples=10, library=lib, seed=1))
                assert s0 == s1
                assert s1 != s2
        except NotImplementedError:
Exemple #9
 def kappac(self, n):
     fn = self.fn
     r = self.r
     z = self.z
     y = self.y
     W = self.W
     Q = self.Q
     S = self.S
     NBc = self.NBc
     NBbc = self.NBbc
     j = self.j
     L1 = [fn[0] * bell(n, 1, (r))]
     for k in range(2, n + 1):
         L1.extend([fn[k - 1] * bell(n, k, r)])
     L2 = Sum(Indexed('y', j), (j, 0, n - 1))
     L22 = lambdify(y, L2)
     L3 = L22(L1)
     L4 = L3.subs({W(z): Q(z) * S(z) - NBc * NBbc})
     L5 = L4.subs({S(z): NBc + NBbc, Q(z): z * z - NBc * NBbc})
     return L5
Exemple #10
def test_Lambda():
    e = Lambda(x, x**2)
    assert e(4) == 16
    assert e(x) == x**2
    assert e(y) == y**2

    assert Lambda((), 42)() == 42
    assert unchanged(Lambda, (), 42)
    assert Lambda((), 42) != Lambda((), 43)
    assert Lambda((), f(x))() == f(x)
    assert Lambda((), 42).nargs == FiniteSet(0)

    assert unchanged(Lambda, (x,), x**2)
    assert Lambda(x, x**2) == Lambda((x,), x**2)
    assert Lambda(x, x**2) == Lambda(y, y**2)
    assert Lambda(x, x**2) != Lambda(y, y**2 + 1)
    assert Lambda((x, y), x**y) == Lambda((y, x), y**x)
    assert Lambda((x, y), x**y) != Lambda((x, y), y**x)

    assert Lambda((x, y), x**y)(x, y) == x**y
    assert Lambda((x, y), x**y)(3, 3) == 3**3
    assert Lambda((x, y), x**y)(x, 3) == x**3
    assert Lambda((x, y), x**y)(3, y) == 3**y
    assert Lambda(x, f(x))(x) == f(x)
    assert Lambda(x, x**2)(e(x)) == x**4
    assert e(e(x)) == x**4

    x1, x2 = (Indexed('x', i) for i in (1, 2))
    assert Lambda((x1, x2), x1 + x2)(x, y) == x + y

    assert Lambda((x, y), x + y).nargs == FiniteSet(2)

    p = x, y, z, t
    assert Lambda(p, t*(x + y + z))(*p) == t * (x + y + z)

    assert Lambda(x, 2*x) + Lambda(y, 2*y) == 2*Lambda(x, 2*x)
    assert Lambda(x, 2*x) not in [ Lambda(x, x) ]
    raises(TypeError, lambda: Lambda(1, x))
    assert Lambda(x, 1)(1) is S.One

    raises(SyntaxError, lambda: Lambda((x, x), x + 2))
Exemple #11
def shift_index(expr, k, s):
    shift the k-th index of all fields in input expression by s
    return the shifted expression
    :param expr: input expression
    k: the index number to be shifted
    s: the shifted amount
    e.g. k=1, s=1, U[x,y,z] -> U[x,y+1,z]
    if expr.is_Symbol:
        return expr
    if expr.is_Number:
        return expr
    if isinstance(expr, Indexed):
        b = expr.base
        idx = list(expr.indices)
        idx[k] += s
        t = Indexed(b, *idx)
        return t
    # recursive call
    args = tuple([shift_index(arg, k, s) for arg in expr.args])
    expr2 = expr.func(*args)
    return expr2
Exemple #12
def test_indexed_idx_sum():
    i = symbols('i', cls=Idx)
    r = Indexed('r', i)
    assert Sum(r, (i, 0, 3)).doit() == sum([r.xreplace({i: j}) for j in range(4)])
    assert Product(r, (i, 0, 3)).doit() == prod([r.xreplace({i: j}) for j in range(4)])

    j = symbols('j', integer=True)
    assert Sum(r, (i, j, j+2)).doit() == sum([r.xreplace({i: j+k}) for k in range(3)])
    assert Product(r, (i, j, j+2)).doit() == prod([r.xreplace({i: j+k}) for k in range(3)])

    k = Idx('k', range=(1, 3))
    A = IndexedBase('A')
    assert Sum(A[k], k).doit() == sum([A[Idx(j, (1, 3))] for j in range(1, 4)])
    assert Product(A[k], k).doit() == prod([A[Idx(j, (1, 3))] for j in range(1, 4)])

    raises(ValueError, lambda: Sum(A[k], (k, 1, 4)))
    raises(ValueError, lambda: Sum(A[k], (k, 0, 3)))
    raises(ValueError, lambda: Sum(A[k], (k, 2, oo)))

    raises(ValueError, lambda: Product(A[k], (k, 1, 4)))
    raises(ValueError, lambda: Product(A[k], (k, 0, 3)))
    raises(ValueError, lambda: Product(A[k], (k, 2, oo)))
Exemple #13
 def pdf(self):
     sym = [Indexed(self.symbol, i) for i in range(self.component_count)]
     return self.distribution(*sym)
Exemple #14
 def __getitem__(self, key):
     if isinstance(self.pspace, JointPSpace):
         if (self.pspace.component_count <= key) == True:
             raise ValueError("Index keys for %s can only up to %s." %
                 (self.name, self.pspace.component_count - 1))
         return Indexed(self, key)
Exemple #15
    def handleSymmetricYukawa(self, fermions, contractArgs, coeff, freeDummies,
        duplicateFermions = [el for el in fermions if fermions.count(el) == 2]

        if len(duplicateFermions) != 2:
            return expand(

        expanded = []

        duplicateFermion = duplicateFermions[0]

        symbs = [el[0].symbol for el in contractArgs]
        duplicatePos = [
            i for i, el in enumerate(symbs) if el == duplicateFermion

        for pos in duplicatePos:
            duplicateTensor = copy.deepcopy(contractArgs[pos][0])

            newDic = {}
            for k, v in duplicateTensor.dic.items():
                if isinstance(v, Symbol):
                    newDic[k] = Symbol('df$' + str(v))
                elif isinstance(v, Indexed):
                    newDic[k] = Indexed('df$' + str(v.base), *v.indices)

            duplicateTensor.dic = newDic
            newcArgs = []
            for i, el in enumerate(contractArgs):
                if i != pos:
                    if len(contractArgs[pos]) > 1:
                            (duplicateTensor, *contractArgs[pos][1:]))
                        newcArgs.append((duplicateTensor, ))


        if expand(expanded[0] + expanded[1]) == 0:
            # Antisymmetric Yukawa matrix

            # For one generation, the Yukawa operator simply vanishes.
            # In this case an error should be raised eventually
            if self.model.Fermions[str(duplicateFermion)].gen == 1:
                return expand(

            self.model.assumptions[coupling]['antisymmetric'] = True
            # Symmetric Yukawa matrix
            self.model.assumptions[coupling]['symmetric'] = True

        return expanded[0]
Exemple #16
 def class_key(cls):
     return Indexed.class_key()
Exemple #17
def test_issue_9594():
    i = symbols('i', cls=Idx)
    r = Indexed('r', i)
    Sum(r, (i, 0, 3)).doit() == [r.subs(i, j) for j in range(4)]
Exemple #18
def dse_cse(expr):
    Perform common subexpression elimination on sympy expressions.

    :param expr: sympy equation or list of equations on which CSE is performed.

    :return: A list of the resulting equations after performing CSE
    expr = expr if isinstance(expr, list) else [expr]

    temps, stencils = cse(expr, numbered_symbols("temp"))

    # Restores the LHS
    for i in range(len(expr)):
        stencils[i] = Eq(expr[i].lhs, stencils[i].rhs)

    to_revert = {}
    to_keep = []

    # Restores IndexedBases if they are collected by CSE and
    # reverts changes to simple index operations (eg: t - 1)
    for temp, value in temps:
        if isinstance(value, IndexedBase):
            to_revert[temp] = value
        elif isinstance(value, Indexed):
            to_revert[temp] = value
        elif isinstance(value, Add) and not \
                set([t, x, y, z]).isdisjoint(set(value.args)):
            to_revert[temp] = value
            to_keep.append((temp, value))

    # Restores the IndexedBases and the Indexes in the assignments to revert
    for temp, value in to_revert.items():
        s_dict = {}
        for arg in preorder_traversal(value):
            if isinstance(arg, Indexed):
                new_indices = []
                for index in arg.indices:
                    if index in to_revert:
                if arg.base.label in to_revert:
                    s_dict[arg] = Indexed(to_revert[value.base.label],
        to_revert[temp] = value.xreplace(s_dict)

    subs_dict = {}

    # Builds a dictionary of the replacements
    for expr in stencils + [assign for temp, assign in to_keep]:
        for arg in preorder_traversal(expr):
            if isinstance(arg, Indexed):
                new_indices = []
                for index in arg.indices:
                    if index in to_revert:
                if arg.base.label in to_revert:
                    subs_dict[arg] = Indexed(to_revert[arg.base.label],
                elif tuple(new_indices) != arg.indices:
                    subs_dict[arg] = Indexed(arg.base, *new_indices)
            if arg in to_revert:
                subs_dict[arg] = to_revert[arg]

    stencils = [stencil.xreplace(subs_dict) for stencil in stencils]

    to_keep = [Eq(temp[0], temp[1].xreplace(subs_dict)) for temp in to_keep]

    # If the RHS of a temporary variable is the LHS of a stencil,
    # update the value of the temporary variable after the stencil

    new_stencils = []

    for stencil in stencils:

        for temp in to_keep:
            if stencil.lhs in preorder_traversal(temp.rhs):

    return to_keep + new_stencils
Exemple #19
def test_Indexed_func_args():
    i, j = symbols('i j', integer=True)
    a = symbols('a')
    A = Indexed(a, i, j)
    assert A == A.func(*A.args)
Exemple #20
def test_not_interable():
    i, j = symbols('i j', integer=True)
    A = Indexed('A', i, i + j)
    assert not iterable(A)
Exemple #21
def test_MarginalDistribution():
    a1, p1, p2 = symbols('a1 p1 p2', positive=True)
    C = Multinomial('C', 2, p1, p2)
    B = MultivariateBeta('B', a1, C[0])
    MGR = MarginalDistribution(B, (C[0],))
    mgrc = Mul(Symbol('B'), Piecewise(ExprCondPair(Mul(Integer(2),
    Pow(Symbol('p1', positive=True), Indexed(IndexedBase(Symbol('C')),
    Integer(0))), Pow(Symbol('p2', positive=True),
    Indexed(IndexedBase(Symbol('C')), Integer(1))),
    Pow(factorial(Indexed(IndexedBase(Symbol('C')), Integer(0))), Integer(-1)),
    Pow(factorial(Indexed(IndexedBase(Symbol('C')), Integer(1))), Integer(-1))),
    Eq(Add(Indexed(IndexedBase(Symbol('C')), Integer(0)),
    Indexed(IndexedBase(Symbol('C')), Integer(1))), Integer(2))),
    ExprCondPair(Integer(0), True)), Pow(gamma(Symbol('a1', positive=True)),
    Integer(-1)), gamma(Add(Symbol('a1', positive=True),
    Indexed(IndexedBase(Symbol('C')), Integer(0)))),
    Pow(gamma(Indexed(IndexedBase(Symbol('C')), Integer(0))), Integer(-1)),
    Pow(Indexed(IndexedBase(Symbol('B')), Integer(0)),
    Add(Symbol('a1', positive=True), Integer(-1))),
    Pow(Indexed(IndexedBase(Symbol('B')), Integer(1)),
    Add(Indexed(IndexedBase(Symbol('C')), Integer(0)), Integer(-1))))
    assert MGR(C) == mgrc
    def probability(self,
        Handles probability queries for discrete Markov chains.


        condition: Relational
        given_condition: Relational/And


            If the transition probabilities are not available
            If the transition probabilities is MatrixSymbol or Matrix


        Any information passed at the time of query overrides
        any information passed at the time of object creation like
        transition probabilities, state space.

        Pass the transition matrix using TransitionMatrixOf and state space
        using StochasticStateSpaceOf in given_condition using & or And.

        check, trans_probs, state_space, given_condition = \
            self._preprocess(given_condition, evaluate)

        if check:
            return Probability(condition, given_condition)

        if isinstance(condition, Eq) and \
            isinstance(given_condition, Eq) and \
            len(given_condition.atoms(RandomSymbol)) == 1:
            # handles simple queries like P(Eq(X[i], dest_state), Eq(X[i], init_state))
            lhsc, rhsc = condition.lhs, condition.rhs
            lhsg, rhsg = given_condition.lhs, given_condition.rhs
            if not isinstance(lhsc, RandomIndexedSymbol):
                lhsc, rhsc = (rhsc, lhsc)
            if not isinstance(lhsg, RandomIndexedSymbol):
                lhsg, rhsg = (rhsg, lhsg)
            keyc, statec, keyg, stateg = (lhsc.key, rhsc, lhsg.key, rhsg)
            if Lt(stateg, trans_probs.shape[0]) == False or Lt(
                    statec, trans_probs.shape[1]) == False:
                raise IndexError(
                    "No information is avaliable for (%s, %s) in "
                    "transition probabilities of shape, (%s, %s). "
                    "State space is zero indexed." %
                    (stateg, statec, trans_probs.shape[0],
            if keyc < keyg:
                raise ValueError(
                    "Incorrect given condition is given, probability "
                    "of past state cannot be computed from future state.")
            nsteptp = trans_probs**(keyc - keyg)
            if hasattr(nsteptp, "__getitem__"):
                return nsteptp.__getitem__((stateg, statec))
            return Indexed(nsteptp, stateg, statec)

        if isinstance(condition, And):
            # handle queries like,
            # P(Eq(X[i+k], s1) & Eq(X[i+m], s2) . . . & Eq(X[i], sn), Eq(P(X[i]), prob))
            conds = condition.args
            i, result = -1, 1
            while i > -len(conds):
                result *= self.probability(conds[i], conds[i-1] & \
                            TransitionMatrixOf(self, trans_probs) & \
                            StochasticStateSpaceOf(self, state_space))
                i -= 1
            if isinstance(given_condition,
                          (TransitionMatrixOf, StochasticStateSpaceOf)):
                return result * Probability(conds[i])
            if isinstance(given_condition, Eq):
                if not isinstance(given_condition.lhs, Probability) or \
                    given_condition.lhs.args[0] != conds[i]:
                    raise ValueError("Probability for %s needed", conds[i])
                return result * given_condition.rhs

        raise NotImplementedError(
            "Mechanism for handling (%s, %s) queries hasn't been "
            "implemented yet." % (condition, given_condition))
Exemple #23
 def __getitem__(self, key):
     from sympy.stats.joint_rv import JointPSpace
     if isinstance(self.pspace, JointPSpace):
         return Indexed(self, key)
def get_sum(min_limit, max_limit, reg):
    x, i = symbols('x i')
    return Sum(Indexed('x', i), (i, min_limit, max_limit)).doit()
Exemple #25
def test_Idx_limits():
    i = symbols('i', cls=Idx)
    r = Indexed('r', i)

    assert SeqFormula(r, (i, 0, 5))[:] == [r.subs(i, j) for j in range(6)]
    assert SeqPer((1, 2), (i, 0, 5))[:] == [1, 2, 1, 2, 1, 2]
Exemple #26
 def __getitem__(self, indices, **kw_args):
     if is_sequence(indices):
         # Special case needed because M[*my_tuple] is a syntax error.
         return Indexed(self, *indices, **kw_args)
         return Indexed(self, indices, **kw_args)
Exemple #27
def test_Indexed_func_args():
    i, j = symbols('i j', integer=True)
    a = symbols('a')
    A = Indexed(a, i, j)
    assert A == A.func(*A.args)
Exemple #28
def test_Idx_limits():
    i = symbols('i', cls=Idx)
    r = Indexed('r', i)

    assert SeqFormula(r, (i, 0, 5))[:] == [r.subs(i, j) for j in range(6)]
    assert SeqPer((1, 2), (i, 0, 5))[:] == [1, 2, 1, 2, 1, 2]
Exemple #29
def test_not_interable():
    i, j = symbols("i j", integer=True)
    A = Indexed("A", i, i + j)
    assert not iterable(A)
Exemple #30
def test_complex_indices():
    i, j = symbols('i j', integer=True)
    A = Indexed('A', i, i + j)
    assert A.rank == 2
    assert A.indices == (i, i + j)
Exemple #31
def generate_coupling_expansions(obj, verbose=True):
    generate expansions for coupling.
    if verbose:
        print('* Generating coupling expansions...')
    i = sym.symbols('i_sym')  # summation index
    psi = obj.psi
    eps = obj.eps

    ruleA = {'psi': obj.pA['expand']}
    ruleB = {'psi': obj.pB['expand']}

    rule_trunc = {}
    for k in range(obj.miter, obj.miter + 200):
        rule_trunc.update({obj.eps**k: 0})

    for key in obj.var_names:
        if verbose:
            print('key in coupling expand', key)
        gA = Sum(psi**i * Indexed('g' + key + 'A', i),
                 (i, 1, obj.miter)).doit()
        gB = Sum(psi**i * Indexed('g' + key + 'B', i),
                 (i, 1, obj.miter)).doit()

        iA = Sum(psi**i * Indexed('i' + key + 'A', i),
                 (i, 0, obj.miter)).doit()
        iB = Sum(psi**i * Indexed('i' + key + 'B', i),
                 (i, 0, obj.miter)).doit()

        tmp = gA.subs(ruleA)
        tmp = sym.expand(tmp,
        tmp = tmp.subs(rule_trunc)
        gA_collected = collect(expand(tmp).subs(rule_trunc), eps)
        if verbose:
            print('completed gA collected')

        tmp = gB.subs(ruleB)
        tmp = sym.expand(tmp,
        tmp = tmp.subs(rule_trunc)
        gB_collected = collect(expand(tmp).subs(rule_trunc), eps)
        if verbose:
            print('completed gB collected')

        tmp = iA.subs(ruleA)
        tmp = sym.expand(tmp,
        tmp = tmp.subs(rule_trunc)
        iA_collected = collect(expand(tmp).subs(rule_trunc), eps)
        if verbose:
            print('completed iA collected')

        tmp = iB.subs(ruleB)
        tmp = sym.expand(tmp,
        tmp = tmp.subs(rule_trunc)
        iB_collected = collect(expand(tmp).subs(rule_trunc), eps)
        if verbose:
            print('completed iB collected')

        obj.g[key + '_epsA'] = 0
        obj.g[key + '_epsB'] = 0

        obj.i[key + '_epsA'] = 0
        obj.i[key + '_epsB'] = 0

        for j in range(obj.miter):
            obj.g[key + '_epsA'] += eps**j * gA_collected.coeff(eps, j)
            obj.g[key + '_epsB'] += eps**j * gB_collected.coeff(eps, j)

            obj.i[key + '_epsA'] += eps**j * iA_collected.coeff(eps, j)
            obj.i[key + '_epsB'] += eps**j * iB_collected.coeff(eps, j)

    return obj.g, obj.i
Exemple #32
def test_Lambda():
    e = Lambda(x, x**2)
    assert e(4) == 16
    assert e(x) == x**2
    assert e(y) == y**2

    assert Lambda((), 42)() == 42
    assert unchanged(Lambda, (), 42)
    assert Lambda((), 42) != Lambda((), 43)
    assert Lambda((), f(x))() == f(x)
    assert Lambda((), 42).nargs == FiniteSet(0)

    assert unchanged(Lambda, (x, ), x**2)
    assert Lambda(x, x**2) == Lambda((x, ), x**2)
    assert Lambda(x, x**2) == Lambda(y, y**2)
    assert Lambda(x, x**2) != Lambda(y, y**2 + 1)
    assert Lambda((x, y), x**y) == Lambda((y, x), y**x)
    assert Lambda((x, y), x**y) != Lambda((x, y), y**x)

    assert Lambda((x, y), x**y)(x, y) == x**y
    assert Lambda((x, y), x**y)(3, 3) == 3**3
    assert Lambda((x, y), x**y)(x, 3) == x**3
    assert Lambda((x, y), x**y)(3, y) == 3**y
    assert Lambda(x, f(x))(x) == f(x)
    assert Lambda(x, x**2)(e(x)) == x**4
    assert e(e(x)) == x**4

    x1, x2 = (Indexed('x', i) for i in (1, 2))
    assert Lambda((x1, x2), x1 + x2)(x, y) == x + y

    assert Lambda((x, y), x + y).nargs == FiniteSet(2)

    p = x, y, z, t
    assert Lambda(p, t * (x + y + z))(*p) == t * (x + y + z)

    assert Lambda(x, 2 * x) + Lambda(y, 2 * y) == 2 * Lambda(x, 2 * x)
    assert Lambda(x, 2 * x) not in [Lambda(x, x)]
    raises(BadSignatureError, lambda: Lambda(1, x))
    assert Lambda(x, 1)(1) is S.One

    raises(BadSignatureError, lambda: Lambda((x, x), x + 2))
    raises(BadSignatureError, lambda: Lambda(((x, x), y), x))
    raises(BadSignatureError, lambda: Lambda(((y, x), x), x))
    raises(BadSignatureError, lambda: Lambda(((y, 1), 2), x))

    with warns_deprecated_sympy():
        assert Lambda([x, y], x + y) == Lambda((x, y), x + y)

    flam = Lambda(((x, y), ), x + y)
    assert flam((2, 3)) == 5
    flam = Lambda(((x, y), z), x + y + z)
    assert flam((2, 3), 1) == 6
    flam = Lambda((((x, y), z), ), x + y + z)
    assert flam(((2, 3), 1)) == 6
    raises(BadArgumentsError, lambda: flam(1, 2, 3))
    flam = Lambda((x, ), (x, x))
    assert flam(1, ) == (1, 1)
    assert flam((1, )) == ((1, ), (1, ))
    flam = Lambda(((x, ), ), (x, x))
    raises(BadArgumentsError, lambda: flam(1))
    assert flam((1, )) == (1, 1)

    # Previously TypeError was raised so this is potentially needed for
    # backwards compatibility.
    assert issubclass(BadSignatureError, TypeError)
    assert issubclass(BadArgumentsError, TypeError)

    # These are tested to see they don't raise:
    hash(Lambda(x, 2 * x))
    hash(Lambda(x, x))  # IdentityFunction subclass
Exemple #33
def load_coupling_expansions(obj, fn='gA', recompute=False):

    i = sym.symbols('i_sym')  # summation index
    #psi = obj.psi
    eps = obj.eps

    # for solution of isostables in terms of theta.
    obj.pA['expand'] = Sum(eps**i * Indexed('pA', i), (i, 1, obj.miter)).doit()
    obj.pB['expand'] = Sum(eps**i * Indexed('pB', i), (i, 1, obj.miter)).doit()

    #fname = obj.hodd['dat_fnames'][i]
    #file_does_not_exist = not(os.path.exists(fname))
    for key in obj.var_names:
        obj.g[key + '_epsA'] = []
        obj.g[key + '_epsB'] = []

        obj.i[key + '_epsA'] = []
        obj.i[key + '_epsB'] = []

        # check that files exist
        val = 0
        for key in obj.var_names:
            val += not (os.path.isfile(obj.g[key + '_epsA_fname']))
            val += not (os.path.isfile(obj.g[key + '_epsB_fname']))
            val += not (os.path.isfile(obj.i[key + '_epsA_fname']))
            val += not (os.path.isfile(obj.i[key + '_epsB_fname']))


        if val != 0:
            files_do_not_exist = True
            files_do_not_exist = False

    if recompute or files_do_not_exist:

        for key in obj.var_names:
            # dump
            dill.dump(obj.g[key + '_epsA'],
                      open(obj.g[key + '_epsA_fname'], 'wb'),
            dill.dump(obj.g[key + '_epsB'],
                      open(obj.g[key + '_epsB_fname'], 'wb'),

            dill.dump(obj.i[key + '_epsA'],
                      open(obj.i[key + '_epsA_fname'], 'wb'),
            dill.dump(obj.i[key + '_epsB'],
                      open(obj.i[key + '_epsB_fname'], 'wb'),


        for key in obj.var_names:
            obj.g[key + '_epsA'] = dill.load(
                open(obj.g[key + '_epsA_fname'], 'rb'))
            obj.g[key + '_epsB'] = dill.load(
                open(obj.g[key + '_epsB_fname'], 'rb'))

            obj.i[key + '_epsA'] = dill.load(
                open(obj.i[key + '_epsA_fname'], 'rb'))
            obj.i[key + '_epsB'] = dill.load(
                open(obj.i[key + '_epsB_fname'], 'rb'))

    # vector of i expansion
    obj.i['vecA'] = sym.zeros(obj.dim, 1)
    obj.i['vecB'] = sym.zeros(obj.dim, 1)

    for i, key in enumerate(obj.var_names):
        obj.i['vecA'][i] = obj.i[key + '_epsA']
        obj.i['vecB'][i] = obj.i[key + '_epsB']
