def test_embedding_layer_without_token_type(self): layer = mobile_bert_layers.MobileBertEmbedding(10, 8, 2, 16) input_seq = tf.Variable([[2, 3, 4, 5]]) output = layer(input_seq) output_shape = output.shape.as_list() expected_shape = [1, 4, 16] self.assertListEqual(output_shape, expected_shape, msg=None)
def test_embedding_layer_get_config(self): layer = mobile_bert_layers.MobileBertEmbedding( word_vocab_size=16, word_embed_size=32, type_vocab_size=4, output_embed_size=32, max_sequence_length=32, normalization_type='layer_norm', initializer=tf.keras.initializers.TruncatedNormal(stddev=0.01), dropout_rate=0.5) layer_config = layer.get_config() new_layer = mobile_bert_layers.MobileBertEmbedding.from_config(layer_config) self.assertEqual(layer_config, new_layer.get_config())