def test_count_exprs(): assert ast.count_exprs( ast.BinNumExpr(ast.NumVal(1), ast.NumVal(2), ast.BinNumOpType.ADD) ) == 3 assert ast.count_exprs( ast.ExpExpr(ast.NumVal(2)) ) == 2 assert ast.count_exprs( ast.VectorVal([ ast.NumVal(2), ast.TanhExpr(ast.NumVal(3)) ]) ) == 4 assert ast.count_exprs( ast.IfExpr( ast.CompExpr(ast.NumVal(2), ast.NumVal(0), ast.CompOpType.GT), ast.NumVal(3), ast.NumVal(4), ) ) == 6 assert ast.count_exprs(ast.NumVal(1)) == 1
def test_count_exprs_exclude_list(): assert ast.count_exprs(ast.BinNumExpr(ast.NumVal(1), ast.NumVal(2), ast.BinNumOpType.ADD), exclude_list={ast.BinExpr, ast.NumVal}) == 0 assert ast.count_exprs(ast.BinNumExpr(ast.NumVal(1), ast.NumVal(2), ast.BinNumOpType.ADD), exclude_list={ast.BinNumExpr}) == 2
def _calc_bin_depth_threshold(self, expr): # The logic below counts the number of non-binary expressions # in a non-recursive branch of a binary expression to account # for large tree-like models and adjust the bin depth threshold # if necessary. cnt = None if not isinstance(expr.left, ast.BinExpr): cnt = ast.count_exprs(expr.left, exclude_list={ast.BinExpr}) elif not isinstance(expr.right, ast.BinExpr): cnt = ast.count_exprs(expr.right, exclude_list={ast.BinExpr}) if cnt and cnt < self.ast_size_per_subroutine_threshold: return math.ceil(self.ast_size_per_subroutine_threshold / cnt) return self.bin_depth_threshold
def _adjust_ast_check_frequency(self, expr): """ The logic below counts the number of non-binary expressions in a non-recursive branch of a binary expression to account for large tree-like models and adjust the size check frequency if necessary. """ cnt = None if not isinstance(expr.left, ast.BinExpr): cnt = ast.count_exprs(expr.left, exclude_list={ast.BinExpr}) elif not isinstance(expr.right, ast.BinExpr): cnt = ast.count_exprs(expr.right, exclude_list={ast.BinExpr}) if cnt and cnt < self.ast_size_per_subroutine_threshold: return math.ceil(self.ast_size_per_subroutine_threshold / cnt) return self.ast_size_check_frequency
def test_count_all_exprs_types(): expr = ast.BinVectorNumExpr( ast.BinVectorExpr( ast.VectorVal([ ast.ExpExpr(ast.NumVal(2)), ast.SqrtExpr(ast.NumVal(2)), ast.PowExpr(ast.NumVal(2), ast.NumVal(3)), ast.TanhExpr(ast.NumVal(1)), ast.BinNumExpr( ast.NumVal(0), ast.FeatureRef(0), ast.BinNumOpType.ADD) ]), ast.VectorVal([ ast.NumVal(1), ast.NumVal(2), ast.NumVal(3), ast.NumVal(4), ast.FeatureRef(1) ]), ast.BinNumOpType.SUB), ast.IfExpr( ast.CompExpr(ast.NumVal(2), ast.NumVal(0), ast.CompOpType.GT), ast.NumVal(3), ast.NumVal(4), ), ast.BinNumOpType.MUL) assert ast.count_exprs(expr) == 27
def bin_depth_threshold_hook(self, expr, **kwargs): # The condition below is a sanity check to ensure that the expression # is actually worth moving into a separate subroutine. if ast.count_exprs(expr) > self.ast_size_per_subroutine_threshold: function_name = self._get_subroutine_name() self.enqueue_subroutine(function_name, expr) return self._cg.function_invocation( function_name, self._feature_array_name) else: return self._do_interpret(expr, **kwargs)
def _pre_interpret_hook(self, expr, ast_size_check_counter=0, **kwargs): if isinstance(expr, ast.BinExpr) and not expr.to_reuse: frequency = self._adjust_ast_check_frequency(expr) self.ast_size_check_frequency = min(frequency, self.ast_size_check_frequency) ast_size_check_counter += 1 if ast_size_check_counter >= self.ast_size_check_frequency: ast_size_check_counter = 0 ast_size = ast.count_exprs(expr) if ast_size > self.ast_size_per_subroutine_threshold: sub_name = self._get_subroutine_name() self.enqueue_subroutine(sub_name, expr) return self._cg.sub_invocation(sub_name), kwargs kwargs['ast_size_check_counter'] = ast_size_check_counter return BaseToCodeInterpreter._pre_interpret_hook(self, expr, **kwargs)
def test_count_all_exprs_types(): assert ast.count_exprs(EXPR_WITH_ALL_EXPRS) == 43