Exemplo n.º 1
0
    def test_layer_serialization(self):
        layer = cls_head.GaussianProcessClassificationHead(
            inner_dim=5,
            num_classes=2,
            use_spec_norm=True,
            use_gp_layer=True,
            **self.spec_norm_kwargs,
            **self.gp_layer_kwargs)
        new_layer = cls_head.GaussianProcessClassificationHead.from_config(
            layer.get_config())

        # If the serialization was successful, the new config should match the old.
        self.assertAllEqual(layer.get_config(), new_layer.get_config())
Exemplo n.º 2
0
    def test_pooler_layer(self, inner_dim, num_weights_expected):
        test_layer = cls_head.GaussianProcessClassificationHead(
            inner_dim=inner_dim,
            num_classes=2,
            use_spec_norm=True,
            use_gp_layer=True,
            initializer="zeros",
            **self.spec_norm_kwargs,
            **self.gp_layer_kwargs)
        features = tf.zeros(shape=(2, 10, 10), dtype=tf.float32)
        _ = test_layer(features)

        num_weights_observed = len(test_layer.get_weights())
        self.assertEqual(num_weights_observed, num_weights_expected)
Exemplo n.º 3
0
    def test_sngp_kwargs_serialization(self):
        """Tests if SNGP-specific kwargs are added during serialization."""
        layer = cls_head.GaussianProcessClassificationHead(
            inner_dim=5,
            num_classes=2,
            use_spec_norm=True,
            use_gp_layer=True,
            **self.spec_norm_kwargs,
            **self.gp_layer_kwargs)
        layer_config = layer.get_config()

        # The config value should equal to those defined in setUp().
        self.assertEqual(layer_config["norm_multiplier"], 1.)
        self.assertEqual(layer_config["num_inducing"], 512)
Exemplo n.º 4
0
 def test_layer_invocation(self):
     test_layer = cls_head.GaussianProcessClassificationHead(
         inner_dim=5,
         num_classes=2,
         use_spec_norm=True,
         use_gp_layer=True,
         initializer="zeros",
         **self.spec_norm_kwargs,
         **self.gp_layer_kwargs)
     features = tf.zeros(shape=(2, 10, 10), dtype=tf.float32)
     output, _ = test_layer(features)
     self.assertAllClose(output, [[0., 0.], [0., 0.]])
     self.assertSameElements(test_layer.checkpoint_items.keys(),
                             ["pooler_dense"])
Exemplo n.º 5
0
  def test_sngp_train_logits(self):
    """Checks if temperature scaling is disabled during training."""
    features = tf.zeros(shape=(5, 10, 10), dtype=tf.float32)

    gp_layer = cls_head.GaussianProcessClassificationHead(
        inner_dim=5, num_classes=2)

    # Without temperature.
    gp_layer.temperature = None
    outputs_no_temp = gp_layer(features, training=True)

    # With temperature.
    gp_layer.temperature = 10.
    outputs_with_temp = gp_layer(features, training=True)

    self.assertAllEqual(outputs_no_temp, outputs_with_temp)
Exemplo n.º 6
0
  def test_sngp_output_shape(self, use_gp_layer, return_covmat):
    batch_size = 32
    num_classes = 2

    test_layer = cls_head.GaussianProcessClassificationHead(
        inner_dim=5,
        num_classes=num_classes,
        use_spec_norm=True,
        use_gp_layer=use_gp_layer,
        **self.spec_norm_kwargs,
        **self.gp_layer_kwargs)

    features = tf.zeros(shape=(batch_size, 10, 10), dtype=tf.float32)
    outputs = test_layer(features, return_covmat=return_covmat)

    if use_gp_layer and return_covmat:
      self.assertIsInstance(outputs, tuple)
      self.assertEqual(outputs[0].shape, (batch_size, num_classes))
      self.assertEqual(outputs[1].shape, (batch_size, batch_size))
    else:
      self.assertIsInstance(outputs, tf.Tensor)
      self.assertEqual(outputs.shape, (batch_size, num_classes))