예제 #1
0
 def __call__(self, *args, **kwargs) -> Union[JaxArray, List[JaxArray]]:
     """Execute the sequence of operations contained on ``*args`` and ``**kwargs`` and return result."""
     if not self:
         return args if len(args) > 1 else args[0]
     for f in self[:-1]:
         args = f(*args, **util.local_kwargs(kwargs, f))
         if not isinstance(args, tuple):
             args = (args,)
     return self[-1](*args, **util.local_kwargs(kwargs, self[-1]))
예제 #2
0
파일: layers.py 프로젝트: google/objax
 def run_layer(layer: int, f: Callable, args: List, kwargs: Dict):
     try:
         return f(*args, **util.local_kwargs(kwargs, f))
     except Exception as e:
         raise type(e)(f'Sequential layer[{layer}] {f} {e}') from e