def pseudo_call(self, pseudo_inputs, params, state): """Computes shapes and types this layer would produce for the given inputs. Args: pseudo_inputs: A ShapeType instance (input data minus the actual values) or a tuple of ShapeType instances, following the same conventions as Layer.call's input arg. params: Parameters for this layer. state: start state. Returns: A ShapeType instance representing the shape and type of the output (if this layer has one output) or a tuple of ShapeType instances (if this layer has more than one output). """ try: # Beware: using an actual RNG (as opposed to this ShapeType stub) would # cause a large number of dropout masks to be computed and permanently # stored in global memory. rng = ShapeType(shape=(2, ), dtype=onp.uint32) def call_on_input(x, params, state, rng): return self.call(x, params=params, state=state, rng=rng) params_shapes = nested_map( params, lambda x: ShapeType(shape=x.shape, dtype=x.dtype)) s = backend.eval_on_shapes(call_on_input)(pseudo_inputs, params_shapes, state, rng) return s except Exception: name, trace = self.__class__.__name__, _short_traceback(skip=3) raise LayerError(name, 'pseudo_call', self._caller, pseudo_inputs, trace)
def sizes(x): """Get a structure of sizes for a structure of nested arrays.""" def size(x): try: return x.size except Exception: # pylint: disable=broad-except return 0 return nested_map(x, size)
def shapes(x): """Get a structure of shapes for a structure of nested arrays.""" def shape(x): try: return tuple([int(i) for i in x.shape]) except Exception: # pylint: disable=broad-except return [] return nested_map(x, shape)