コード例 #1
0
 def test_layer_invocation(self):
     test_layer = cls_head.ClassificationHead(inner_dim=5, num_classes=2)
     features = tf.zeros(shape=(2, 10, 10), dtype=tf.float32)
     output = test_layer(features)
     self.assertAllClose(output, [[0., 0.], [0., 0.]])
     self.assertSameElements(test_layer.checkpoint_items.keys(),
                             ["pooler_dense"])
コード例 #2
0
ファイル: cls_head_test.py プロジェクト: ykate1998/models
  def test_pooler_layer(self, inner_dim, num_weights_expected):
    test_layer = cls_head.ClassificationHead(inner_dim=inner_dim, num_classes=2)
    features = tf.zeros(shape=(2, 10, 10), dtype=tf.float32)
    _ = test_layer(features)

    num_weights_observed = len(test_layer.get_weights())
    self.assertEqual(num_weights_observed, num_weights_expected)
コード例 #3
0
    def test_layer_serialization(self):
        layer = cls_head.ClassificationHead(10, 2)
        new_layer = cls_head.ClassificationHead.from_config(layer.get_config())

        # If the serialization was successful, the new config should match the old.
        self.assertAllEqual(layer.get_config(), new_layer.get_config())