def testPopPushFusionPrettyPrint(self): # Testing two things: That pop-push fusion does the expected thing, and that # push-skipping PrimOps print with exclamation marks as expected. This test # is likely to be brittle, and may want to be rearranged later. prog = test_programs.fibonacci_program() fused = stack.fuse_pop_push(prog) self.verify_program_pretty_print(fib_fused_pretty, fused)
def program_lowered(self, main, sig=None, backend=None): """Constructs a lowered `instructions.Program` for this `Context`. This constructs the program with `self.program(main)`, and the performs type inference, optimization, and lowering, to emit a result that can be executed (or staged) by the auto-batching VM. The point of having this as a method in its own right is that it caches the compilation on the types of the arguments. If either `sig` or `backend` are omitted or `None`, type inference is skipped. The result is not executable, but it can be enlightening to inspect. Args: main: Python string name of the function that should be the entry point. sig: A `list` of (patterns of) `instructions.TensorType` aligned with the formal parameters to `main`. backend: Backend implementation. Returns: prog: An `instructions.Program` representing the batched computation defined by all the functions decorated with `batch` in this `Context` so far. Suitable for execution or staging on real data by the auto-batching VM. """ module = self.module() prog = module.program(main) if self._lowering_cache is not None: key, result = self._lowering_cache if key == (module, main, sig, backend): return result else: # Clear the module and compile caches as well, because of b/119122199 self._module = None self._compile_cache = None module = self.module() prog = module.program(main) if sig is not None and backend is not None: typed = ab_type_inference.infer_types_from_signature( prog, sig, backend) else: typed = prog alloc = allocation_strategy.optimize(typed) lowered = lowering.lower_function_calls(alloc) result = stack.fuse_pop_push(lowered) self._lowering_cache = ((module, main, sig, backend), result) return result