def _forward_abstract(self, input_signature): """Computes shapes and dtypes this layer would produce in a forward pass. Args: input_signature: `ShapeDtype` instance (if this layer takes one input) or list/tuple of `ShapeDtype` instances. Returns: Tuple of (output, state). The output part of the tuple is a `ShapeDtype` instance representing the shape and type of the output (if this layer has one output) or a tuple of `ShapeDtype` instances (if this layer has more than one output). """ try: # Note: By using rng_signature in place of an rng, we avoid computing and # permanently storing in global memory a large number of dropout masks. # TODO(jonni): Check if using an rng still carries this cost. dummy_rng = fastmath.random.get_prng(0) rng_signature = ShapeDtype(dummy_rng.shape, dummy_rng.dtype) weights_signature = nested_map(signature, self.weights) state_signature = nested_map(signature, self.state) forward_infer_shapes = fastmath.abstract_eval(self.pure_fn) return forward_infer_shapes(input_signature, weights_signature, state_signature, rng_signature) except Exception: # TODO(lukaszkaiser): the choice of 7 is a heuristic, can we automate it? # Skipping 7 lines which are all JAX abstract'ifying wrappers. name, trace = self._name, _short_traceback(skip=7) raise LayerError(name, '_forward_abstract', self._caller, input_signature, trace) from None
def weights_and_state_signature(self, input_signature, unsafe=False): """Return a pair containing the signatures of weights and state.""" rng, state, weights = self.rng, self.state, self.weights abstract_init = fastmath.abstract_eval(self.init) sig = abstract_init(input_signature) self.rng = rng if not unsafe: self.state, self.weights = state, weights return sig
def weights_and_state_signature(self, input_signature): """Return a pair containing the signatures of weights and state.""" rng = self.rng abstract_init = fastmath.abstract_eval(self.init) self.rng = rng return abstract_init(input_signature)