Пример #1
0
    def _forward_abstract(self, input_signature):
        """Computes shapes and dtypes this layer would produce in a forward pass.

    Args:
      input_signature: `ShapeDtype` instance (if this layer takes one input)
          or list/tuple of `ShapeDtype` instances.

    Returns:
      Tuple of (output, state).

      The output part of the tuple is a `ShapeDtype` instance representing the
      shape and type of the output (if this layer has one output) or a tuple
      of `ShapeDtype` instances (if this layer has more than one output).
    """
        try:
            # Note: By using rng_signature in place of an rng, we avoid computing and
            # permanently storing in global memory a large number of dropout masks.
            # TODO(jonni): Check if using an rng still carries this cost.
            dummy_rng = math.random.get_prng(0)
            rng_signature = ShapeDtype(dummy_rng.shape, dummy_rng.dtype)
            weight_signature = nested_map(signature, self.weights)
            forward_infer_shapes = math.abstract_eval(self.pure_fn)
            return forward_infer_shapes(input_signature, weight_signature,
                                        self.state, rng_signature)
        except Exception:
            # Skipping 13 lines which are all JAX abstract'ifying wrappers.
            name, trace = self._name, _short_traceback(skip=13)
            raise LayerError(name, '_forward_abstract', self._caller,
                             input_signature, trace) from None
Пример #2
0
    def _forward_abstract(self, input_signature):
        """Computes shapes and dtypes this layer would produce in a forward pass.

    Args:
      input_signature: A ShapeDtype instance (if this layer takes one input)
          or a list/tuple of ShapeDtype instances; signatures of inputs.

    Returns:
      A tuple of (output, state).

      The output part of the tuple is a ShapeDtype instance representing the
      shape and type of the output (if this layer has one output) or a tuple
      of ShapeDtype instances (if this layer has more than one output).
    """
        try:
            # Beware: using an actual RNG (as opposed to this ShapeDtype stub) would
            # cause a large number of dropout masks to be computed and permanently
            # stored in global memory.
            rng = ShapeDtype((2, ), onp.uint32)

            def call_on_input(x, weights, state, rng):
                return self.forward_with_state(x,
                                               weights=weights,
                                               state=state,
                                               rng=rng)

            weight_signature = nested_map(signature, self.weights)
            s = math.abstract_eval(call_on_input)(input_signature,
                                                  weight_signature, self.state,
                                                  rng)
            return s
        except Exception:
            name, trace = self.__class__.__name__, _short_traceback(skip=3)
            raise LayerError(name, '_forward_abstract', self._caller,
                             input_signature, trace)
Пример #3
0
    def _forward_abstract(self, input_signature):
        """Computes shapes and dtypes this layer would produce in a forward pass.

    Args:
      input_signature: ShapeDtype instance (if this layer takes one input)
          or list/tuple of ShapeDtype instances.

    Returns:
      Tuple of (output, state).

      The output part of the tuple is a ShapeDtype instance representing the
      shape and type of the output (if this layer has one output) or a tuple
      of ShapeDtype instances (if this layer has more than one output).
    """
        try:
            # Note: By using rng_signature in place of an rng, we avoid computing and
            # permanently storing in global memory a large number of dropout masks.
            # TODO(jonni): Check if using an rng still carries this cost.
            rng_signature = ShapeDtype((2, ), np.uint32)
            weight_signature = nested_map(signature, self.weights)
            forward_infer_shapes = math.abstract_eval(self.forward_with_state)
            return forward_infer_shapes(input_signature, weight_signature,
                                        self.state, rng_signature)
        except Exception as e:
            name, trace = self._name, _short_traceback(skip=3)
            raise LayerError(name, '_forward_abstract', self._caller,
                             input_signature, trace) from e
Пример #4
0
 def weights_and_state_signature(self, input_signature):
     """Return a pair containing the signatures of weights and state."""
     abstract_init = math.abstract_eval(self.init)
     return abstract_init(input_signature)