示例#1
0
    def visit_unary(self, node: AstUnary):
        op = node.op
        if op == '+':
            return self.visit(node.item)

        if op == 'not':
            item = node.item._visit_expr(self)
            if isinstance(item, AstCompare) and item.second_right is None:
                return self.visit(
                    _cl(AstCompare(item.left, item.neg_op, item.right), node))

            if isinstance(item, AstBinary) and item.op in ('and', 'or'):
                return self.visit(
                    _cl(
                        AstBinary(AstUnary('not', item.left),
                                  'and' if item.op == 'or' else 'or',
                                  AstUnary('not', item.right)), node))

            if is_boolean(item):
                return _cl(AstValue(not item.value), node)

        if isinstance(node.item, AstUnary) and op == node.item.op:
            return self.visit(node.item.item)

        item = self.visit(node.item)
        if is_number(item):
            if op == '-':
                return _cl(AstValue(-item.value), node)

        if item is node.item:
            return node
        else:
            return node.clone(item=item)
示例#2
0
    def visit_list_for(self, node:AstListFor):
        source = self.visit(node.source)
        if is_vector(source):
            src_len = len(source)
        else:
            src_type = self.get_type(source)
            if isinstance(src_type, ppl_types.SequenceType):
                src_len = src_type.size
            else:
                src_len = None

        if node.test is None:
            if node.target == '_' and src_len is not None:
                if isinstance(node.expr, AstSample) and node.expr.size is None:
                    return self.visit(node.expr.clone(size=AstValue(src_len)))
                else:
                    return self.visit(_cl(makeVector([node.expr for _ in range(src_len)]), node))

            if is_vector(source):
                result = makeVector([AstLet(node.target, item, node.expr,
                                            original_target=node.original_target) for item in source])
                return self.visit(_cl(result, node))

            elif src_len is not None:
                result = makeVector([AstLet(node.target, makeSubscript(source, i), node.expr,
                                             original_target=node.original_target) for i in range(src_len)])
                return self.visit(_cl(result, node))

        for name in get_info(node.expr).changed_vars:
            self.lock_name(name)

        test = self.visit(node.test)
        expr = self.visit(node.expr)
        return _cl(AstListFor(node.target, source, expr, test, original_target=node.original_target), node)
示例#3
0
    def visit_call_clojure_core_concat(self, node: AstCall):
        import itertools
        if not node.has_keyword_args:
            args = [self.visit(arg) for arg in node.args]
            if all([is_string(item) for item in args]):
                return _cl(AstValue(''.join([item.value for item in args])),
                           node)

            elif all([isinstance(item, AstValueVector) for item in args]):
                return _cl(
                    AstValue(
                        list(itertools.chain([item.value for item in args]))),
                    node)

            elif all([is_vector(item) for item in args]):
                args = [
                    item if isinstance(item, AstVector) else item.to_vector()
                    for item in args
                ]
                return _cl(
                    AstValue(
                        list(itertools.chain([item.value for item in args]))),
                    node)

        return self.visit_call(node)
示例#4
0
def clean_locals(ast, f_locals):
    if isinstance(ast, AstBody):
        items = ast.items[:]
        free_vars = [get_info(node).free_vars for node in items]
        i = 0
        while i < len(items):
            if isinstance(items[i], AstDef):
                name = items[i].name
                if name in f_locals and all(
                    [name not in fv for fv in free_vars]):
                    del items[i]
                    del free_vars[i]
                    continue
            i += 1

        if len(items) < len(ast.items):
            return _cl(makeBody(items), ast)
        else:
            return ast

    elif isinstance(ast, AstReturn):
        value = clean_locals(ast.value, f_locals)
        if value is not ast.value:
            return _cl(AstReturn(value), ast)
        else:
            return ast

    else:
        return ast
示例#5
0
    def visit_return(self, node: AstReturn):
        value = self.visit(node.value)
        if isinstance(value, AstBody):
            items = value.items
            ret = self.visit(_cl(AstReturn(items[-1]), node))
            return _cl(makeBody(items[:-1], ret), value)
        elif isinstance(value, AstLet):
            with self.create_lock(value.target):
                ret = self.visit(_cl(AstReturn(value.body), node))
            return _cl(AstLet(value.target, value.source, ret), value)

        if value is not node.value:
            return _cl(AstReturn(value), node)
        else:
            return node
示例#6
0
    def visit_call_abs(self, node: AstCall):
        if node.arg_count == 1 and not node.has_keyword_args:
            arg = self.visit_expr(node.args[0])
            if isinstance(arg, AstValue):
                return _cl(AstValue(abs(arg.value)), node)

        return self.visit_call(node)
示例#7
0
    def visit_call(self, node: AstCall):
        function = self.visit(node.function)
        prefix, args = self.parse_args(node.args)
        if isinstance(function, AstFunction) and all(
            [not get_info(arg).has_changed_vars for arg in args]):
            self.define_all(function.parameters, args, vararg=function.vararg)
            result = self.visit(function.body)
            if function.f_locals is not None:
                result = clean_locals(result, function.f_locals)

            if get_info(result).return_count == 1:
                if isinstance(result, AstReturn):
                    result = result.value
                    result = result if result is not None else AstValue(None)
                    if len(prefix) > 0:
                        result = makeBody(prefix, result)
                    return result

                elif isinstance(result, AstBody) and result.last_is_return:
                    items = prefix + result.items[:-1]
                    result = result.items[-1].value
                    result = result if result is not None else AstValue(None)
                    return makeBody(items, result)

        elif isinstance(function, AstDict):
            if len(args) != 1 or node.has_keyword_args:
                raise TypeError(
                    "dict access requires exactly one argument ({} given)".
                    format(node.arg_count))
            return _cl(makeSubscript(function, args[0]), node)

        result = node.clone(function=function, args=args)
        return makeBody(prefix, result)
示例#8
0
 def visit_let(self, node: AstLet):
     prefix, source = self._visit_expr(node.source)
     body = self.visit(node.body)
     if source is node.source and body is node.body:
         return node
     else:
         return _cl(makeBody(prefix, AstLet(node.target, source, body, original_target=node.original_target)), node)
示例#9
0
    def visit_import(self, node: AstImport):
        module_name, names = ppl_namespaces.namespace_from_module(
            node.module_name)
        if node.imported_names is not None:
            if node.alias is None:
                for name in node.imported_names:
                    self.define(
                        name,
                        AstSymbol("{}.{}".format(module_name, name),
                                  predef=True))
            else:
                self.define(
                    node.alias,
                    AstSymbol("{}.{}".format(module_name,
                                             node.imported_names[0]),
                              predef=True))

        else:
            bindings = {
                key: AstSymbol("{}.{}".format(module_name, key), predef=True)
                for key in names
            }
            ns = AstNamespace(module_name, bindings)
            self.define(node.module_name, ns)

        return _cl(AstImport(module_name), node)
示例#10
0
    def visit_call_math_sqrt(self, node: AstCall):
        if node.arg_count == 1:
            value = self.visit_expr(node.args[0])
            if isinstance(value, AstValue):
                return _cl(AstValue(math.sqrt(value.value)), node)

        return self.visit_call(node)
示例#11
0
 def visit_observe(self, node: AstObserve):
     dist = self.visit(node.dist)
     value = self.visit(node.value)
     if dist is node.dist and value is node.value:
         return node
     else:
         return _cl(AstObserve(dist, value), node)
示例#12
0
 def visit_while(self, node: AstWhile):
     test = self.visit(node.test)
     body = self.visit(node.body)
     if test is node.test and body is node.body:
         return node
     else:
         return _cl(AstWhile(test, body), node)
示例#13
0
 def visit_binary(self, node: AstBinary):
     left = self.visit(node.left)
     right = self.visit(node.right)
     if left is node.left and right is node.right:
         return node
     else:
         return _cl(AstBinary(left, node.op, right), node)
示例#14
0
 def visit_compare(self, node: AstCompare):
     left = self.visit(node.left)
     right = self.visit(node.right)
     if left is node.left and right is node.right:
         return node
     else:
         return _cl(AstCompare(left, node.op, right), node)
示例#15
0
 def visit_vector(self, node: AstVector):
     items = [self.visit(item) for item in node.items]
     if len(items) > 0 and all([isinstance(item, AstSample) and item.size is None for item in items]) and \
             all([item.dist == items[0].dist for item in items]):
         return _cl(AstSample(items[0].dist, size=AstValue(len(items))),
                    node)
     return makeVector(items)
示例#16
0
 def visit_sample(self, node: AstSample):
     dist = self.visit(node.dist)
     size = self.visit(node.size)
     if dist is not node.dist or size is not node.size:
         return _cl(AstSample(dist, size=size), node)
     else:
         return node
示例#17
0
 def visit_function(self, node:AstFunction):
     with self.create_lock():
         self.lock_all()
         body = self.visit(node.body)
         if body is not node.body:
             return _cl(AstFunction(node.name, node.parameters, body, vararg=node.vararg,
                                    doc_string=node.doc_string, f_locals=node.f_locals), node)
     return node
示例#18
0
 def visit_subscript(self, node: AstSubscript):
     base_prefix, base = self._visit_expr(node.base)
     index_prefix, index = self._visit_expr(node.index)
     if base is node.base and index is node.index:
         return node
     else:
         prefix = base_prefix + index_prefix
         return _cl(makeBody(prefix, makeSubscript(base, index)), node)
示例#19
0
 def visit_for(self, node: AstFor):
     prefix, source = self._visit_expr(node.source)
     body = self.visit(node.body)
     target = node.target if node.target in get_info(body).free_vars else '_'
     if target is node.target and source is node.source and body is node.body:
         return node
     else:
         return _cl(makeBody(prefix, AstFor(target, source, body, original_target=node.original_target)), node)
示例#20
0
    def visit_for(self, node:AstFor):
        source = self.visit(node.source)
        if is_vector(source):
            result = makeBody([AstLet(node.target, item, node.body) for item in source])
            return self.visit(_cl(result, node))
        else:
            src_type = self.get_type(source)
            if isinstance(src_type, ppl_types.SequenceType) and src_type.size is not None:
                result = makeBody([
                             AstLet(node.target, makeSubscript(source, i), node.body,
                                    original_target=node.original_target) for i in range(src_type.size)
                         ])
                return self.visit(_cl(result, node))

        for name in get_info(node.body).changed_vars:
            self.lock_name(name)
        body = self.visit(node.body)
        return node.clone(source=source, body=body)
示例#21
0
    def visit_binary(self, node:AstBinary):
        l_prefix, left = self._visit_expr(node.left)
        r_prefix, right = self._visit_expr(node.right)
        prefix = l_prefix + r_prefix

        if left is node.left and right is node.right:
            return node
        else:
            prefix.append(AstBinary(left, node.op, right))
            return _cl(makeBody(prefix), node)
示例#22
0
    def visit_let(self, node:AstLet):
        if count_variable_usage(node.target, node.body) == 0:
            return self.visit(_cl(makeBody(node.source, node.body), node))

        source = self.visit_expr(node.source)
        src_info = get_info(source)
        if isinstance(source, AstBody) and len(source) > 1:
            result = node.clone(source=source.items[-1])
            result = _cl(makeBody(source.items[:-1], result), node.source)
            return self.visit(result)

        elif src_info.is_independent(get_info(node.body)) and \
                (count_variable_usage(node.target, node.body) == 1 or src_info.can_embed):
            print("CAN EMBED", source, src_info.can_embed, count_variable_usage(node.target, node.body), node.target)
            print(" " * 20, "-->", node.body)
            self.define(node.target, self.visit(node.source))
            return _cl(self.visit(node.body), node)

        return self.visit(makeBody(AstDef(node.target, node.source), node.body))
示例#23
0
    def visit_slice(self, node:AstSlice):
        base = self.visit(node.base)
        start = self.visit(node.start)
        stop = self.visit(node.stop)

        if (is_integer(start) or start is None) and (is_integer(stop) or stop is None):
            if isinstance(base, AstValueVector) or isinstance(base, AstVector):
                start = start.value if start is not None else None
                stop = stop.value if stop is not None else None
                if start is not None and stop is not None:
                    return _cl(makeVector(base.items[start:stop]), node)
                elif start is not None:
                    return _cl(makeVector(base.items[start:]), node)
                elif stop is not None:
                    return _cl(makeVector(base.items[:stop]), node)
                else:
                    return _cl(makeVector(base.items), node)

        return _cl(AstSlice(base, start, stop), node)
示例#24
0
 def visit_slice(self, node: AstSlice):
     prefix, base = self._visit_expr(node.base)
     a_prefix, a = self._visit_expr(node.start)
     b_prefix, b = self._visit_expr(node.stop)
     prefix += a_prefix
     prefix += b_prefix
     if base is node.base and a is node.start and b is node.stop:
         return node
     else:
         return _cl(makeBody(prefix, AstSlice(base, a, b)), node)
示例#25
0
    def visit_call_range(self, node: AstCall):
        args = [self.visit(arg) for arg in node.args]
        if 1 <= len(args) <= 2 and all([is_integer(arg) for arg in args]):
            if len(args) == 1:
                result = range(args[0].value)
            else:
                result = range(args[0].value, args[1].value)
            return _cl(AstValueVector(list(result)), node)

        return self.visit_call(node)
示例#26
0
 def visit_unary(self, node: AstUnary):
     # when applying an unary operator twice, it usually cancels, so we can get rid of it entirely
     if isinstance(node.item, AstUnary) and node.op == node.item.op:
         if node.op in ('not', '+', '-'):
             return self.visit(node.item.item)
     prefix, item = self._visit_expr(node.item)
     if item is node.item:
         return node
     else:
         prefix.append(AstUnary(node.op, item))
         return _cl(makeBody(prefix), node)
 def visit_in_scope(self, node: AstNode, is_loop: bool = False):
     items = []
     self.begin_scope(items, is_loop)
     if isinstance(node, AstBody):
         for item in node.items:
             items.append(self.visit(item))
     else:
         items.append(self.visit(node))
     result = _cl(makeBody(items), node)
     symbols = self.end_scope()
     return symbols, result
示例#28
0
 def visit_dict(self, node: AstDict):
     if len(node) > 0:
         prefix = []
         result = {}
         for key in node.items:
             p, i = self.visit(node.items[key])
             prefix += p
             result[key] = i
         return _cl(makeBody(prefix, AstDict(result)), node)
     else:
         return node
示例#29
0
 def visit_vector(self, node: AstVector):
     items = [self.visit(item) for item in node.items]
     if len(items) > 0 and all([isinstance(item, AstSample) and item.size is None for item in items]) and \
             all([item.dist == items[0].dist for item in items]):
         result = _cl(AstSample(items[0].dist, size=AstValue(len(items))),
                      node)
         original_name = getattr(node, 'original_name', None)
         if original_name is not None:
             result.original_name = original_name
         return result
     return makeVector(items)
示例#30
0
    def visit_compare(self, node: AstCompare):
        left = self.visit(node.left)
        right = self.visit(node.right)
        second_right = self.visit(node.second_right)

        if second_right is None:
            if is_unary_neg(left) and is_unary_neg(right):
                left, right = right.item, left.item
            elif is_unary_neg(left) and is_number(right):
                left, right = AstValue(-right.value), left.item
            elif is_number(left) and is_unary_neg(right):
                right, left = AstValue(-left.value), right.item

            if is_binary_add_sub(left) and is_number(right):
                left = self.visit(AstBinary(left, '-', right))
                right = AstValue(0)
            elif is_binary_add_sub(right) and is_number(left):
                right = self.visit(AstBinary(right, '-', left))
                left = AstValue(0)

        if is_number(left) and is_number(right):
            result = node.op_function(left.value, right.value)
            if second_right is None:
                return _cl(AstValue(result), node)

            elif is_number(second_right):
                result = result and node.op_function_2(right.value,
                                                       second_right.value)
                return _cl(AstValue(result), node)

        if node.op in ('in',
                       'not in') and is_vector(right) and second_right is None:
            op = node.op
            for item in right:
                if left == item:
                    return AstValue(True if op == 'in' else False)
            return AstValue(False if op == 'in' else True)

        return _cl(
            AstCompare(left, node.op, right, node.second_op, second_right),
            node)