예제 #1
0
def test_count_exprs():
    assert ast.count_exprs(
        ast.BinNumExpr(ast.NumVal(1), ast.NumVal(2), ast.BinNumOpType.ADD)
    ) == 3

    assert ast.count_exprs(
        ast.ExpExpr(ast.NumVal(2))
    ) == 2

    assert ast.count_exprs(
        ast.VectorVal([
            ast.NumVal(2),
            ast.TanhExpr(ast.NumVal(3))
        ])
    ) == 4

    assert ast.count_exprs(
        ast.IfExpr(
            ast.CompExpr(ast.NumVal(2), ast.NumVal(0), ast.CompOpType.GT),
            ast.NumVal(3),
            ast.NumVal(4),
        )
    ) == 6

    assert ast.count_exprs(ast.NumVal(1)) == 1
예제 #2
0
def test_count_exprs_exclude_list():
    assert ast.count_exprs(ast.BinNumExpr(ast.NumVal(1), ast.NumVal(2),
                                          ast.BinNumOpType.ADD),
                           exclude_list={ast.BinExpr, ast.NumVal}) == 0

    assert ast.count_exprs(ast.BinNumExpr(ast.NumVal(1), ast.NumVal(2),
                                          ast.BinNumOpType.ADD),
                           exclude_list={ast.BinNumExpr}) == 2
예제 #3
0
 def _calc_bin_depth_threshold(self, expr):
     # The logic below counts the number of non-binary expressions
     # in a non-recursive branch of a binary expression to account
     # for large tree-like models and adjust the bin depth threshold
     # if necessary.
     cnt = None
     if not isinstance(expr.left, ast.BinExpr):
         cnt = ast.count_exprs(expr.left, exclude_list={ast.BinExpr})
     elif not isinstance(expr.right, ast.BinExpr):
         cnt = ast.count_exprs(expr.right, exclude_list={ast.BinExpr})
     if cnt and cnt < self.ast_size_per_subroutine_threshold:
         return math.ceil(self.ast_size_per_subroutine_threshold / cnt)
     return self.bin_depth_threshold
예제 #4
0
 def _adjust_ast_check_frequency(self, expr):
     """
     The logic below counts the number of non-binary expressions
     in a non-recursive branch of a binary expression to account
     for large tree-like models and adjust the size check frequency
     if necessary.
     """
     cnt = None
     if not isinstance(expr.left, ast.BinExpr):
         cnt = ast.count_exprs(expr.left, exclude_list={ast.BinExpr})
     elif not isinstance(expr.right, ast.BinExpr):
         cnt = ast.count_exprs(expr.right, exclude_list={ast.BinExpr})
     if cnt and cnt < self.ast_size_per_subroutine_threshold:
         return math.ceil(self.ast_size_per_subroutine_threshold / cnt)
     return self.ast_size_check_frequency
예제 #5
0
def test_count_all_exprs_types():
    expr = ast.BinVectorNumExpr(
        ast.BinVectorExpr(
            ast.VectorVal([
                ast.ExpExpr(ast.NumVal(2)),
                ast.SqrtExpr(ast.NumVal(2)),
                ast.PowExpr(ast.NumVal(2), ast.NumVal(3)),
                ast.TanhExpr(ast.NumVal(1)),
                ast.BinNumExpr(
                    ast.NumVal(0),
                    ast.FeatureRef(0),
                    ast.BinNumOpType.ADD)
            ]),
            ast.VectorVal([
                ast.NumVal(1),
                ast.NumVal(2),
                ast.NumVal(3),
                ast.NumVal(4),
                ast.FeatureRef(1)
            ]),
            ast.BinNumOpType.SUB),
        ast.IfExpr(
            ast.CompExpr(ast.NumVal(2), ast.NumVal(0), ast.CompOpType.GT),
            ast.NumVal(3),
            ast.NumVal(4),
        ),
        ast.BinNumOpType.MUL)

    assert ast.count_exprs(expr) == 27
예제 #6
0
 def bin_depth_threshold_hook(self, expr, **kwargs):
     # The condition below is a sanity check to ensure that the expression
     # is actually worth moving into a separate subroutine.
     if ast.count_exprs(expr) > self.ast_size_per_subroutine_threshold:
         function_name = self._get_subroutine_name()
         self.enqueue_subroutine(function_name, expr)
         return self._cg.function_invocation(
             function_name, self._feature_array_name)
     else:
         return self._do_interpret(expr, **kwargs)
예제 #7
0
파일: mixins.py 프로젝트: rubyroobs/m2vcl
    def _pre_interpret_hook(self, expr, ast_size_check_counter=0, **kwargs):
        if isinstance(expr, ast.BinExpr) and not expr.to_reuse:
            frequency = self._adjust_ast_check_frequency(expr)
            self.ast_size_check_frequency = min(frequency,
                                                self.ast_size_check_frequency)

            ast_size_check_counter += 1
            if ast_size_check_counter >= self.ast_size_check_frequency:
                ast_size_check_counter = 0
                ast_size = ast.count_exprs(expr)
                if ast_size > self.ast_size_per_subroutine_threshold:
                    sub_name = self._get_subroutine_name()
                    self.enqueue_subroutine(sub_name, expr)
                    return self._cg.sub_invocation(sub_name), kwargs

            kwargs['ast_size_check_counter'] = ast_size_check_counter

        return BaseToCodeInterpreter._pre_interpret_hook(self, expr, **kwargs)
예제 #8
0
def test_count_all_exprs_types():
    assert ast.count_exprs(EXPR_WITH_ALL_EXPRS) == 43