示例#1
0
def layer_test(
    layer_cls,
    kwargs=None,
    input_shape=None,
    input_dtype=None,
    input_data=None,
    expected_output=None,
    expected_output_dtype=None,
    expected_output_shape=None,
    validate_training=True,
    adapt_data=None,
    custom_objects=None,
    test_harness=None,
    supports_masking=None,
):
    """Test routine for a layer with a single input and single output.

    Args:
      layer_cls: Layer class object.
      kwargs: Optional dictionary of keyword arguments for instantiating the
        layer.
      input_shape: Input shape tuple.
      input_dtype: Data type of the input data.
      input_data: Numpy array of input data.
      expected_output: Numpy array of the expected output.
      expected_output_dtype: Data type expected for the output.
      expected_output_shape: Shape tuple for the expected shape of the output.
      validate_training: Whether to attempt to validate training on this layer.
        This might be set to False for non-differentiable layers that output
        string or integer values.
      adapt_data: Optional data for an 'adapt' call. If None, adapt() will not
        be tested for this layer. This is only relevant for PreprocessingLayers.
      custom_objects: Optional dictionary mapping name strings to custom objects
        in the layer class. This is helpful for testing custom layers.
      test_harness: The Tensorflow test, if any, that this function is being
        called in.
      supports_masking: Optional boolean to check the `supports_masking` property
        of the layer. If None, the check will not be performed.

    Returns:
      The output data (Numpy array) returned by the layer, for additional
      checks to be done by the calling code.

    Raises:
      ValueError: if `input_shape is None`.
    """
    if input_data is None:
        if input_shape is None:
            raise ValueError("input_shape is None")
        if not input_dtype:
            input_dtype = "float32"
        input_data_shape = list(input_shape)
        for i, e in enumerate(input_data_shape):
            if e is None:
                input_data_shape[i] = np.random.randint(1, 4)
        input_data = 10 * np.random.random(input_data_shape)
        if input_dtype[:5] == "float":
            input_data -= 0.5
        input_data = input_data.astype(input_dtype)
    elif input_shape is None:
        input_shape = input_data.shape
    if input_dtype is None:
        input_dtype = input_data.dtype
    if expected_output_dtype is None:
        expected_output_dtype = input_dtype

    if tf.as_dtype(expected_output_dtype) == tf.string:
        if test_harness:
            assert_equal = test_harness.assertAllEqual
        else:
            assert_equal = string_test
    else:
        if test_harness:
            assert_equal = test_harness.assertAllClose
        else:
            assert_equal = numeric_test

    # instantiation
    kwargs = kwargs or {}
    layer = layer_cls(**kwargs)

    if (supports_masking is not None
            and layer.supports_masking != supports_masking):
        raise AssertionError(
            "When testing layer %s, the `supports_masking` property is %r"
            "but expected to be %r.\nFull kwargs: %s" % (
                layer_cls.__name__,
                layer.supports_masking,
                supports_masking,
                kwargs,
            ))

    # Test adapt, if data was passed.
    if adapt_data is not None:
        layer.adapt(adapt_data)

    # test get_weights , set_weights at layer level
    weights = layer.get_weights()
    layer.set_weights(weights)

    # test and instantiation from weights
    if "weights" in tf_inspect.getargspec(layer_cls.__init__):
        kwargs["weights"] = weights
        layer = layer_cls(**kwargs)

    # test in functional API
    x = layers.Input(shape=input_shape[1:], dtype=input_dtype)
    y = layer(x)
    if backend.dtype(y) != expected_output_dtype:
        raise AssertionError(
            "When testing layer %s, for input %s, found output "
            "dtype=%s but expected to find %s.\nFull kwargs: %s" % (
                layer_cls.__name__,
                x,
                backend.dtype(y),
                expected_output_dtype,
                kwargs,
            ))

    def assert_shapes_equal(expected, actual):
        """Asserts that the output shape from the layer matches the actual shape."""
        if len(expected) != len(actual):
            raise AssertionError(
                "When testing layer %s, for input %s, found output_shape="
                "%s but expected to find %s.\nFull kwargs: %s" %
                (layer_cls.__name__, x, actual, expected, kwargs))

        for expected_dim, actual_dim in zip(expected, actual):
            if isinstance(expected_dim, tf.compat.v1.Dimension):
                expected_dim = expected_dim.value
            if isinstance(actual_dim, tf.compat.v1.Dimension):
                actual_dim = actual_dim.value
            if expected_dim is not None and expected_dim != actual_dim:
                raise AssertionError(
                    "When testing layer %s, for input %s, found output_shape="
                    "%s but expected to find %s.\nFull kwargs: %s" %
                    (layer_cls.__name__, x, actual, expected, kwargs))

    if expected_output_shape is not None:
        assert_shapes_equal(tf.TensorShape(expected_output_shape), y.shape)

    # check shape inference
    model = models.Model(x, y)
    computed_output_shape = tuple(
        layer.compute_output_shape(tf.TensorShape(input_shape)).as_list())
    computed_output_signature = layer.compute_output_signature(
        tf.TensorSpec(shape=input_shape, dtype=input_dtype))
    actual_output = model.predict(input_data)
    actual_output_shape = actual_output.shape
    assert_shapes_equal(computed_output_shape, actual_output_shape)
    assert_shapes_equal(computed_output_signature.shape, actual_output_shape)
    if computed_output_signature.dtype != actual_output.dtype:
        raise AssertionError(
            "When testing layer %s, for input %s, found output_dtype="
            "%s but expected to find %s.\nFull kwargs: %s" % (
                layer_cls.__name__,
                x,
                actual_output.dtype,
                computed_output_signature.dtype,
                kwargs,
            ))
    if expected_output is not None:
        assert_equal(actual_output, expected_output)

    # test serialization, weight setting at model level
    model_config = model.get_config()
    recovered_model = models.Model.from_config(model_config, custom_objects)
    if model.weights:
        weights = model.get_weights()
        recovered_model.set_weights(weights)
        output = recovered_model.predict(input_data)
        assert_equal(output, actual_output)

    # test training mode (e.g. useful for dropout tests)
    # Rebuild the model to avoid the graph being reused between predict() and
    # See b/120160788 for more details. This should be mitigated after 2.0.
    layer_weights = (layer.get_weights()
                     )  # Get the layer weights BEFORE training.
    if validate_training:
        model = models.Model(x, layer(x))
        if _thread_local_data.run_eagerly is not None:
            model.compile(
                "rmsprop",
                "mse",
                weighted_metrics=["acc"],
                run_eagerly=should_run_eagerly(),
            )
        else:
            model.compile("rmsprop", "mse", weighted_metrics=["acc"])
        model.train_on_batch(input_data, actual_output)

    # test as first layer in Sequential API
    layer_config = layer.get_config()
    layer_config["batch_input_shape"] = input_shape
    layer = layer.__class__.from_config(layer_config)

    # Test adapt, if data was passed.
    if adapt_data is not None:
        layer.adapt(adapt_data)

    model = models.Sequential()
    model.add(layers.Input(shape=input_shape[1:], dtype=input_dtype))
    model.add(layer)

    layer.set_weights(layer_weights)
    actual_output = model.predict(input_data)
    actual_output_shape = actual_output.shape
    for expected_dim, actual_dim in zip(computed_output_shape,
                                        actual_output_shape):
        if expected_dim is not None:
            if expected_dim != actual_dim:
                raise AssertionError(
                    "When testing layer %s **after deserialization**, "
                    "for input %s, found output_shape="
                    "%s but expected to find inferred shape %s.\nFull kwargs: %s"
                    % (
                        layer_cls.__name__,
                        x,
                        actual_output_shape,
                        computed_output_shape,
                        kwargs,
                    ))
    if expected_output is not None:
        assert_equal(actual_output, expected_output)

    # test serialization, weight setting at model level
    model_config = model.get_config()
    recovered_model = models.Sequential.from_config(model_config,
                                                    custom_objects)
    if model.weights:
        weights = model.get_weights()
        recovered_model.set_weights(weights)
        output = recovered_model.predict(input_data)
        assert_equal(output, actual_output)

    # for further checks in the caller function
    return actual_output
    def _get_single_variable(
        self,
        name,
        shape=None,
        dtype=tf.float32,
        initializer=None,
        regularizer=None,
        partition_info=None,
        reuse=None,
        trainable=None,
        caching_device=None,
        validate_shape=True,
        constraint=None,
        synchronization=tf.VariableSynchronization.AUTO,
        aggregation=tf.compat.v1.VariableAggregation.NONE,
    ):
        """Get or create a single Variable (e.g.

        a shard or entire variable).

        See the documentation of get_variable above (ignore partitioning components)
        for details.

        Args:
          name: see get_variable.
          shape: see get_variable.
          dtype: see get_variable.
          initializer: see get_variable.
          regularizer: see get_variable.
          partition_info: _PartitionInfo object.
          reuse: see get_variable.
          trainable: see get_variable.
          caching_device: see get_variable.
          validate_shape: see get_variable.
          constraint: see get_variable.
          synchronization: see get_variable.
          aggregation: see get_variable.

        Returns:
          A Variable.  See documentation of get_variable above.

        Raises:
          ValueError: See documentation of get_variable above.
        """
        # Set to true if initializer is a constant.
        initializing_from_value = False
        if initializer is not None and not callable(initializer):
            initializing_from_value = True
        if shape is not None and initializing_from_value:
            raise ValueError(
                "If initializer is a constant, do not specify shape.")

        dtype = tf.as_dtype(dtype)
        shape = as_shape(shape)

        if name in self._vars:
            # Here we handle the case when returning an existing variable.
            found_var = self._vars[name]
            if not shape.is_compatible_with(found_var.get_shape()):
                raise ValueError(
                    "Trying to share variable %s, but specified shape %s"
                    " and found shape %s." %
                    (name, shape, found_var.get_shape()))
            if not dtype.is_compatible_with(found_var.dtype):
                dtype_str = dtype.name
                found_type_str = found_var.dtype.name
                raise ValueError(
                    "Trying to share variable %s, but specified dtype %s"
                    " and found dtype %s." % (name, dtype_str, found_type_str))
            return found_var

        # The code below handles only the case of creating a new variable.
        if reuse is True:  # pylint: disable=g-bool-id-comparison
            raise ValueError(
                "Variable %s does not exist, or was not created with "
                "tf.get_variable(). Did you mean to set "
                "reuse=tf.AUTO_REUSE in VarScope?" % name)

        # Create the tensor to initialize the variable with default value.
        if initializer is None:
            (
                initializer,
                initializing_from_value,
            ) = self._get_default_initializer(name=name,
                                              shape=shape,
                                              dtype=dtype)
        # Enter an init scope when creating the initializer.
        with tf.init_scope():
            if initializing_from_value:
                init_val = initializer
                variable_dtype = None
            else:
                # Instantiate initializer if provided initializer is a type object.
                if tf_inspect.isclass(initializer):
                    initializer = initializer()
                if shape.is_fully_defined():
                    if ("partition_info"
                            in tf_inspect.getargspec(initializer).args):
                        init_val = functools.partial(
                            initializer,
                            shape.as_list(),
                            dtype=dtype,
                            partition_info=partition_info,
                        )
                    else:
                        init_val = functools.partial(initializer,
                                                     shape.as_list(),
                                                     dtype=dtype)
                    variable_dtype = dtype.base_dtype
                else:
                    init_val = initializer
                    variable_dtype = None

        # Create the variable (Always eagerly as a workaround for a strange
        # tpu / funcgraph / keras functional model interaction )
        with tf.init_scope():
            v = tf.Variable(
                initial_value=init_val,
                name=name,
                trainable=trainable,
                caching_device=caching_device,
                dtype=variable_dtype,
                validate_shape=validate_shape,
                constraint=constraint,
                synchronization=synchronization,
                aggregation=aggregation,
            )

        self._vars[name] = v
        logging.vlog(
            1,
            "Created variable %s with shape %s and init %s",
            v.name,
            format(shape),
            initializer,
        )

        # Run the regularizer if requested and save the resulting loss.
        if regularizer:
            self.add_regularizer(v, regularizer)

        return v