Ejemplo n.º 1
0
    def test_postorder(self):
        expr1 = e.MINUS(e.MAX(e.NamedAttributeRef("salary")),
                        e.MIN(e.NamedAttributeRef("salary")))
        expr2 = e.PLUS(e.LOG(e.NamedAttributeRef("salary")),
                       e.ABS(e.NamedAttributeRef("salary")))

        def isAggregate(expr):
            return isinstance(expr, e.BuiltinAggregateExpression)

        def classname(expr):
            return expr.__class__.__name__

        e1cls = [x for x in expr1.postorder(classname)]

        e2cls = [x for x in expr2.postorder(classname)]

        e1any = any(expr1.postorder(isAggregate))

        e2any = any(expr2.postorder(isAggregate))

        self.assertEqual(str(
            e1cls
        ), """['NamedAttributeRef', 'MAX', 'NamedAttributeRef', 'MIN', 'MINUS']"""
                         )  # noqa
        self.assertEqual(str(
            e2cls
        ), """['NamedAttributeRef', 'LOG', 'NamedAttributeRef', 'ABS', 'PLUS']"""
                         )  # noqa
        self.assertEqual(e1any, True)
        self.assertEqual(e2any, False)
Ejemplo n.º 2
0
 def p_sexpr_id(p):
     'sexpr : unreserved_id'
     try:
         # Check for zero-argument function
         p[0] = Parser.resolve_function(p, p[1], [])
     except:
         # Resolve as an attribute reference
         p[0] = sexpr.NamedAttributeRef(p[1])
Ejemplo n.º 3
0
    def add_state_func(p, name, args, inits, updates, emitters, is_aggregate):
        """Register a stateful apply or UDA.

        :param name: The name of the function
        :param args: A list of function argument names (strings)
        :param inits: A list of NaryEmitArg that describe init logic; each
        should contain exactly one emit expression.
        :param updates: A list of Expression that describe update logic
        :param emitters: An Expression list that returns the final results.
        If None, all statemod variables are returned in the order specified.
        :param is_aggregate: True if the state_func is a UDA

        TODO: de-duplicate logic from add_udf.
        """
        lineno = p.lineno(0)
        if name in Parser.udf_functions:
            raise DuplicateFunctionDefinitionException(name, lineno)
        if len(args) != len(set(args)):
            raise DuplicateVariableException(name, lineno)
        if len(inits) != len(updates):
            raise BadApplyDefinitionException(name, lineno)

        # Unpack the update, init expressions into a statemod dictionary
        statemods = collections.OrderedDict()
        for init, update in zip(inits, updates):
            if not isinstance(init, emitarg.NaryEmitArg):
                raise IllegalWildcardException(name, lineno)

            if len(init.sexprs) != 1:
                raise NestedTupleExpressionException(lineno)

            # Init, update expressions contain tuples or contain aggregates
            check_simple_expression(init.sexprs[0], lineno)
            check_simple_expression(update, lineno)

            if not init.column_names:
                raise UnnamedStateVariableException(name, lineno)

            # check for duplicate variable definitions
            sm_name = init.column_names[0]
            if sm_name in statemods or sm_name in args:
                raise DuplicateVariableException(name, lineno)

            statemods[sm_name] = (init.sexprs[0], update)

        # Check for undefined variables:
        #  - Init expressions cannot reference any variables.
        #  - Update expression can reference function arguments and state
        #    variables.
        #  - The emitter expressions can reference state variables.
        allvars = statemods.keys() + args
        for init_expr, update_expr in statemods.itervalues():
            Parser.check_for_undefined(p, name, init_expr, [])
            Parser.check_for_undefined(p, name, update_expr, allvars)

        if emitters is None:
            emitters = [sexpr.NamedAttributeRef(v) for v in statemods.keys()]

        for e in emitters:
            Parser.check_for_undefined(p, name, e, statemods.keys())
            check_simple_expression(e, lineno)

        # If the function is a UDA, wrap the output expression(s) so
        # downstream users can distinguish stateful apply from
        # aggregate expressions.
        if is_aggregate:
            emitters = [sexpr.UdaAggregateExpression(e) for e in emitters]

        assert len(emitters) > 0
        if len(emitters) == 1:
            emit_op = emitters[0]
        else:
            emit_op = TupleExpression(emitters)

        Parser.udf_functions[name] = StatefulFunc(args, statemods, emit_op)