def visit_unary(self, node: AstUnary): op = node.op if op == '+': return self.visit(node.item) if op == 'not': item = node.item._visit_expr(self) if isinstance(item, AstCompare) and item.second_right is None: return self.visit( _cl(AstCompare(item.left, item.neg_op, item.right), node)) if isinstance(item, AstBinary) and item.op in ('and', 'or'): return self.visit( _cl( AstBinary(AstUnary('not', item.left), 'and' if item.op == 'or' else 'or', AstUnary('not', item.right)), node)) if is_boolean(item): return _cl(AstValue(not item.value), node) if isinstance(node.item, AstUnary) and op == node.item.op: return self.visit(node.item.item) item = self.visit(node.item) if is_number(item): if op == '-': return _cl(AstValue(-item.value), node) if item is node.item: return node else: return node.clone(item=item)
def visit_list_for(self, node:AstListFor): source = self.visit(node.source) if is_vector(source): src_len = len(source) else: src_type = self.get_type(source) if isinstance(src_type, ppl_types.SequenceType): src_len = src_type.size else: src_len = None if node.test is None: if node.target == '_' and src_len is not None: if isinstance(node.expr, AstSample) and node.expr.size is None: return self.visit(node.expr.clone(size=AstValue(src_len))) else: return self.visit(_cl(makeVector([node.expr for _ in range(src_len)]), node)) if is_vector(source): result = makeVector([AstLet(node.target, item, node.expr, original_target=node.original_target) for item in source]) return self.visit(_cl(result, node)) elif src_len is not None: result = makeVector([AstLet(node.target, makeSubscript(source, i), node.expr, original_target=node.original_target) for i in range(src_len)]) return self.visit(_cl(result, node)) for name in get_info(node.expr).changed_vars: self.lock_name(name) test = self.visit(node.test) expr = self.visit(node.expr) return _cl(AstListFor(node.target, source, expr, test, original_target=node.original_target), node)
def visit_call_clojure_core_concat(self, node: AstCall): import itertools if not node.has_keyword_args: args = [self.visit(arg) for arg in node.args] if all([is_string(item) for item in args]): return _cl(AstValue(''.join([item.value for item in args])), node) elif all([isinstance(item, AstValueVector) for item in args]): return _cl( AstValue( list(itertools.chain([item.value for item in args]))), node) elif all([is_vector(item) for item in args]): args = [ item if isinstance(item, AstVector) else item.to_vector() for item in args ] return _cl( AstValue( list(itertools.chain([item.value for item in args]))), node) return self.visit_call(node)
def clean_locals(ast, f_locals): if isinstance(ast, AstBody): items = ast.items[:] free_vars = [get_info(node).free_vars for node in items] i = 0 while i < len(items): if isinstance(items[i], AstDef): name = items[i].name if name in f_locals and all( [name not in fv for fv in free_vars]): del items[i] del free_vars[i] continue i += 1 if len(items) < len(ast.items): return _cl(makeBody(items), ast) else: return ast elif isinstance(ast, AstReturn): value = clean_locals(ast.value, f_locals) if value is not ast.value: return _cl(AstReturn(value), ast) else: return ast else: return ast
def visit_return(self, node: AstReturn): value = self.visit(node.value) if isinstance(value, AstBody): items = value.items ret = self.visit(_cl(AstReturn(items[-1]), node)) return _cl(makeBody(items[:-1], ret), value) elif isinstance(value, AstLet): with self.create_lock(value.target): ret = self.visit(_cl(AstReturn(value.body), node)) return _cl(AstLet(value.target, value.source, ret), value) if value is not node.value: return _cl(AstReturn(value), node) else: return node
def visit_call_abs(self, node: AstCall): if node.arg_count == 1 and not node.has_keyword_args: arg = self.visit_expr(node.args[0]) if isinstance(arg, AstValue): return _cl(AstValue(abs(arg.value)), node) return self.visit_call(node)
def visit_call(self, node: AstCall): function = self.visit(node.function) prefix, args = self.parse_args(node.args) if isinstance(function, AstFunction) and all( [not get_info(arg).has_changed_vars for arg in args]): self.define_all(function.parameters, args, vararg=function.vararg) result = self.visit(function.body) if function.f_locals is not None: result = clean_locals(result, function.f_locals) if get_info(result).return_count == 1: if isinstance(result, AstReturn): result = result.value result = result if result is not None else AstValue(None) if len(prefix) > 0: result = makeBody(prefix, result) return result elif isinstance(result, AstBody) and result.last_is_return: items = prefix + result.items[:-1] result = result.items[-1].value result = result if result is not None else AstValue(None) return makeBody(items, result) elif isinstance(function, AstDict): if len(args) != 1 or node.has_keyword_args: raise TypeError( "dict access requires exactly one argument ({} given)". format(node.arg_count)) return _cl(makeSubscript(function, args[0]), node) result = node.clone(function=function, args=args) return makeBody(prefix, result)
def visit_let(self, node: AstLet): prefix, source = self._visit_expr(node.source) body = self.visit(node.body) if source is node.source and body is node.body: return node else: return _cl(makeBody(prefix, AstLet(node.target, source, body, original_target=node.original_target)), node)
def visit_import(self, node: AstImport): module_name, names = ppl_namespaces.namespace_from_module( node.module_name) if node.imported_names is not None: if node.alias is None: for name in node.imported_names: self.define( name, AstSymbol("{}.{}".format(module_name, name), predef=True)) else: self.define( node.alias, AstSymbol("{}.{}".format(module_name, node.imported_names[0]), predef=True)) else: bindings = { key: AstSymbol("{}.{}".format(module_name, key), predef=True) for key in names } ns = AstNamespace(module_name, bindings) self.define(node.module_name, ns) return _cl(AstImport(module_name), node)
def visit_call_math_sqrt(self, node: AstCall): if node.arg_count == 1: value = self.visit_expr(node.args[0]) if isinstance(value, AstValue): return _cl(AstValue(math.sqrt(value.value)), node) return self.visit_call(node)
def visit_observe(self, node: AstObserve): dist = self.visit(node.dist) value = self.visit(node.value) if dist is node.dist and value is node.value: return node else: return _cl(AstObserve(dist, value), node)
def visit_while(self, node: AstWhile): test = self.visit(node.test) body = self.visit(node.body) if test is node.test and body is node.body: return node else: return _cl(AstWhile(test, body), node)
def visit_binary(self, node: AstBinary): left = self.visit(node.left) right = self.visit(node.right) if left is node.left and right is node.right: return node else: return _cl(AstBinary(left, node.op, right), node)
def visit_compare(self, node: AstCompare): left = self.visit(node.left) right = self.visit(node.right) if left is node.left and right is node.right: return node else: return _cl(AstCompare(left, node.op, right), node)
def visit_vector(self, node: AstVector): items = [self.visit(item) for item in node.items] if len(items) > 0 and all([isinstance(item, AstSample) and item.size is None for item in items]) and \ all([item.dist == items[0].dist for item in items]): return _cl(AstSample(items[0].dist, size=AstValue(len(items))), node) return makeVector(items)
def visit_sample(self, node: AstSample): dist = self.visit(node.dist) size = self.visit(node.size) if dist is not node.dist or size is not node.size: return _cl(AstSample(dist, size=size), node) else: return node
def visit_function(self, node:AstFunction): with self.create_lock(): self.lock_all() body = self.visit(node.body) if body is not node.body: return _cl(AstFunction(node.name, node.parameters, body, vararg=node.vararg, doc_string=node.doc_string, f_locals=node.f_locals), node) return node
def visit_subscript(self, node: AstSubscript): base_prefix, base = self._visit_expr(node.base) index_prefix, index = self._visit_expr(node.index) if base is node.base and index is node.index: return node else: prefix = base_prefix + index_prefix return _cl(makeBody(prefix, makeSubscript(base, index)), node)
def visit_for(self, node: AstFor): prefix, source = self._visit_expr(node.source) body = self.visit(node.body) target = node.target if node.target in get_info(body).free_vars else '_' if target is node.target and source is node.source and body is node.body: return node else: return _cl(makeBody(prefix, AstFor(target, source, body, original_target=node.original_target)), node)
def visit_for(self, node:AstFor): source = self.visit(node.source) if is_vector(source): result = makeBody([AstLet(node.target, item, node.body) for item in source]) return self.visit(_cl(result, node)) else: src_type = self.get_type(source) if isinstance(src_type, ppl_types.SequenceType) and src_type.size is not None: result = makeBody([ AstLet(node.target, makeSubscript(source, i), node.body, original_target=node.original_target) for i in range(src_type.size) ]) return self.visit(_cl(result, node)) for name in get_info(node.body).changed_vars: self.lock_name(name) body = self.visit(node.body) return node.clone(source=source, body=body)
def visit_binary(self, node:AstBinary): l_prefix, left = self._visit_expr(node.left) r_prefix, right = self._visit_expr(node.right) prefix = l_prefix + r_prefix if left is node.left and right is node.right: return node else: prefix.append(AstBinary(left, node.op, right)) return _cl(makeBody(prefix), node)
def visit_let(self, node:AstLet): if count_variable_usage(node.target, node.body) == 0: return self.visit(_cl(makeBody(node.source, node.body), node)) source = self.visit_expr(node.source) src_info = get_info(source) if isinstance(source, AstBody) and len(source) > 1: result = node.clone(source=source.items[-1]) result = _cl(makeBody(source.items[:-1], result), node.source) return self.visit(result) elif src_info.is_independent(get_info(node.body)) and \ (count_variable_usage(node.target, node.body) == 1 or src_info.can_embed): print("CAN EMBED", source, src_info.can_embed, count_variable_usage(node.target, node.body), node.target) print(" " * 20, "-->", node.body) self.define(node.target, self.visit(node.source)) return _cl(self.visit(node.body), node) return self.visit(makeBody(AstDef(node.target, node.source), node.body))
def visit_slice(self, node:AstSlice): base = self.visit(node.base) start = self.visit(node.start) stop = self.visit(node.stop) if (is_integer(start) or start is None) and (is_integer(stop) or stop is None): if isinstance(base, AstValueVector) or isinstance(base, AstVector): start = start.value if start is not None else None stop = stop.value if stop is not None else None if start is not None and stop is not None: return _cl(makeVector(base.items[start:stop]), node) elif start is not None: return _cl(makeVector(base.items[start:]), node) elif stop is not None: return _cl(makeVector(base.items[:stop]), node) else: return _cl(makeVector(base.items), node) return _cl(AstSlice(base, start, stop), node)
def visit_slice(self, node: AstSlice): prefix, base = self._visit_expr(node.base) a_prefix, a = self._visit_expr(node.start) b_prefix, b = self._visit_expr(node.stop) prefix += a_prefix prefix += b_prefix if base is node.base and a is node.start and b is node.stop: return node else: return _cl(makeBody(prefix, AstSlice(base, a, b)), node)
def visit_call_range(self, node: AstCall): args = [self.visit(arg) for arg in node.args] if 1 <= len(args) <= 2 and all([is_integer(arg) for arg in args]): if len(args) == 1: result = range(args[0].value) else: result = range(args[0].value, args[1].value) return _cl(AstValueVector(list(result)), node) return self.visit_call(node)
def visit_unary(self, node: AstUnary): # when applying an unary operator twice, it usually cancels, so we can get rid of it entirely if isinstance(node.item, AstUnary) and node.op == node.item.op: if node.op in ('not', '+', '-'): return self.visit(node.item.item) prefix, item = self._visit_expr(node.item) if item is node.item: return node else: prefix.append(AstUnary(node.op, item)) return _cl(makeBody(prefix), node)
def visit_in_scope(self, node: AstNode, is_loop: bool = False): items = [] self.begin_scope(items, is_loop) if isinstance(node, AstBody): for item in node.items: items.append(self.visit(item)) else: items.append(self.visit(node)) result = _cl(makeBody(items), node) symbols = self.end_scope() return symbols, result
def visit_dict(self, node: AstDict): if len(node) > 0: prefix = [] result = {} for key in node.items: p, i = self.visit(node.items[key]) prefix += p result[key] = i return _cl(makeBody(prefix, AstDict(result)), node) else: return node
def visit_vector(self, node: AstVector): items = [self.visit(item) for item in node.items] if len(items) > 0 and all([isinstance(item, AstSample) and item.size is None for item in items]) and \ all([item.dist == items[0].dist for item in items]): result = _cl(AstSample(items[0].dist, size=AstValue(len(items))), node) original_name = getattr(node, 'original_name', None) if original_name is not None: result.original_name = original_name return result return makeVector(items)
def visit_compare(self, node: AstCompare): left = self.visit(node.left) right = self.visit(node.right) second_right = self.visit(node.second_right) if second_right is None: if is_unary_neg(left) and is_unary_neg(right): left, right = right.item, left.item elif is_unary_neg(left) and is_number(right): left, right = AstValue(-right.value), left.item elif is_number(left) and is_unary_neg(right): right, left = AstValue(-left.value), right.item if is_binary_add_sub(left) and is_number(right): left = self.visit(AstBinary(left, '-', right)) right = AstValue(0) elif is_binary_add_sub(right) and is_number(left): right = self.visit(AstBinary(right, '-', left)) left = AstValue(0) if is_number(left) and is_number(right): result = node.op_function(left.value, right.value) if second_right is None: return _cl(AstValue(result), node) elif is_number(second_right): result = result and node.op_function_2(right.value, second_right.value) return _cl(AstValue(result), node) if node.op in ('in', 'not in') and is_vector(right) and second_right is None: op = node.op for item in right: if left == item: return AstValue(True if op == 'in' else False) return AstValue(False if op == 'in' else True) return _cl( AstCompare(left, node.op, right, node.second_op, second_right), node)