Exemple #1
0
    def pseudo_call(self, pseudo_inputs, params, state):
        """Computes shapes and types this layer would produce for the given inputs.

    Args:
      pseudo_inputs: A ShapeType instance (input data minus the actual values)
          or a tuple of ShapeType instances, following the same conventions as
          Layer.call's input arg.
      params: Parameters for this layer.
      state: start state.

    Returns:
      A ShapeType instance representing the shape and type of the output (if
      this layer has one output) or a tuple of ShapeType instances (if this
      layer has more than one output).
    """
        try:
            # Beware: using an actual RNG (as opposed to this ShapeType stub) would
            # cause a large number of dropout masks to be computed and permanently
            # stored in global memory.
            rng = ShapeType(shape=(2, ), dtype=onp.uint32)

            def call_on_input(x, params, state, rng):
                return self.call(x, params=params, state=state, rng=rng)

            params_shapes = nested_map(
                params, lambda x: ShapeType(shape=x.shape, dtype=x.dtype))
            s = backend.eval_on_shapes(call_on_input)(pseudo_inputs,
                                                      params_shapes, state,
                                                      rng)
            return s
        except Exception:
            name, trace = self.__class__.__name__, _short_traceback(skip=3)
            raise LayerError(name, 'pseudo_call', self._caller, pseudo_inputs,
                             trace)
Exemple #2
0
def check_shape_agreement(layer_fn, input_shapes, integer_inputs=False):
    """Checks if the layer's call output agrees its pseudo_call predictions.

  This function helps test layer mechanics and inter-layer connections that
  aren't dependent on specific data values.

  Args:
    layer_fn: A Layer instance, viewed as a function from input shapes to
        output shapes.
    input_shapes: A tuple representing a shape (if the layer takes one input)
        or a tuple of shapes (if this layer takes more than one input).
        For example: (210, 160, 3) or ((210, 160, 3), (105, 80, 3)).
    integer_inputs: If True, use numpy int32 as the type for the pseudo-data,
        else use float32.

  Returns:
    A tuple representing either a single shape (if the layer has one output) or
    a tuple of shape tuples (if the layer has more than one output).
  """
    rng1, rng2, rng3 = backend.random.split(backend.random.get_prng(0), 3)
    input_dtype = onp.int32 if integer_inputs else onp.float32
    if _is_tuple_of_shapes(input_shapes):
        pseudo_data = tuple(ShapeType(x, input_dtype) for x in input_shapes)
        input_dtype = tuple(input_dtype for _ in input_shapes)
    else:
        pseudo_data = ShapeType(input_shapes, input_dtype)
    params, state = layer_fn.initialize(input_shapes, input_dtype, rng1)
    pseudo_output, _ = layer_fn.pseudo_call(pseudo_data, params, state)
    if isinstance(pseudo_output, tuple):
        output_shape = tuple(x.shape for x in pseudo_output)
    else:
        output_shape = pseudo_output.shape

    random_input = _random_values(input_shapes, rng2, integer_inputs)
    real_output, _ = layer_fn(random_input, params, state=state, rng=rng3)
    result_shape = shapes(real_output)

    msg = 'output shape %s != real result shape %s' % (output_shape,
                                                       result_shape)
    assert output_shape == result_shape, msg
    # TODO(jonni): Remove this assert? It makes test logs harder to read.
    return output_shape