def _forward_abstract(self, input_signature): """Computes shapes and dtypes this layer would produce in a forward pass. Args: input_signature: `ShapeDtype` instance (if this layer takes one input) or list/tuple of `ShapeDtype` instances. Returns: Tuple of (output, state). The output part of the tuple is a `ShapeDtype` instance representing the shape and type of the output (if this layer has one output) or a tuple of `ShapeDtype` instances (if this layer has more than one output). """ try: # Note: By using rng_signature in place of an rng, we avoid computing and # permanently storing in global memory a large number of dropout masks. # TODO(jonni): Check if using an rng still carries this cost. dummy_rng = math.random.get_prng(0) rng_signature = ShapeDtype(dummy_rng.shape, dummy_rng.dtype) weight_signature = nested_map(signature, self.weights) forward_infer_shapes = math.abstract_eval(self.pure_fn) return forward_infer_shapes(input_signature, weight_signature, self.state, rng_signature) except Exception: # Skipping 13 lines which are all JAX abstract'ifying wrappers. name, trace = self._name, _short_traceback(skip=13) raise LayerError(name, '_forward_abstract', self._caller, input_signature, trace) from None
def _forward_abstract(self, input_signature): """Computes shapes and dtypes this layer would produce in a forward pass. Args: input_signature: A ShapeDtype instance (if this layer takes one input) or a list/tuple of ShapeDtype instances; signatures of inputs. Returns: A tuple of (output, state). The output part of the tuple is a ShapeDtype instance representing the shape and type of the output (if this layer has one output) or a tuple of ShapeDtype instances (if this layer has more than one output). """ try: # Beware: using an actual RNG (as opposed to this ShapeDtype stub) would # cause a large number of dropout masks to be computed and permanently # stored in global memory. rng = ShapeDtype((2, ), onp.uint32) def call_on_input(x, weights, state, rng): return self.forward_with_state(x, weights=weights, state=state, rng=rng) weight_signature = nested_map(signature, self.weights) s = math.abstract_eval(call_on_input)(input_signature, weight_signature, self.state, rng) return s except Exception: name, trace = self.__class__.__name__, _short_traceback(skip=3) raise LayerError(name, '_forward_abstract', self._caller, input_signature, trace)
def _forward_abstract(self, input_signature): """Computes shapes and dtypes this layer would produce in a forward pass. Args: input_signature: ShapeDtype instance (if this layer takes one input) or list/tuple of ShapeDtype instances. Returns: Tuple of (output, state). The output part of the tuple is a ShapeDtype instance representing the shape and type of the output (if this layer has one output) or a tuple of ShapeDtype instances (if this layer has more than one output). """ try: # Note: By using rng_signature in place of an rng, we avoid computing and # permanently storing in global memory a large number of dropout masks. # TODO(jonni): Check if using an rng still carries this cost. rng_signature = ShapeDtype((2, ), np.uint32) weight_signature = nested_map(signature, self.weights) forward_infer_shapes = math.abstract_eval(self.forward_with_state) return forward_infer_shapes(input_signature, weight_signature, self.state, rng_signature) except Exception as e: name, trace = self._name, _short_traceback(skip=3) raise LayerError(name, '_forward_abstract', self._caller, input_signature, trace) from e
def weights_and_state_signature(self, input_signature): """Return a pair containing the signatures of weights and state.""" abstract_init = math.abstract_eval(self.init) return abstract_init(input_signature)