def test_extend_pretrained_tokens(self): model_dir = tempfile.TemporaryDirectory().name os.makedirs(model_dir) save_path = MiniBertFactory.create_mini_bert_weights(model_dir) tokenizer = bert.FullTokenizer(vocab_file=os.path.join( model_dir, "vocab.txt"), do_lower_case=True) ckpt_dir = os.path.dirname(save_path) bert_params = bert.params_from_pretrained_ckpt(ckpt_dir) self.assertEqual(bert_params.token_type_vocab_size, 2) bert_params.extra_tokens_vocab_size = 3 l_bert = bert.BertModelLayer.from_params(bert_params) # we dummy call the layer once in order to instantiate the weights l_bert([np.array([[1, 1, 0]]), np.array([[1, 0, 0]])], mask=[[True, True, False]]) mismatched = bert.load_stock_weights(l_bert, save_path) self.assertEqual(0, len(mismatched), "token_type embeddings should have mismatched shape") l_bert([np.array([[1, -3, 0]]), np.array([[1, 0, 0]])], mask=[[True, True, False]])
def test_freeze(self): model_dir = tempfile.TemporaryDirectory().name os.makedirs(model_dir) save_path = MiniBertFactory.create_mini_bert_weights(model_dir) tokenizer = bert.FullTokenizer(vocab_file=os.path.join( model_dir, "vocab.txt"), do_lower_case=True) # prepare input max_seq_len = 24 input_str_batch = ["hello, bert!", "how are you doing!"] input_ids, token_type_ids = self.prepare_input_batch( input_str_batch, tokenizer, max_seq_len) bert_ckpt_file = os.path.join(model_dir, "bert_model.ckpt") bert_params = bert.params_from_pretrained_ckpt(model_dir) bert_params.adapter_size = 4 l_bert = bert.BertModelLayer.from_params(bert_params) model = keras.models.Sequential([ l_bert, ]) model.build(input_shape=(None, max_seq_len)) model.summary() l_bert.apply_adapter_freeze() model.summary() bert.load_stock_weights(l_bert, bert_ckpt_file) #l_bert.embeddings_layer.trainable = False model.summary() orig_weight_values = [] for weight in l_bert.weights: orig_weight_values.append(weight.numpy()) model.compile(optimizer=keras.optimizers.Adam(), loss=keras.losses.mean_squared_error, run_eagerly=True) orig_pred = model.predict(input_ids) model.fit(x=input_ids, y=np.zeros_like(orig_pred), batch_size=2, epochs=4) for ndx, weight in enumerate(l_bert.weights): print("{}: {}".format( np.array_equal(weight.numpy(), orig_weight_values[ndx]), weight.name)) model.summary()
def setUp(self) -> None: tf.compat.v1.reset_default_graph() tf.compat.v1.enable_eager_execution() print("Eager Execution:", tf.executing_eagerly()) # build a dummy bert self.ckpt_path = MiniBertFactory.create_mini_bert_weights() self.ckpt_dir = os.path.dirname(self.ckpt_path) self.tokenizer = bert.FullTokenizer(vocab_file=os.path.join(self.ckpt_dir, "vocab.txt"), do_lower_case=True)
def test_extend_pretrained_segments(self): model_dir = tempfile.TemporaryDirectory().name os.makedirs(model_dir) save_path = MiniBertFactory.create_mini_bert_weights(model_dir) tokenizer = bert.FullTokenizer(vocab_file=os.path.join( model_dir, "vocab.txt"), do_lower_case=True) ckpt_dir = os.path.dirname(save_path) bert_params = bert.params_from_pretrained_ckpt(ckpt_dir) self.assertEqual(bert_params.token_type_vocab_size, 2) bert_params.token_type_vocab_size = 4 l_bert = bert.BertModelLayer.from_params(bert_params) # we dummy call the layer once in order to instantiate the weights l_bert([np.array([[1, 1, 0]]), np.array([[1, 0, 0]])]) #, mask=[[True, True, False]]) # # - load the weights from a pre-trained model, # - expect a mismatch for the token_type embeddings # - use the segment/token type id=0 embedding for the missing token types # mismatched = bert.load_stock_weights(l_bert, save_path) self.assertEqual(1, len(mismatched), "token_type embeddings should have mismatched shape") for weight, value in mismatched: if re.match("(.*)embeddings/token_type_embeddings/embeddings:0", weight.name): seg0_emb = value[:1, :] new_segment_embeddings = np.repeat( seg0_emb, (weight.shape[0] - value.shape[0]), axis=0) new_value = np.concatenate([value, new_segment_embeddings], axis=0) keras.backend.batch_set_value([(weight, new_value)]) tte = l_bert.embeddings_layer.token_type_embeddings_layer.weights[0] if not tf.executing_eagerly(): with tf.keras.backend.get_session() as sess: tte, = sess.run((tte, )) self.assertTrue(np.allclose(seg0_emb, tte[0], 1e-6)) self.assertFalse(np.allclose(seg0_emb, tte[1], 1e-6)) self.assertTrue(np.allclose(seg0_emb, tte[2], 1e-6)) self.assertTrue(np.allclose(seg0_emb, tte[3], 1e-6)) bert_params.token_type_vocab_size = 4 print("token_type_vocab_size", bert_params.token_type_vocab_size) print(l_bert.embeddings_layer.trainable_weights[1])
def test_adapter_albert_freeze(self): model_dir = tempfile.TemporaryDirectory().name os.makedirs(model_dir) # for tokenizer only save_path = MiniBertFactory.create_mini_bert_weights(model_dir) tokenizer = bert.FullTokenizer(vocab_file=os.path.join( model_dir, "vocab.txt"), do_lower_case=True) # prepare input max_seq_len = 28 input_str_batch = ["hello, albert!", "how are you doing!"] input_ids, token_type_ids = self.prepare_input_batch( input_str_batch, tokenizer, max_seq_len, extra_token_count=3) bert_params = bert.BertModelLayer.Params( attention_dropout=0.1, hidden_act="gelu", hidden_dropout=0.1, hidden_size=8, initializer_range=0.02, intermediate_size=32, max_position_embeddings=32, num_heads=2, num_layers=2, token_type_vocab_size=2, vocab_size=len(tokenizer.vocab), adapter_size=2, embedding_size=4, extra_tokens_vocab_size=3, shared_layer=True, ) l_bert = bert.BertModelLayer.from_params(bert_params) model = keras.models.Sequential([ l_bert, ]) model.build(input_shape=(None, max_seq_len)) model.summary() l_bert.apply_adapter_freeze() model.summary() orig_weight_values = [] for weight in l_bert.weights: orig_weight_values.append(weight.numpy()) model.compile(optimizer=keras.optimizers.Adam(), loss=keras.losses.mean_squared_error, run_eagerly=True) trainable_count = len(l_bert.trainable_weights) orig_pred = model.predict(input_ids) model.fit(x=input_ids, y=np.zeros_like(orig_pred), batch_size=2, epochs=4) trained_count = 0 for ndx, weight in enumerate(l_bert.weights): weight_equal = np.array_equal(weight.numpy(), orig_weight_values[ndx]) print("trained:[{}]: {}".format(not weight_equal, weight.name)) if not weight_equal: trained_count += 1 print(" trained weights:", trained_count) print("trainable weights:", trainable_count) self.assertEqual(trained_count, trainable_count) model.summary()