def batched(*args, **kwargs): """The batched function.""" # Accepting kwargs here because Python 2.7 gets confused otherwise. # See, e.g., # https://stackoverflow.com/questions/14003939/python-function-args-and-kwargs-with-other-specified-keyword-arguments # The whole strategy of adding arguments to the decorated function has # the unfortunate consequence that pylint uses the argument signature of # the decoratee to issue warnings, with the result that # disable=unexpected-keyword-arg needs to be added to use sites of # decorated functions, for example in `frontend_test.py`. # TODO(axch) Figure out how to avoid pylint warnings in callers. # Options: # - Add `backend` and `max_stack_depth` arguments to the `Context` # constructor, and use those as defaults here. # - Likely insufficient, because dynamic selection of the stack size # seems important. # - Carry the same information in a context manager, used like # with ctx.config(backend=TF_BACKEND, max_stack_depth=15): # some_decorated_function(normal, arguments) # - Just invoke batched functions from the `Context` object: # ctx.run(some_decorated_function, input_1, max_stack_depth=15) # - Maybe even make magic methods for the above? # ctx.some_decorated_function(input_1, max_stack_depth=15) max_stack_depth = kwargs.pop('max_stack_depth', 15) backend = kwargs.pop('backend', TF_BACKEND) dry_run = kwargs.pop('dry_run', False) stackless = kwargs.pop('stackless', False) block_code_cache = kwargs.pop('block_code_cache', {}) if self._in_dry_run: return _run_at_batch_size_one(backend, function, args) if dry_run: with self._dry_run(): return _run_at_batch_size_one(backend, function, args) if kwargs: msg = 'Auto-batched function given unexpected keyword arguments {}.' raise TypeError(msg.format(kwargs.keys())) sig = ab_type_inference.signature(self.program(main=name), args, backend) if stackless: compiled = self.program_compiled(main=name, sig=sig, backend=backend) return st.execute(compiled, backend, block_code_cache, *args) else: lowered = self.program_lowered(main=name, sig=sig, backend=backend) return vm.execute(lowered, args, max_stack_depth, backend, block_code_cache=block_code_cache)
def _execute(prog, inputs, stack_depth, backend): return vm.execute(prog, [inputs], max_stack_depth=stack_depth, backend=backend)