def test_postorder(self): expr1 = e.MINUS(e.MAX(e.NamedAttributeRef("salary")), e.MIN(e.NamedAttributeRef("salary"))) expr2 = e.PLUS(e.LOG(e.NamedAttributeRef("salary")), e.ABS(e.NamedAttributeRef("salary"))) def isAggregate(expr): return isinstance(expr, e.BuiltinAggregateExpression) def classname(expr): return expr.__class__.__name__ e1cls = [x for x in expr1.postorder(classname)] e2cls = [x for x in expr2.postorder(classname)] e1any = any(expr1.postorder(isAggregate)) e2any = any(expr2.postorder(isAggregate)) self.assertEqual(str( e1cls ), """['NamedAttributeRef', 'MAX', 'NamedAttributeRef', 'MIN', 'MINUS']""" ) # noqa self.assertEqual(str( e2cls ), """['NamedAttributeRef', 'LOG', 'NamedAttributeRef', 'ABS', 'PLUS']""" ) # noqa self.assertEqual(e1any, True) self.assertEqual(e2any, False)
def p_sexpr_id(p): 'sexpr : unreserved_id' try: # Check for zero-argument function p[0] = Parser.resolve_function(p, p[1], []) except: # Resolve as an attribute reference p[0] = sexpr.NamedAttributeRef(p[1])
def add_state_func(p, name, args, inits, updates, emitters, is_aggregate): """Register a stateful apply or UDA. :param name: The name of the function :param args: A list of function argument names (strings) :param inits: A list of NaryEmitArg that describe init logic; each should contain exactly one emit expression. :param updates: A list of Expression that describe update logic :param emitters: An Expression list that returns the final results. If None, all statemod variables are returned in the order specified. :param is_aggregate: True if the state_func is a UDA TODO: de-duplicate logic from add_udf. """ lineno = p.lineno(0) if name in Parser.udf_functions: raise DuplicateFunctionDefinitionException(name, lineno) if len(args) != len(set(args)): raise DuplicateVariableException(name, lineno) if len(inits) != len(updates): raise BadApplyDefinitionException(name, lineno) # Unpack the update, init expressions into a statemod dictionary statemods = collections.OrderedDict() for init, update in zip(inits, updates): if not isinstance(init, emitarg.NaryEmitArg): raise IllegalWildcardException(name, lineno) if len(init.sexprs) != 1: raise NestedTupleExpressionException(lineno) # Init, update expressions contain tuples or contain aggregates check_simple_expression(init.sexprs[0], lineno) check_simple_expression(update, lineno) if not init.column_names: raise UnnamedStateVariableException(name, lineno) # check for duplicate variable definitions sm_name = init.column_names[0] if sm_name in statemods or sm_name in args: raise DuplicateVariableException(name, lineno) statemods[sm_name] = (init.sexprs[0], update) # Check for undefined variables: # - Init expressions cannot reference any variables. # - Update expression can reference function arguments and state # variables. # - The emitter expressions can reference state variables. allvars = statemods.keys() + args for init_expr, update_expr in statemods.itervalues(): Parser.check_for_undefined(p, name, init_expr, []) Parser.check_for_undefined(p, name, update_expr, allvars) if emitters is None: emitters = [sexpr.NamedAttributeRef(v) for v in statemods.keys()] for e in emitters: Parser.check_for_undefined(p, name, e, statemods.keys()) check_simple_expression(e, lineno) # If the function is a UDA, wrap the output expression(s) so # downstream users can distinguish stateful apply from # aggregate expressions. if is_aggregate: emitters = [sexpr.UdaAggregateExpression(e) for e in emitters] assert len(emitters) > 0 if len(emitters) == 1: emit_op = emitters[0] else: emit_op = TupleExpression(emitters) Parser.udf_functions[name] = StatefulFunc(args, statemods, emit_op)