def __call__(self, *args, **kwargs) -> Union[JaxArray, List[JaxArray]]: """Execute the sequence of operations contained on ``*args`` and ``**kwargs`` and return result.""" if not self: return args if len(args) > 1 else args[0] for f in self[:-1]: args = f(*args, **util.local_kwargs(kwargs, f)) if not isinstance(args, tuple): args = (args,) return self[-1](*args, **util.local_kwargs(kwargs, self[-1]))
def run_layer(layer: int, f: Callable, args: List, kwargs: Dict): try: return f(*args, **util.local_kwargs(kwargs, f)) except Exception as e: raise type(e)(f'Sequential layer[{layer}] {f} {e}') from e