def pseudo_call(self, pseudo_inputs, params, state): """Computes shapes and types this layer would produce for the given inputs. Args: pseudo_inputs: A ShapeType instance (input data minus the actual values) or a tuple of ShapeType instances, following the same conventions as Layer.call's input arg. params: Parameters for this layer. state: start state. Returns: A ShapeType instance representing the shape and type of the output (if this layer has one output) or a tuple of ShapeType instances (if this layer has more than one output). """ try: # Beware: using an actual RNG (as opposed to this ShapeType stub) would # cause a large number of dropout masks to be computed and permanently # stored in global memory. rng = ShapeType(shape=(2, ), dtype=onp.uint32) def call_on_input(x, params, state, rng): return self.call(x, params=params, state=state, rng=rng) params_shapes = nested_map( params, lambda x: ShapeType(shape=x.shape, dtype=x.dtype)) s = backend.eval_on_shapes(call_on_input)(pseudo_inputs, params_shapes, state, rng) return s except Exception: name, trace = self.__class__.__name__, _short_traceback(skip=3) raise LayerError(name, 'pseudo_call', self._caller, pseudo_inputs, trace)
def check_shape_agreement(layer_fn, input_shapes, integer_inputs=False): """Checks if the layer's call output agrees its pseudo_call predictions. This function helps test layer mechanics and inter-layer connections that aren't dependent on specific data values. Args: layer_fn: A Layer instance, viewed as a function from input shapes to output shapes. input_shapes: A tuple representing a shape (if the layer takes one input) or a tuple of shapes (if this layer takes more than one input). For example: (210, 160, 3) or ((210, 160, 3), (105, 80, 3)). integer_inputs: If True, use numpy int32 as the type for the pseudo-data, else use float32. Returns: A tuple representing either a single shape (if the layer has one output) or a tuple of shape tuples (if the layer has more than one output). """ rng1, rng2, rng3 = backend.random.split(backend.random.get_prng(0), 3) input_dtype = onp.int32 if integer_inputs else onp.float32 if _is_tuple_of_shapes(input_shapes): pseudo_data = tuple(ShapeType(x, input_dtype) for x in input_shapes) input_dtype = tuple(input_dtype for _ in input_shapes) else: pseudo_data = ShapeType(input_shapes, input_dtype) params, state = layer_fn.initialize(input_shapes, input_dtype, rng1) pseudo_output, _ = layer_fn.pseudo_call(pseudo_data, params, state) if isinstance(pseudo_output, tuple): output_shape = tuple(x.shape for x in pseudo_output) else: output_shape = pseudo_output.shape random_input = _random_values(input_shapes, rng2, integer_inputs) real_output, _ = layer_fn(random_input, params, state=state, rng=rng3) result_shape = shapes(real_output) msg = 'output shape %s != real result shape %s' % (output_shape, result_shape) assert output_shape == result_shape, msg # TODO(jonni): Remove this assert? It makes test logs harder to read. return output_shape