Ejemplo n.º 1
0
    def test_lambda_skip_state_variable_from_initializer(self):
        # Force the initializers to use the tf.random.Generator, which will
        # contain the state variable.
        kernel_initializer = initializers.RandomNormalV2()
        kernel_initializer._random_generator._rng_type = (
            kernel_initializer._random_generator.RNG_STATEFUL)
        dense = keras.layers.Dense(1,
                                   use_bias=False,
                                   kernel_initializer=kernel_initializer)

        def lambda_fn(x):
            return dense(x + 1)  # Dense layer is built on first call

        # While it is generally not advised to mix Variables with Lambda layers,
        # if the variables are explicitly set as attributes then they are still
        # tracked. This is consistent with the base Layer behavior.
        layer = keras.layers.Lambda(lambda_fn)
        layer.dense = dense

        model = test_utils.get_model_from_layers([layer], input_shape=(10, ))
        model.compile(
            keras.optimizers.optimizer_v2.gradient_descent.SGD(0.1),
            "mae",
            run_eagerly=test_utils.should_run_eagerly(),
        )
        x, y = np.ones((10, 10), "float32"), 2 * np.ones((10, 10), "float32")
        model.fit(x, y, batch_size=2, epochs=2, validation_data=(x, y))
        self.assertLen(model.trainable_weights, 1)
Ejemplo n.º 2
0
  def test_partition(self):
    with self.cached_session():
      partition_enabled_initializers = [
          initializers.ZerosV2(),
          initializers.OnesV2(),
          initializers.RandomUniformV2(),
          initializers.RandomNormalV2(),
          initializers.TruncatedNormalV2(),
          initializers.LecunUniformV2(),
          initializers.GlorotUniformV2(),
          initializers.HeUniformV2()
      ]
      for initializer in partition_enabled_initializers:
        got = initializer(
            shape=(4, 2), partition_shape=(2, 2), partition_offset=(0, 0))
        self.assertEqual(got.shape, (2, 2))

      partition_forbidden_initializers = [
          initializers.OrthogonalV2(),
          initializers.IdentityV2()
      ]
      for initializer in partition_forbidden_initializers:
        with self.assertRaisesRegex(
            ValueError,
            "initializer doesn't support partition-related arguments"):
          initializer(
              shape=(4, 2), partition_shape=(2, 2), partition_offset=(0, 0))
Ejemplo n.º 3
0
 def test_normal(self):
     tensor_shape = (8, 12, 99)
     with self.cached_session():
         self._runner(
             initializers.RandomNormalV2(mean=0, stddev=1, seed=153),
             tensor_shape,
         )