Esempio n. 1
0
    def pseudo_call(self, pseudo_inputs, params, state):
        """Computes shapes and types this layer would produce for the given inputs.

    Args:
      pseudo_inputs: A ShapeType instance (input data minus the actual values)
          or a tuple of ShapeType instances, following the same conventions as
          Layer.call's input arg.
      params: Parameters for this layer.
      state: start state.

    Returns:
      A ShapeType instance representing the shape and type of the output (if
      this layer has one output) or a tuple of ShapeType instances (if this
      layer has more than one output).
    """
        try:
            # Beware: using an actual RNG (as opposed to this ShapeType stub) would
            # cause a large number of dropout masks to be computed and permanently
            # stored in global memory.
            rng = ShapeType(shape=(2, ), dtype=onp.uint32)

            def call_on_input(x, params, state, rng):
                return self.call(x, params=params, state=state, rng=rng)

            params_shapes = nested_map(
                params, lambda x: ShapeType(shape=x.shape, dtype=x.dtype))
            s = backend.eval_on_shapes(call_on_input)(pseudo_inputs,
                                                      params_shapes, state,
                                                      rng)
            return s
        except Exception:
            name, trace = self.__class__.__name__, _short_traceback(skip=3)
            raise LayerError(name, 'pseudo_call', self._caller, pseudo_inputs,
                             trace)
Esempio n. 2
0
def sizes(x):
  """Get a structure of sizes for a structure of nested arrays."""
  def size(x):
    try:
      return x.size
    except Exception:  # pylint: disable=broad-except
      return 0
  return nested_map(x, size)
Esempio n. 3
0
def shapes(x):
  """Get a structure of shapes for a structure of nested arrays."""
  def shape(x):
    try:
      return tuple([int(i) for i in x.shape])
    except Exception:  # pylint: disable=broad-except
      return []
  return nested_map(x, shape)