Beispiel #1
0
def even_odd_program():
  ab = dsl.ProgramBuilder()

  def pred_type(_):
    return instructions.TensorType(np.bool_, ())

  odd = ab.declare_function('odd', type_inference=pred_type)

  with ab.function('even', type_inference=pred_type) as even:
    ab.param('n')
    ab.var.cond = ab.primop(lambda n: n <= 0)
    with ab.if_(ab.var.cond, then_name='base-case'):
      ab.var.ans = ab.const(True)
    with ab.else_(else_name='recur', continue_name='finish'):
      ab.var.nm1 = ab.primop(lambda n: n - 1)
      ab.var.ans = ab.call(odd, [ab.var.nm1])
    ab.return_(ab.var.ans)

  with ab.define_function(odd):
    ab.param('n')
    ab.var.cond = ab.primop(lambda n: n <= 0)
    with ab.if_(ab.var.cond, then_name='base-case'):
      ab.var.ans = ab.const(False)
    with ab.else_(else_name='recur', continue_name='finish'):
      ab.var.nm1 = ab.primop(lambda n: n - 1)
      ab.var.ans = ab.call(even, [ab.var.nm1])
    ab.return_(ab.var.ans)

  prog = ab.program(main=even)
  return prog
Beispiel #2
0
    def module(self):
        """Constructs an `instructions.Module` for this `Context`.

    Returns:
      module: An `instructions.Module` representing the batched computation
        defined by all the functions decorated with `batch` in this `Context` so
        far.
    """
        if self._module is not None:
            return self._module
        ab = dsl.ProgramBuilder()
        function_objects = []
        for function, type_inference in self._tagged_functions:
            declared = ab.declare_function(function.__name__, type_inference)
            function_objects.append(declared)
        for function, _ in self._tagged_functions:
            name = function.__name__
            node, ctx = _parse_and_analyze(function, self.function_names())
            node = _AutoBatchingTransformer(self.function_names(), [
                scoped_name
                for scoped_name, _ in _environment(function, [name])
            ], ctx).visit(node)
            builder_module, _, _ = loader.load_ast(node)
            for scoped_name, val in _environment(function, [name]):
                builder_module.__dict__[scoped_name] = val
            builder = getattr(builder_module, name)
            builder(ab, function_objects)
        self._module = ab.module()
        return self._module
Beispiel #3
0
def synthetic_pattern_program():
  ab = dsl.ProgramBuilder()
  def my_type(_):
    int_ = instructions.TensorType(np.int64, ())
    return ((int_, int_), int_, (int_, int_))

  with ab.function('synthetic', type_inference=my_type) as syn:
    ab.param('batch_size_index')
    one, three, five = ab.locals_(3)
    ab((one, (five, three))).pattern = ab.primop(lambda: (1, (2, 3)))
    ab(((ab.var.four, five), ab.var.six)).pattern = ab.primop(
        lambda: ((4, 5), 6))
    ab.return_(((one, three), ab.var.four, (five, ab.var.six)))

  prog = ab.program(main=syn)
  return prog
Beispiel #4
0
def fibonacci_program():
  ab = dsl.ProgramBuilder()

  def fib_type(arg_types):
    return arg_types[0]

  with ab.function('fibonacci', type_inference=fib_type) as fibonacci:
    ab.param('n')
    ab.var.cond = ab.primop(lambda n: n > 1)
    with ab.if_(ab.var.cond, then_name='recur'):
      ab.var.nm1 = ab.primop(lambda n: n - 1)
      ab.var.fibm1 = ab.call(fibonacci, [ab.var.nm1])
      ab.var.nm2 = ab.primop(lambda n: n - 2)
      ab.var.fibm2 = ab.call(fibonacci, [ab.var.nm2])
      ab.var.ans = ab.primop(lambda fibm1, fibm2: fibm1 + fibm2)
    with ab.else_(else_name='base-case', continue_name='finish'):
      ab.var.ans = ab.const(1)
    ab.return_(ab.var.ans)

  prog = ab.program(main=fibonacci)
  return prog