Beispiel #1
0
    def _forward_abstract(self, input_signature):
        """Computes shapes and dtypes this layer would produce in a forward pass.

    Args:
      input_signature: `ShapeDtype` instance (if this layer takes one input)
          or list/tuple of `ShapeDtype` instances.

    Returns:
      Tuple of (output, state).

      The output part of the tuple is a `ShapeDtype` instance representing the
      shape and type of the output (if this layer has one output) or a tuple
      of `ShapeDtype` instances (if this layer has more than one output).
    """
        try:
            # Note: By using rng_signature in place of an rng, we avoid computing and
            # permanently storing in global memory a large number of dropout masks.
            # TODO(jonni): Check if using an rng still carries this cost.
            dummy_rng = fastmath.random.get_prng(0)
            rng_signature = ShapeDtype(dummy_rng.shape, dummy_rng.dtype)
            weights_signature = nested_map(signature, self.weights)
            state_signature = nested_map(signature, self.state)
            forward_infer_shapes = fastmath.abstract_eval(self.pure_fn)
            return forward_infer_shapes(input_signature, weights_signature,
                                        state_signature, rng_signature)
        except Exception:
            # TODO(lukaszkaiser): the choice of 7 is a heuristic, can we automate it?
            # Skipping 7 lines which are all JAX abstract'ifying wrappers.
            name, trace = self._name, _short_traceback(skip=7)
            raise LayerError(name, '_forward_abstract', self._caller,
                             input_signature, trace) from None
Beispiel #2
0
 def weights_and_state_signature(self, input_signature, unsafe=False):
     """Return a pair containing the signatures of weights and state."""
     rng, state, weights = self.rng, self.state, self.weights
     abstract_init = fastmath.abstract_eval(self.init)
     sig = abstract_init(input_signature)
     self.rng = rng
     if not unsafe:
         self.state, self.weights = state, weights
     return sig
Beispiel #3
0
 def weights_and_state_signature(self, input_signature):
     """Return a pair containing the signatures of weights and state."""
     rng = self.rng
     abstract_init = fastmath.abstract_eval(self.init)
     self.rng = rng
     return abstract_init(input_signature)