def even_odd_program(): ab = dsl.ProgramBuilder() def pred_type(_): return instructions.TensorType(np.bool_, ()) odd = ab.declare_function('odd', type_inference=pred_type) with ab.function('even', type_inference=pred_type) as even: ab.param('n') ab.var.cond = ab.primop(lambda n: n <= 0) with ab.if_(ab.var.cond, then_name='base-case'): ab.var.ans = ab.const(True) with ab.else_(else_name='recur', continue_name='finish'): ab.var.nm1 = ab.primop(lambda n: n - 1) ab.var.ans = ab.call(odd, [ab.var.nm1]) ab.return_(ab.var.ans) with ab.define_function(odd): ab.param('n') ab.var.cond = ab.primop(lambda n: n <= 0) with ab.if_(ab.var.cond, then_name='base-case'): ab.var.ans = ab.const(False) with ab.else_(else_name='recur', continue_name='finish'): ab.var.nm1 = ab.primop(lambda n: n - 1) ab.var.ans = ab.call(even, [ab.var.nm1]) ab.return_(ab.var.ans) prog = ab.program(main=even) return prog
def module(self): """Constructs an `instructions.Module` for this `Context`. Returns: module: An `instructions.Module` representing the batched computation defined by all the functions decorated with `batch` in this `Context` so far. """ if self._module is not None: return self._module ab = dsl.ProgramBuilder() function_objects = [] for function, type_inference in self._tagged_functions: declared = ab.declare_function(function.__name__, type_inference) function_objects.append(declared) for function, _ in self._tagged_functions: name = function.__name__ node, ctx = _parse_and_analyze(function, self.function_names()) node = _AutoBatchingTransformer(self.function_names(), [ scoped_name for scoped_name, _ in _environment(function, [name]) ], ctx).visit(node) builder_module, _, _ = loader.load_ast(node) for scoped_name, val in _environment(function, [name]): builder_module.__dict__[scoped_name] = val builder = getattr(builder_module, name) builder(ab, function_objects) self._module = ab.module() return self._module
def synthetic_pattern_program(): ab = dsl.ProgramBuilder() def my_type(_): int_ = instructions.TensorType(np.int64, ()) return ((int_, int_), int_, (int_, int_)) with ab.function('synthetic', type_inference=my_type) as syn: ab.param('batch_size_index') one, three, five = ab.locals_(3) ab((one, (five, three))).pattern = ab.primop(lambda: (1, (2, 3))) ab(((ab.var.four, five), ab.var.six)).pattern = ab.primop( lambda: ((4, 5), 6)) ab.return_(((one, three), ab.var.four, (five, ab.var.six))) prog = ab.program(main=syn) return prog
def fibonacci_program(): ab = dsl.ProgramBuilder() def fib_type(arg_types): return arg_types[0] with ab.function('fibonacci', type_inference=fib_type) as fibonacci: ab.param('n') ab.var.cond = ab.primop(lambda n: n > 1) with ab.if_(ab.var.cond, then_name='recur'): ab.var.nm1 = ab.primop(lambda n: n - 1) ab.var.fibm1 = ab.call(fibonacci, [ab.var.nm1]) ab.var.nm2 = ab.primop(lambda n: n - 2) ab.var.fibm2 = ab.call(fibonacci, [ab.var.nm2]) ab.var.ans = ab.primop(lambda fibm1, fibm2: fibm1 + fibm2) with ab.else_(else_name='base-case', continue_name='finish'): ab.var.ans = ab.const(1) ab.return_(ab.var.ans) prog = ab.program(main=fibonacci) return prog