コード例 #1
0
    def test_layer_serialization(self):
        cls_list = [("foo", 2), ("bar", 3)]
        test_layer = cls_head.MultiClsHeads(inner_dim=5, cls_list=cls_list)
        new_layer = cls_head.MultiClsHeads.from_config(test_layer.get_config())

        # If the serialization was successful, the new config should match the old.
        self.assertAllEqual(test_layer.get_config(), new_layer.get_config())
コード例 #2
0
ファイル: cls_head_test.py プロジェクト: ykate1998/models
  def test_pooler_layer(self, inner_dim, num_weights_expected):
    cls_list = [("foo", 2), ("bar", 3)]
    test_layer = cls_head.MultiClsHeads(inner_dim=inner_dim, cls_list=cls_list)
    features = tf.zeros(shape=(2, 10, 10), dtype=tf.float32)
    _ = test_layer(features)

    num_weights_observed = len(test_layer.get_weights())
    self.assertEqual(num_weights_observed, num_weights_expected)
コード例 #3
0
 def test_layer_invocation(self):
     cls_list = [("foo", 2), ("bar", 3)]
     test_layer = cls_head.MultiClsHeads(inner_dim=5, cls_list=cls_list)
     features = tf.zeros(shape=(2, 10, 10), dtype=tf.float32)
     outputs = test_layer(features)
     self.assertAllClose(outputs["foo"], [[0., 0.], [0., 0.]])
     self.assertAllClose(outputs["bar"], [[0., 0., 0.], [0., 0., 0.]])
     self.assertSameElements(test_layer.checkpoint_items.keys(),
                             ["pooler_dense", "foo", "bar"])