Esempio n. 1
0
def define_function(f, *param_types, _name=None):
    mname = _name if _name is not None else Env.get_uid()
    param_names = [Env.get_uid() for _ in param_types]
    body = f(*(construct_expr(Ref(pn), pt)
               for pn, pt in zip(param_names, param_types)))
    ret_type = body.dtype

    r = CSERenderer(stop_at_jir=True)
    code = r(body._ir)
    jbody = body._ir.parse(code,
                           ref_map=dict(zip(param_names, param_types)),
                           ir_map=r.jirs)

    Env.hail().expr.ir.functions.IRFunctionRegistry.pyRegisterIR(
        mname, param_names, [pt._parsable_string() for pt in param_types],
        ret_type._parsable_string(), jbody)
    register_session_function(mname, param_types, ret_type)

    @typecheck(args=expr_any)
    def f(*args):
        indices, aggregations = unify_all(*args)
        return construct_expr(Apply(mname, ret_type, *(a._ir for a in args)),
                              ret_type, indices, aggregations)

    return Function(f, param_types, ret_type, mname)
Esempio n. 2
0
 def test_agg_let(self):
     agg = ir.ApplyAggOp('AggOp', [], [ir.Ref('foo')])
     sum = ir.ApplyBinaryPrimOp('+', agg, agg)
     agglet = ir.AggLet('foo', ir.I32(2), sum, False)
     expected = ('(AggLet foo False (I32 2)'
                 ' (Let __cse_1 (ApplyAggOp AggOp () ((Ref foo)))'
                 ' (ApplyBinaryPrimOp `+` (Ref __cse_1) (Ref __cse_1))))')
     assert expected == CSERenderer()(agglet)
Esempio n. 3
0
 def test_cse(self):
     x = ir.I32(5)
     x = ir.ApplyBinaryPrimOp('+', x, x)
     expected = ('(Let __cse_1 (I32 5)'
                 ' (ApplyBinaryPrimOp `+`'
                 ' (Ref __cse_1)'
                 ' (Ref __cse_1)))')
     assert expected == CSERenderer()(x)
Esempio n. 4
0
 def test_refs(self):
     ref = ir.Ref('row')
     x = ir.TableMapRows(
         ir.TableRange(10, 1),
         ir.MakeStruct([('foo', ir.GetField(ref, 'idx')),
                        ('bar', ir.GetField(ref, 'idx'))]))
     expected = ('(TableMapRows (TableRange 10 1)'
                 ' (MakeStruct'
                 ' (foo (GetField idx (Ref row)))'
                 ' (bar (GetField idx (Ref row)))))')
     assert expected == CSERenderer()(x)
Esempio n. 5
0
    def register_ir_function(self, name, type_parameters, argument_names, argument_types, return_type, body):
        r = CSERenderer(stop_at_jir=True)
        code = r(body._ir)
        jbody = (self._parse_value_ir(code, ref_map=dict(zip(argument_names, argument_types)), ir_map=r.jirs))

        Env.hail().expr.ir.functions.IRFunctionRegistry.pyRegisterIR(
            name,
            [ta._parsable_string() for ta in type_parameters],
            argument_names, [pt._parsable_string() for pt in argument_types],
            return_type._parsable_string(),
            jbody)
Esempio n. 6
0
 def test_shadowing(self):
     x = ir.GetField(ir.Ref('row'), 'idx')
     sum = ir.ApplyBinaryPrimOp('+', x, x)
     inner = ir.Let('row', sum, sum)
     outer = ir.Let('row', ir.I32(5), inner)
     expected = (
         '(Let row (I32 5)'
         ' (Let __cse_1 (GetField idx (Ref row))'
         ' (Let row (ApplyBinaryPrimOp `+` (Ref __cse_1) (Ref __cse_1))'
         ' (Let __cse_2 (GetField idx (Ref row))'
         ' (ApplyBinaryPrimOp `+` (Ref __cse_2) (Ref __cse_2))))))')
     assert expected == CSERenderer()(outer)
Esempio n. 7
0
 def test_agg_cse(self):
     x = ir.GetField(ir.Ref('row'), 'idx')
     inner_sum = ir.ApplyBinaryPrimOp('+', x, x)
     agg = ir.ApplyAggOp('AggOp', [], [], [inner_sum])
     outer_sum = ir.ApplyBinaryPrimOp('+', agg, agg)
     table_agg = ir.TableAggregate(ir.TableRange(5, 1), outer_sum)
     expected = ('(TableAggregate (TableRange 5 1)'
                 ' (AggLet __cse_1 False (GetField idx (Ref row))'
                 ' (Let __cse_2 (ApplyAggOp AggOp () None'
                 ' ((ApplyBinaryPrimOp `+` (Ref __cse_1) (Ref __cse_1))))'
                 ' (ApplyBinaryPrimOp `+` (Ref __cse_2) (Ref __cse_2)))))')
     assert expected == CSERenderer()(table_agg)
Esempio n. 8
0
 def test_cse_ifs(self):
     outer_repeated = ir.I32(5)
     inner_repeated = ir.I32(1)
     sum = ir.ApplyBinaryPrimOp('+', inner_repeated, inner_repeated)
     prod = ir.ApplyBinaryPrimOp('*', sum, outer_repeated)
     cond = ir.If(ir.TrueIR(), prod, outer_repeated)
     expected = ('(If (True)'
                 ' (Let __cse_1 (I32 1)'
                 ' (ApplyBinaryPrimOp `*`'
                 ' (ApplyBinaryPrimOp `+` (Ref __cse_1) (Ref __cse_1))'
                 ' (I32 5)))'
                 ' (I32 5))')
     assert expected == CSERenderer()(cond)
Esempio n. 9
0
 def test_init_op(self):
     x = ir.I32(5)
     sum = ir.ApplyBinaryPrimOp('+', x, x)
     agg = ir.ApplyAggOp('CallStats', [sum], [sum])
     top = ir.ApplyBinaryPrimOp('+', sum, agg)
     expected = (
         '(Let __cse_1 (I32 5)'
         ' (AggLet __cse_3 False (I32 5)'
         ' (ApplyBinaryPrimOp `+`'
         ' (ApplyBinaryPrimOp `+` (Ref __cse_1) (Ref __cse_1))'
         ' (ApplyAggOp CallStats'
         ' ((Let __cse_2 (I32 5)'
         ' (ApplyBinaryPrimOp `+` (Ref __cse_2) (Ref __cse_2))))'
         ' ((ApplyBinaryPrimOp `+` (Ref __cse_3) (Ref __cse_3)))))))')
     assert expected == CSERenderer()(top)
Esempio n. 10
0
 def test_cse2(self):
     x = ir.I32(5)
     y = ir.I32(4)
     sum = ir.ApplyBinaryPrimOp('+', x, x)
     prod = ir.ApplyBinaryPrimOp('*', sum, y)
     div = ir.ApplyBinaryPrimOp('/', prod, sum)
     expected = (
         '(Let __cse_1 (I32 5)'
         ' (Let __cse_2 (ApplyBinaryPrimOp `+` (Ref __cse_1) (Ref __cse_1))'
         ' (ApplyBinaryPrimOp `/`'
         ' (ApplyBinaryPrimOp `*`'
         ' (Ref __cse_2)'
         ' (I32 4))'
         ' (Ref __cse_2))))')
     assert expected == CSERenderer()(div)
Esempio n. 11
0
 def test_agg_cse(self):
     x = ir.GetField(ir.Ref('row'), 'idx')
     inner_sum = ir.ApplyBinaryPrimOp('+', x, x)
     agg = ir.ApplyAggOp('AggOp', [], [inner_sum])
     outer_sum = ir.ApplyBinaryPrimOp('+', agg, agg)
     filter = ir.AggFilter(ir.TrueIR(), outer_sum, False)
     table_agg = ir.TableAggregate(ir.TableRange(5, 1), ir.MakeTuple([outer_sum, filter]))
     expected = (
         '(TableAggregate (TableRange 5 1)'
             ' (AggLet __cse_1 False (GetField idx (Ref row))'
             ' (AggLet __cse_3 False (ApplyBinaryPrimOp `+` (Ref __cse_1) (Ref __cse_1))'
             ' (Let __cse_2 (ApplyAggOp AggOp () ((Ref __cse_3)))'
             ' (MakeTuple (0 1)'
                 ' (ApplyBinaryPrimOp `+` (Ref __cse_2) (Ref __cse_2))'
                 ' (AggFilter False (True)'
                     ' (Let __cse_4 (ApplyAggOp AggOp () ((Ref __cse_3)))'
                     ' (ApplyBinaryPrimOp `+` (Ref __cse_4) (Ref __cse_4)))))))))')
     assert expected == CSERenderer()(table_agg)
Esempio n. 12
0
 def _render(self, ir):
     r = CSERenderer()
     assert len(r.jirs) == 0
     return r(ir)
Esempio n. 13
0
 def _to_java_ir(self, ir, parse):
     if not hasattr(ir, '_jir'):
         r = CSERenderer(stop_at_jir=True)
         # FIXME parse should be static
         ir._jir = parse(r(ir), ir_map=r.jirs)
     return ir._jir