def define_function(f, *param_types, _name=None): mname = _name if _name is not None else Env.get_uid() param_names = [Env.get_uid() for _ in param_types] body = f(*(construct_expr(Ref(pn), pt) for pn, pt in zip(param_names, param_types))) ret_type = body.dtype r = CSERenderer(stop_at_jir=True) code = r(body._ir) jbody = body._ir.parse(code, ref_map=dict(zip(param_names, param_types)), ir_map=r.jirs) Env.hail().expr.ir.functions.IRFunctionRegistry.pyRegisterIR( mname, param_names, [pt._parsable_string() for pt in param_types], ret_type._parsable_string(), jbody) register_session_function(mname, param_types, ret_type) @typecheck(args=expr_any) def f(*args): indices, aggregations = unify_all(*args) return construct_expr(Apply(mname, ret_type, *(a._ir for a in args)), ret_type, indices, aggregations) return Function(f, param_types, ret_type, mname)
def test_agg_let(self): agg = ir.ApplyAggOp('AggOp', [], [ir.Ref('foo')]) sum = ir.ApplyBinaryPrimOp('+', agg, agg) agglet = ir.AggLet('foo', ir.I32(2), sum, False) expected = ('(AggLet foo False (I32 2)' ' (Let __cse_1 (ApplyAggOp AggOp () ((Ref foo)))' ' (ApplyBinaryPrimOp `+` (Ref __cse_1) (Ref __cse_1))))') assert expected == CSERenderer()(agglet)
def test_cse(self): x = ir.I32(5) x = ir.ApplyBinaryPrimOp('+', x, x) expected = ('(Let __cse_1 (I32 5)' ' (ApplyBinaryPrimOp `+`' ' (Ref __cse_1)' ' (Ref __cse_1)))') assert expected == CSERenderer()(x)
def test_refs(self): ref = ir.Ref('row') x = ir.TableMapRows( ir.TableRange(10, 1), ir.MakeStruct([('foo', ir.GetField(ref, 'idx')), ('bar', ir.GetField(ref, 'idx'))])) expected = ('(TableMapRows (TableRange 10 1)' ' (MakeStruct' ' (foo (GetField idx (Ref row)))' ' (bar (GetField idx (Ref row)))))') assert expected == CSERenderer()(x)
def register_ir_function(self, name, type_parameters, argument_names, argument_types, return_type, body): r = CSERenderer(stop_at_jir=True) code = r(body._ir) jbody = (self._parse_value_ir(code, ref_map=dict(zip(argument_names, argument_types)), ir_map=r.jirs)) Env.hail().expr.ir.functions.IRFunctionRegistry.pyRegisterIR( name, [ta._parsable_string() for ta in type_parameters], argument_names, [pt._parsable_string() for pt in argument_types], return_type._parsable_string(), jbody)
def test_shadowing(self): x = ir.GetField(ir.Ref('row'), 'idx') sum = ir.ApplyBinaryPrimOp('+', x, x) inner = ir.Let('row', sum, sum) outer = ir.Let('row', ir.I32(5), inner) expected = ( '(Let row (I32 5)' ' (Let __cse_1 (GetField idx (Ref row))' ' (Let row (ApplyBinaryPrimOp `+` (Ref __cse_1) (Ref __cse_1))' ' (Let __cse_2 (GetField idx (Ref row))' ' (ApplyBinaryPrimOp `+` (Ref __cse_2) (Ref __cse_2))))))') assert expected == CSERenderer()(outer)
def test_agg_cse(self): x = ir.GetField(ir.Ref('row'), 'idx') inner_sum = ir.ApplyBinaryPrimOp('+', x, x) agg = ir.ApplyAggOp('AggOp', [], [], [inner_sum]) outer_sum = ir.ApplyBinaryPrimOp('+', agg, agg) table_agg = ir.TableAggregate(ir.TableRange(5, 1), outer_sum) expected = ('(TableAggregate (TableRange 5 1)' ' (AggLet __cse_1 False (GetField idx (Ref row))' ' (Let __cse_2 (ApplyAggOp AggOp () None' ' ((ApplyBinaryPrimOp `+` (Ref __cse_1) (Ref __cse_1))))' ' (ApplyBinaryPrimOp `+` (Ref __cse_2) (Ref __cse_2)))))') assert expected == CSERenderer()(table_agg)
def test_cse_ifs(self): outer_repeated = ir.I32(5) inner_repeated = ir.I32(1) sum = ir.ApplyBinaryPrimOp('+', inner_repeated, inner_repeated) prod = ir.ApplyBinaryPrimOp('*', sum, outer_repeated) cond = ir.If(ir.TrueIR(), prod, outer_repeated) expected = ('(If (True)' ' (Let __cse_1 (I32 1)' ' (ApplyBinaryPrimOp `*`' ' (ApplyBinaryPrimOp `+` (Ref __cse_1) (Ref __cse_1))' ' (I32 5)))' ' (I32 5))') assert expected == CSERenderer()(cond)
def test_init_op(self): x = ir.I32(5) sum = ir.ApplyBinaryPrimOp('+', x, x) agg = ir.ApplyAggOp('CallStats', [sum], [sum]) top = ir.ApplyBinaryPrimOp('+', sum, agg) expected = ( '(Let __cse_1 (I32 5)' ' (AggLet __cse_3 False (I32 5)' ' (ApplyBinaryPrimOp `+`' ' (ApplyBinaryPrimOp `+` (Ref __cse_1) (Ref __cse_1))' ' (ApplyAggOp CallStats' ' ((Let __cse_2 (I32 5)' ' (ApplyBinaryPrimOp `+` (Ref __cse_2) (Ref __cse_2))))' ' ((ApplyBinaryPrimOp `+` (Ref __cse_3) (Ref __cse_3)))))))') assert expected == CSERenderer()(top)
def test_cse2(self): x = ir.I32(5) y = ir.I32(4) sum = ir.ApplyBinaryPrimOp('+', x, x) prod = ir.ApplyBinaryPrimOp('*', sum, y) div = ir.ApplyBinaryPrimOp('/', prod, sum) expected = ( '(Let __cse_1 (I32 5)' ' (Let __cse_2 (ApplyBinaryPrimOp `+` (Ref __cse_1) (Ref __cse_1))' ' (ApplyBinaryPrimOp `/`' ' (ApplyBinaryPrimOp `*`' ' (Ref __cse_2)' ' (I32 4))' ' (Ref __cse_2))))') assert expected == CSERenderer()(div)
def test_agg_cse(self): x = ir.GetField(ir.Ref('row'), 'idx') inner_sum = ir.ApplyBinaryPrimOp('+', x, x) agg = ir.ApplyAggOp('AggOp', [], [inner_sum]) outer_sum = ir.ApplyBinaryPrimOp('+', agg, agg) filter = ir.AggFilter(ir.TrueIR(), outer_sum, False) table_agg = ir.TableAggregate(ir.TableRange(5, 1), ir.MakeTuple([outer_sum, filter])) expected = ( '(TableAggregate (TableRange 5 1)' ' (AggLet __cse_1 False (GetField idx (Ref row))' ' (AggLet __cse_3 False (ApplyBinaryPrimOp `+` (Ref __cse_1) (Ref __cse_1))' ' (Let __cse_2 (ApplyAggOp AggOp () ((Ref __cse_3)))' ' (MakeTuple (0 1)' ' (ApplyBinaryPrimOp `+` (Ref __cse_2) (Ref __cse_2))' ' (AggFilter False (True)' ' (Let __cse_4 (ApplyAggOp AggOp () ((Ref __cse_3)))' ' (ApplyBinaryPrimOp `+` (Ref __cse_4) (Ref __cse_4)))))))))') assert expected == CSERenderer()(table_agg)
def _render(self, ir): r = CSERenderer() assert len(r.jirs) == 0 return r(ir)
def _to_java_ir(self, ir, parse): if not hasattr(ir, '_jir'): r = CSERenderer(stop_at_jir=True) # FIXME parse should be static ir._jir = parse(r(ir), ir_map=r.jirs) return ir._jir