def extract_free_variables(e):
    m = Match(e)

    if m(Sum):
        limit_var = e.limits[0][0]
        vars = extract_free_variables(e.function)
        lower = extract_free_variables(e.limits[0][1])
        upper = extract_free_variables(e.limits[0][2])
        vars.update(lower)
        vars.update(upper)
        if limit_var in vars:
            vars.remove(limit_var)
        return vars

    if m(Symbol):
        return set([e])

    if m.type(FunctionClass):
        return set([type(e)])

    # not sure how to express this any better?
    if (len(e.args)) == 0:
        return set([])

    # default
    arg = e.args
    free_vars = set()
    for a in arg:
        other = extract_free_variables(a)
        free_vars.update(other)

    return free_vars
示例#2
0
def extract_free_variables(e):
    m = Match(e)

    if m(Sum):
        limit_var = e.limits[0][0]
        vars = extract_free_variables(e.function)
        lower = extract_free_variables(e.limits[0][1])
        upper = extract_free_variables(e.limits[0][2])
        vars.update(lower)
        vars.update(upper)
        if limit_var in vars:
            vars.remove(limit_var)
        return vars

    if m(Symbol):
        return set([e])

    if m.type(FunctionClass):
        return set([type(e)])

    # not sure how to express this any better?
    if (len(e.args)) == 0:
        return set([])

    # default
    arg = e.args
    free_vars = set()
    for a in arg:
        other = extract_free_variables(a)
        free_vars.update(other)

    return free_vars
示例#3
0
    def __call__(self, e):
        m = Match(e)

        if m(Integral):
            name = 'val' + str(self.idx)
            ret = ([(name, e)], Symbol(name))
            self.idx += 1
            return ret

        # not sure how to express this any better?

        if (len(e.args)) == 0:
            return ([], e)

        # default
        arg = e.args
        ints = []
        new_args = []
        for a in arg:
            (new_ints, other) = self(a)
            ints.extend(new_ints)
            new_args.append(other)
            #if not other:
            #    other = Symbol('total')

        return (ints, e.func(*new_args))
示例#4
0
def extract_sum(e):
    m = Match(e)

    if m(Sum):
        return (e, Symbol('total'))
        #return (e, None)

    # not sure how to express this any better?
    if (len(e.args)) == 0:
        return (None, e)

    # default
    arg = e.args
    sum = None
    new_args = []
    for a in arg:
        (sums, other) = extract_sum(a)
        if sums:
            sum = sums
        #if other:
        new_args.append(other)
        #if not other:
        #    other = Symbol('total')

    return (sum, e.func(*new_args))
示例#5
0
def find_func(func, e):
    m = Match(e)
    if m(func):
        return [e]

    fs = []
    for n in e.args:
        fs.extend(find_func(func, n))
    return fs
示例#6
0
def sum_to_c(e, **kw):
    v = AutoVar()
    m = Match(e)

    ec = expr_to_c(**kw)
    if m(Sum, v.e1, v.e2):
        lower_limit = ec(v.e2[1])
        upper_limit = ec((v.e2[2]+1))
        body = c_assign_plus(c_var('total'), ec(v.e1))
        #loop =  c_for(ec(v.e2[0]),function_call('range',lower_limit,upper_limit),body)
        loop_idx = c_var(str(v.e2[0]))
        typed_loop_idx = c_int(str(v.e2[0]))
        loop_init = c_assign(typed_loop_idx, lower_limit)
        loop_test = c_less_than(loop_idx, upper_limit)
        loop_inc = c_pre_incr(loop_idx)
        loop =  c_for(loop_init, loop_test, loop_inc, body)
        return loop
示例#7
0
    def __call__(self, e):
        v = AutoVar()
        m = Match(e)

        # subtraction
        if m(Add, v.e1, (Mul, S.NegativeOne, v.e2)):
            return c_sub(self(v.e1), self(v.e2))

        if m(Add, v.e1, v.e2):
            #return c_expr(c_expr.C_OP_PLUS, self(v.e1), self(v.e2))
            return c_add(self(v.e1), self(v.e2))

        # reciprocal
        if m(Pow, v.e2, S.NegativeOne):
            return c_expr(c_expr.C_OP_DIVIDE, c_num(1.0), self(v.e2))

        # division
        if m(Mul, v.e1, (Pow, v.e2, S.NegativeOne)):
            #return c_expr(c_expr.C_OP_DIVIDE, self(v.e1), self(v.e2))
            return c_div(self(v.e1), self(v.e2))

        if m(Mul, v.e1, v.e2):
            return c_expr(c_expr.C_OP_TIMES, self(v.e1), self(v.e2))

        if m(exp, v.e1):
            return c_function_call('exp', self(v.e1))

        if m(Pow, v.e1, v.e2):
            return c_function_call('pow', self(v.e1), self(v.e2))

        #if m(Indexed, (IndexedBase, v.e1), (Idx, v.e2)):
        if m(Indexed, (IndexedBase, v.e1), v.e2):
            if str(v.e1) in self._index_trans:
                idx_var, idx_expr = self._index_trans[str(v.e1)]
                ex = self(idx_expr.subs(idx_var, v.e2))
                return ex
            return None

        if m.type(FunctionClass):
            args = [self(a) for a in e.args]
            args.extend([c_var(a) for a in self._extra_args])
            name = str(type(e))
            if name in self._func_trans:
                name = self._func_trans[name]
            return c_function_call(name, *args)

        if m(Symbol):
            #if str(e) in self._scope:
            #    return self._scope[str(e)]
            ## automatically create the variable if it does not exist in the symbol table
            ## could also issue a warning or an error here
            var = c_var(str(e))
            #self._scope[str(e)] = var
            return var

        if m.exact(S.Half):
            return c_div(c_num(1.0), c_num(2.0))

        if m(Integer):
            return c_num(e.p)

        if m(Float):
            return c_num(e.num)

        print 'no match',type(e)
        return None
    def __call__(self, e):
        v = AutoVar()
        m = Match(e)

        # subtraction
        if m(Add, (Mul, S.NegativeOne, v.e1), v.e2):
            return py_expr(py_expr.PY_OP_MINUS, self(v.e2), self(v.e1))


        if m(Add, v.e1, v.e2):
            return py_expr(py_expr.PY_OP_PLUS, self(v.e1), self(v.e2))

        # reciprocal
        if m(Pow, v.e2, S.NegativeOne):
            return py_expr(py_expr.PY_OP_DIVIDE, py_num(1.0), self(v.e2))

        # division
        if m(Mul, v.e1, (Pow, v.e2, S.NegativeOne)):
            return py_expr(py_expr.PY_OP_DIVIDE, self(v.e1), self(v.e2))

        if m(Mul, S.NegativeOne, v.e1):
            return py_expr(py_expr.PY_OP_MINUS, self(v.e1))

        if m(Mul, v.e1, v.e2):
            return py_expr(py_expr.PY_OP_TIMES, self(v.e1), self(v.e2))

        if m(exp, v.e1):
            return py_function_call('math.exp', self(v.e1))

        if m(numbers.Pi):
            return py_var("math.pi")

        if m(Indexed, (IndexedBase, v.e1), v.e2):
            if str(v.e1) in self._index_trans:
                idx_var,idx_expr = self._index_trans[str(v.e1)]
                ex = self(idx_expr.subs(idx_var,v.e2))
                return ex
            return None


        # alternate syntax for the pattern match?
        #if m(Pow, v.e1 ** v.e2):
        if m(Pow, v.e1, v.e2):
            return py_expr(py_expr.PY_OP_POW, self(v.e1), self(v.e2))

        # function call
        if m.type(FunctionClass):
            args = [self(a) for a in e.args]
            args.extend([py_var(a) for a in self._extra_args])
            name = str(type(e))
            if name in self._func_trans:
                name = self._func_trans[name]
            return py_function_call(name, *args)


        if m(Integral):
            return self.convert_integral(e)

        if m(Sum, v.e1, v.e2):
            init = py_assign_stmt(py_var('total'),py_num(0),py_assign_stmt.PY_ASSIGN_EQUAL)

            lower_limit = self(v.e2[1])
            upper_limit = self((v.e2[2]+1))
            body = py_assign_stmt(py_var('total'),self(v.e1),py_assign_stmt.PY_ASSIGN_PLUS)
            loop =  py_for(self(v.e2[0]),py_function_call('range',lower_limit,upper_limit),body)
            self._pre_statements.append(init)
            self._pre_statements.append(loop)
            return py_var('total')

        if m(Symbol):
            return py_var(str(e))

        if m(Integer):
            return py_num(e.p)

        if m(Float):
            return py_num(e.num,promote_to_fp=True)

        # alternate syntax for the pattern match?
        # m(Rational, m.numerator(v.e1), m.denominator(v.e2))
        # m(Rational, v.e1 / v.e2)
        if m(Rational):
            (n,d) = e.as_numer_denom()
            return py_expr(py_expr.PY_OP_DIVIDE, py_num(n), py_num(d, True))