def check_encoder_decoder_model_from_pretrained_configs( self, config, input_ids, attention_mask, encoder_hidden_states, decoder_config, decoder_input_ids, decoder_attention_mask, **kwargs): encoder_decoder_config = EncoderDecoderConfig.from_encoder_decoder_configs( config, decoder_config) self.assertTrue(encoder_decoder_config.decoder.is_decoder) enc_dec_model = EncoderDecoderModel(encoder_decoder_config) enc_dec_model.to(torch_device) enc_dec_model.eval() self.assertTrue(enc_dec_model.config.is_encoder_decoder) outputs_encoder_decoder = enc_dec_model( input_ids=input_ids, decoder_input_ids=decoder_input_ids, attention_mask=attention_mask, decoder_attention_mask=decoder_attention_mask, return_dict=True, ) self.assertEqual(outputs_encoder_decoder["logits"].shape, (decoder_input_ids.shape + (decoder_config.vocab_size, ))) self.assertEqual( outputs_encoder_decoder["encoder_last_hidden_state"].shape, (input_ids.shape + (config.hidden_size, )))
def check_encoder_decoder_model_labels(self, config, input_ids, attention_mask, encoder_hidden_states, decoder_config, decoder_input_ids, decoder_attention_mask, labels, **kwargs): encoder_model, decoder_model = self.get_encoder_decoder_model( config, decoder_config) enc_dec_model = EncoderDecoderModel(encoder=encoder_model, decoder=decoder_model) enc_dec_model.to(torch_device) outputs_encoder_decoder = enc_dec_model( input_ids=input_ids, decoder_input_ids=decoder_input_ids, attention_mask=attention_mask, decoder_attention_mask=decoder_attention_mask, labels=labels, return_dict=True, ) loss = outputs_encoder_decoder["loss"] # check that backprop works loss.backward() self.assertEqual(outputs_encoder_decoder["logits"].shape, (decoder_input_ids.shape + (decoder_config.vocab_size, ))) self.assertEqual( outputs_encoder_decoder["encoder_last_hidden_state"].shape, (input_ids.shape + (config.hidden_size, )))
def create_and_check_bert_encoder_decoder_model( self, config, input_ids, attention_mask, encoder_hidden_states, decoder_config, decoder_input_ids, decoder_attention_mask, **kwargs): encoder_model = BertModel(config) decoder_model = BertForMaskedLM(decoder_config) enc_dec_model = EncoderDecoderModel(encoder=encoder_model, decoder=decoder_model) enc_dec_model.to(torch_device) outputs_encoder_decoder = enc_dec_model( input_ids=input_ids, decoder_input_ids=decoder_input_ids, attention_mask=attention_mask, decoder_attention_mask=decoder_attention_mask, ) self.assertEqual(outputs_encoder_decoder[0].shape, (decoder_input_ids.shape + (decoder_config.vocab_size, ))) self.assertEqual(outputs_encoder_decoder[1].shape, (input_ids.shape + (config.hidden_size, ))) encoder_outputs = (encoder_hidden_states, ) outputs_encoder_decoder = enc_dec_model( encoder_outputs=encoder_outputs, decoder_input_ids=decoder_input_ids, attention_mask=attention_mask, decoder_attention_mask=decoder_attention_mask, ) self.assertEqual(outputs_encoder_decoder[0].shape, (decoder_input_ids.shape + (decoder_config.vocab_size, ))) self.assertEqual(outputs_encoder_decoder[1].shape, (input_ids.shape + (config.hidden_size, )))
def check_save_and_load(self, config, input_ids, attention_mask, encoder_hidden_states, decoder_config, decoder_input_ids, decoder_attention_mask, **kwargs): encoder_model, decoder_model = self.get_encoder_decoder_model( config, decoder_config) enc_dec_model = EncoderDecoderModel(encoder=encoder_model, decoder=decoder_model) enc_dec_model.to(torch_device) enc_dec_model.eval() with torch.no_grad(): outputs = enc_dec_model( input_ids=input_ids, decoder_input_ids=decoder_input_ids, attention_mask=attention_mask, decoder_attention_mask=decoder_attention_mask, ) out_2 = outputs[0].cpu().numpy() out_2[np.isnan(out_2)] = 0 with tempfile.TemporaryDirectory() as tmpdirname: enc_dec_model.save_pretrained(tmpdirname) EncoderDecoderModel.from_pretrained(tmpdirname) after_outputs = enc_dec_model( input_ids=input_ids, decoder_input_ids=decoder_input_ids, attention_mask=attention_mask, decoder_attention_mask=decoder_attention_mask, ) out_1 = after_outputs[0].cpu().numpy() out_1[np.isnan(out_1)] = 0 max_diff = np.amax(np.abs(out_1 - out_2)) self.assertLessEqual(max_diff, 1e-5)
def create_and_check_bert_encoder_decoder_model_lm_labels( self, config, input_ids, attention_mask, encoder_hidden_states, decoder_config, decoder_input_ids, decoder_attention_mask, lm_labels, **kwargs): encoder_model = BertModel(config) decoder_model = BertForMaskedLM(decoder_config) enc_dec_model = EncoderDecoderModel(encoder=encoder_model, decoder=decoder_model) enc_dec_model.to(torch_device) outputs_encoder_decoder = enc_dec_model( input_ids=input_ids, decoder_input_ids=decoder_input_ids, attention_mask=attention_mask, decoder_attention_mask=decoder_attention_mask, lm_labels=lm_labels, ) lm_loss = outputs_encoder_decoder[0] self.check_loss_output(lm_loss) # check that backprop works lm_loss.backward() self.assertEqual(outputs_encoder_decoder[1].shape, (decoder_input_ids.shape + (decoder_config.vocab_size, ))) self.assertEqual(outputs_encoder_decoder[2].shape, (input_ids.shape + (config.hidden_size, )))
def check_encoder_decoder_model_output_attentions( self, config, input_ids, attention_mask, encoder_hidden_states, decoder_config, decoder_input_ids, decoder_attention_mask, labels, **kwargs ): # make the decoder inputs a different shape from the encoder inputs to harden the test decoder_input_ids = decoder_input_ids[:, :-1] decoder_attention_mask = decoder_attention_mask[:, :-1] encoder_model, decoder_model = self.get_encoder_decoder_model(config, decoder_config) enc_dec_model = EncoderDecoderModel(encoder=encoder_model, decoder=decoder_model) enc_dec_model.to(torch_device) outputs_encoder_decoder = enc_dec_model( input_ids=input_ids, decoder_input_ids=decoder_input_ids, attention_mask=attention_mask, decoder_attention_mask=decoder_attention_mask, output_attentions=True, ) encoder_attentions = outputs_encoder_decoder["encoder_attentions"] self.assertEqual(len(encoder_attentions), config.num_hidden_layers) self.assertEqual( encoder_attentions[0].shape[-3:], (config.num_attention_heads, input_ids.shape[-1], input_ids.shape[-1]) ) decoder_attentions = outputs_encoder_decoder["decoder_attentions"] num_decoder_layers = ( decoder_config.num_decoder_layers if hasattr(decoder_config, "num_decoder_layers") else decoder_config.num_hidden_layers ) self.assertEqual(len(decoder_attentions), num_decoder_layers) self.assertEqual( decoder_attentions[0].shape[-3:], (decoder_config.num_attention_heads, decoder_input_ids.shape[-1], decoder_input_ids.shape[-1]), ) cross_attentions = outputs_encoder_decoder["cross_attentions"] self.assertEqual(len(cross_attentions), num_decoder_layers) cross_attention_input_seq_len = decoder_input_ids.shape[-1] * ( 1 + (decoder_config.ngram if hasattr(decoder_config, "ngram") else 0) ) self.assertEqual( cross_attentions[0].shape[-3:], (decoder_config.num_attention_heads, cross_attention_input_seq_len, input_ids.shape[-1]), )
def check_encoder_decoder_model_generate(self, input_ids, config, decoder_config, **kwargs): encoder_model, decoder_model = self.get_encoder_decoder_model(config, decoder_config) enc_dec_model = EncoderDecoderModel(encoder=encoder_model, decoder=decoder_model) enc_dec_model.to(torch_device) # Bert does not have a bos token id, so use pad_token_id instead generated_output = enc_dec_model.generate( input_ids, decoder_start_token_id=enc_dec_model.config.decoder.pad_token_id ) self.assertEqual(generated_output.shape, (input_ids.shape[0],) + (decoder_config.max_length,))
def check_encoder_decoder_model_output_attentions( self, config, input_ids, attention_mask, encoder_hidden_states, decoder_config, decoder_input_ids, decoder_attention_mask, labels, **kwargs ): encoder_model, decoder_model = self.get_encoder_decoder_model(config, decoder_config) enc_dec_model = EncoderDecoderModel(encoder=encoder_model, decoder=decoder_model) enc_dec_model.to(torch_device) outputs_encoder_decoder = enc_dec_model( input_ids=input_ids, decoder_input_ids=decoder_input_ids, attention_mask=attention_mask, decoder_attention_mask=decoder_attention_mask, output_attentions=True, return_dict=True, ) encoder_attentions = outputs_encoder_decoder["encoder_attentions"] self.assertEqual(len(encoder_attentions), config.num_hidden_layers) self.assertListEqual( list(encoder_attentions[0].shape[-3:]), [config.num_attention_heads, input_ids.shape[-1], input_ids.shape[-1]], ) decoder_attentions = outputs_encoder_decoder["decoder_attentions"] num_decoder_layers = ( decoder_config.num_decoder_layers if hasattr(decoder_config, "num_decoder_layers") else decoder_config.num_hidden_layers ) self.assertEqual(len(decoder_attentions), num_decoder_layers) self.assertListEqual( list(decoder_attentions[0].shape[-3:]), [decoder_config.num_attention_heads, decoder_input_ids.shape[-1], decoder_input_ids.shape[-1]], ) cross_attentions = outputs_encoder_decoder["cross_attentions"] self.assertEqual(len(cross_attentions), num_decoder_layers) cross_attention_input_seq_len = input_ids.shape[-1] * ( 1 + (decoder_config.ngram if hasattr(decoder_config, "ngram") else 0) ) self.assertListEqual( list(cross_attentions[0].shape[-3:]), [decoder_config.num_attention_heads, cross_attention_input_seq_len, decoder_input_ids.shape[-1]], )
def check_encoder_decoder_model( self, config, input_ids, attention_mask, encoder_hidden_states, decoder_config, decoder_input_ids, decoder_attention_mask, **kwargs, ): encoder_model, decoder_model = self.get_encoder_decoder_model( config, decoder_config) enc_dec_model = EncoderDecoderModel(encoder=encoder_model, decoder=decoder_model) self.assertTrue(enc_dec_model.config.decoder.is_decoder) self.assertTrue(enc_dec_model.config.decoder.add_cross_attention) self.assertTrue(enc_dec_model.config.is_encoder_decoder) enc_dec_model.to(torch_device) outputs_encoder_decoder = enc_dec_model( input_ids=input_ids, decoder_input_ids=decoder_input_ids, attention_mask=attention_mask, decoder_attention_mask=decoder_attention_mask, return_dict=True, ) self.assertEqual(outputs_encoder_decoder["logits"].shape, (decoder_input_ids.shape + (decoder_config.vocab_size, ))) self.assertEqual( outputs_encoder_decoder["encoder_last_hidden_state"].shape, (input_ids.shape + (config.hidden_size, ))) encoder_outputs = BaseModelOutput( last_hidden_state=encoder_hidden_states) outputs_encoder_decoder = enc_dec_model( encoder_outputs=encoder_outputs, decoder_input_ids=decoder_input_ids, attention_mask=attention_mask, decoder_attention_mask=decoder_attention_mask, return_dict=True, ) self.assertEqual(outputs_encoder_decoder["logits"].shape, (decoder_input_ids.shape + (decoder_config.vocab_size, ))) self.assertEqual( outputs_encoder_decoder["encoder_last_hidden_state"].shape, (input_ids.shape + (config.hidden_size, )))
def check_encoder_decoder_model(self, config, input_ids, attention_mask, encoder_hidden_states, decoder_config, decoder_input_ids, decoder_attention_mask, **kwargs): encoder_model, decoder_model = self.get_encoder_decoder_model( config, decoder_config) enc_dec_model = EncoderDecoderModel(encoder=encoder_model, decoder=decoder_model) self.assertTrue(enc_dec_model.config.decoder.is_decoder) self.assertTrue(enc_dec_model.config.decoder.add_cross_attention) self.assertTrue(enc_dec_model.config.is_encoder_decoder) enc_dec_model.to(torch_device) outputs_encoder_decoder = enc_dec_model( input_ids=input_ids, decoder_input_ids=decoder_input_ids, attention_mask=attention_mask, decoder_attention_mask=decoder_attention_mask, ) self.assertEqual(outputs_encoder_decoder[0].shape, (decoder_input_ids.shape + (decoder_config.vocab_size, ))) self.assertEqual(outputs_encoder_decoder[1].shape, (input_ids.shape + (config.hidden_size, ))) encoder_outputs = (encoder_hidden_states, ) outputs_encoder_decoder = enc_dec_model( encoder_outputs=encoder_outputs, decoder_input_ids=decoder_input_ids, attention_mask=attention_mask, decoder_attention_mask=decoder_attention_mask, ) self.assertEqual(outputs_encoder_decoder[0].shape, (decoder_input_ids.shape + (decoder_config.vocab_size, ))) self.assertEqual(outputs_encoder_decoder[1].shape, (input_ids.shape + (config.hidden_size, )))
def get_model(args): if args.model_path: model = EncoderDecoderModel.from_pretrained(args.model_path) src_tokenizer = BertTokenizer.from_pretrained( os.path.join(args.model_path, "src_tokenizer") ) tgt_tokenizer = GPT2Tokenizer.from_pretrained( os.path.join(args.model_path, "tgt_tokenizer") ) tgt_tokenizer.build_inputs_with_special_tokens = types.MethodType( build_inputs_with_special_tokens, tgt_tokenizer ) if local_rank == 0 or local_rank == -1: print("model and tokenizer load from save success") else: src_tokenizer = BertTokenizer.from_pretrained(args.src_pretrain_dataset_name) tgt_tokenizer = GPT2Tokenizer.from_pretrained(args.tgt_pretrain_dataset_name) tgt_tokenizer.add_special_tokens( {"bos_token": "[BOS]", "eos_token": "[EOS]", "pad_token": "[PAD]"} ) tgt_tokenizer.build_inputs_with_special_tokens = types.MethodType( build_inputs_with_special_tokens, tgt_tokenizer ) encoder = BertGenerationEncoder.from_pretrained(args.src_pretrain_dataset_name) decoder = GPT2LMHeadModel.from_pretrained( args.tgt_pretrain_dataset_name, add_cross_attention=True, is_decoder=True ) decoder.resize_token_embeddings(len(tgt_tokenizer)) decoder.config.bos_token_id = tgt_tokenizer.bos_token_id decoder.config.eos_token_id = tgt_tokenizer.eos_token_id decoder.config.vocab_size = len(tgt_tokenizer) decoder.config.add_cross_attention = True decoder.config.is_decoder = True model_config = EncoderDecoderConfig.from_encoder_decoder_configs( encoder.config, decoder.config ) model = EncoderDecoderModel( encoder=encoder, decoder=decoder, config=model_config ) if local_rank != -1: model = model.to(device) if args.ngpu > 1: print("{}/{} GPU start".format(local_rank, torch.cuda.device_count())) model = torch.nn.parallel.DistributedDataParallel( model, device_ids=[local_rank], output_device=local_rank ) optimizer, scheduler = get_optimizer_and_schedule(args, model) return model, src_tokenizer, tgt_tokenizer, optimizer, scheduler
def create_and_check_encoder_decoder_shared_weights( self, config, input_ids, attention_mask, encoder_hidden_states, decoder_config, decoder_input_ids, decoder_attention_mask, labels, **kwargs): torch.manual_seed(0) encoder_model, decoder_model = self.get_encoder_decoder_model( config, decoder_config) model = EncoderDecoderModel(encoder=encoder_model, decoder=decoder_model) model.to(torch_device) model.eval() # load state dict copies weights but does not tie them decoder_state_dict = model.decoder._modules[ model.decoder.base_model_prefix].state_dict() model.encoder.load_state_dict(decoder_state_dict, strict=False) torch.manual_seed(0) tied_encoder_model, tied_decoder_model = self.get_encoder_decoder_model( config, decoder_config) config = EncoderDecoderConfig.from_encoder_decoder_configs( tied_encoder_model.config, tied_decoder_model.config, tie_encoder_decoder=True) tied_model = EncoderDecoderModel(encoder=tied_encoder_model, decoder=tied_decoder_model, config=config) tied_model.to(torch_device) tied_model.eval() model_result = model( input_ids=input_ids, decoder_input_ids=decoder_input_ids, attention_mask=attention_mask, decoder_attention_mask=decoder_attention_mask, ) tied_model_result = tied_model( input_ids=input_ids, decoder_input_ids=decoder_input_ids, attention_mask=attention_mask, decoder_attention_mask=decoder_attention_mask, ) # check that models has less parameters self.assertLess(sum(p.numel() for p in tied_model.parameters()), sum(p.numel() for p in model.parameters())) random_slice_idx = ids_tensor((1, ), model_result[0].shape[-1]).item() # check that outputs are equal self.assertTrue( torch.allclose(model_result[0][0, :, random_slice_idx], tied_model_result[0][0, :, random_slice_idx], atol=1e-4)) # check that outputs after saving and loading are equal with tempfile.TemporaryDirectory() as tmpdirname: tied_model.save_pretrained(tmpdirname) tied_model = EncoderDecoderModel.from_pretrained(tmpdirname) tied_model.to(torch_device) tied_model.eval() # check that models has less parameters self.assertLess(sum(p.numel() for p in tied_model.parameters()), sum(p.numel() for p in model.parameters())) random_slice_idx = ids_tensor((1, ), model_result[0].shape[-1]).item() tied_model_result = tied_model( input_ids=input_ids, decoder_input_ids=decoder_input_ids, attention_mask=attention_mask, decoder_attention_mask=decoder_attention_mask, ) # check that outputs are equal self.assertTrue( torch.allclose(model_result[0][0, :, random_slice_idx], tied_model_result[0][0, :, random_slice_idx], atol=1e-4))
def sample_generate(top_k=50, temperature=1.0, model_path='/content/BERT checkpoints/model-9.pth', gpu_id=0): # make sure your model is on GPU device = torch.device(f"cuda:{gpu_id}") # ------------------------LOAD MODEL----------------- print('load the model....') tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') bert_encoder = BertConfig.from_pretrained('bert-base-uncased') bert_decoder = BertConfig.from_pretrained('bert-base-uncased', is_decoder=True) config = EncoderDecoderConfig.from_encoder_decoder_configs( bert_encoder, bert_decoder) model = EncoderDecoderModel(config) model.load_state_dict(torch.load(model_path, map_location='cuda')) model = model.to(device) encoder = model.get_encoder() decoder = model.get_decoder() model.eval() print('load success') # ------------------------END LOAD MODEL-------------- # ------------------------LOAD VALIDATE DATA------------------ test_data = torch.load("/content/test_data.pth") test_dataset = TensorDataset(*test_data) test_dataloader = DataLoader(dataset=test_dataset, shuffle=False, batch_size=1) # ------------------------END LOAD VALIDATE DATA-------------- # ------------------------START GENERETE------------------- update_count = 0 bleu_2scores = 0 bleu_4scores = 0 nist_2scores = 0 nist_4scores = 0 sentences = [] meteor_scores = 0 print('start generating....') for batch in test_dataloader: with torch.no_grad(): batch = [item.to(device) for item in batch] encoder_input, decoder_input, mask_encoder_input, _ = batch past, _ = encoder(encoder_input, mask_encoder_input) prev_pred = decoder_input[:, :1] sentence = prev_pred # decoding loop for i in range(100): logits = decoder(sentence, encoder_hidden_states=past) logits = logits[0][:, -1] logits = logits.squeeze(1) / temperature logits = top_k_logits(logits, k=top_k) probs = F.softmax(logits, dim=-1) prev_pred = torch.multinomial(probs, num_samples=1) sentence = torch.cat([sentence, prev_pred], dim=-1) if prev_pred[0][0] == 102: break predict = tokenizer.convert_ids_to_tokens(sentence[0].tolist()) encoder_input = encoder_input.squeeze(dim=0) encoder_input_num = (encoder_input != 0).sum() inputs = tokenizer.convert_ids_to_tokens( encoder_input[:encoder_input_num].tolist()) decoder_input = decoder_input.squeeze(dim=0) decoder_input_num = (decoder_input != 0).sum() reference = tokenizer.convert_ids_to_tokens( decoder_input[:decoder_input_num].tolist()) print('-' * 20 + f"example {update_count}" + '-' * 20) print(f"input: {' '.join(inputs)}") print(f"output: {' '.join(reference)}") print(f"predict: {' '.join(predict)}") temp_bleu_2, \ temp_bleu_4, \ temp_nist_2, \ temp_nist_4, \ temp_meteor_scores = calculate_metrics(predict[1:-1], reference[1:-1]) bleu_2scores += temp_bleu_2 bleu_4scores += temp_bleu_4 nist_2scores += temp_nist_2 nist_4scores += temp_nist_4 meteor_scores += temp_meteor_scores sentences.append(" ".join(predict[1:-1])) update_count += 1 entro, dist = cal_entropy(sentences) mean_len, var_len = cal_length(sentences) print(f'avg: {mean_len}, var: {var_len}') print(f'entro: {entro}') print(f'dist: {dist}') print(f'test bleu_2scores: {bleu_2scores / update_count}') print(f'test bleu_4scores: {bleu_4scores / update_count}') print(f'test nist_2scores: {nist_2scores / update_count}') print(f'test nist_4scores: {nist_4scores / update_count}') print(f'test meteor_scores: {meteor_scores / update_count}')
def train_model(epochs=10, num_gradients_accumulation=4, batch_size=4, gpu_id=0, lr=1e-5, load_dir='/content/BERT checkpoints'): # make sure your model is on GPU device = torch.device(f"cuda:{gpu_id}") # ------------------------LOAD MODEL----------------- print('load the model....') bert_encoder = BertConfig.from_pretrained('bert-base-uncased') bert_decoder = BertConfig.from_pretrained('bert-base-uncased', is_decoder=True) config = EncoderDecoderConfig.from_encoder_decoder_configs( bert_encoder, bert_decoder) model = EncoderDecoderModel(config) model = model.to(device) print('load success') # ------------------------END LOAD MODEL-------------- # ------------------------LOAD TRAIN DATA------------------ train_data = torch.load("/content/train_data.pth") train_dataset = TensorDataset(*train_data) train_dataloader = DataLoader(dataset=train_dataset, shuffle=True, batch_size=batch_size) val_data = torch.load("/content/validate_data.pth") val_dataset = TensorDataset(*val_data) val_dataloader = DataLoader(dataset=val_dataset, shuffle=True, batch_size=batch_size) # ------------------------END LOAD TRAIN DATA-------------- # ------------------------SET OPTIMIZER------------------- num_train_optimization_steps = len( train_dataset) * epochs // batch_size // num_gradients_accumulation param_optimizer = list(model.named_parameters()) no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight'] optimizer_grouped_parameters = [{ 'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay': 0.01 }, { 'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0 }] optimizer = AdamW( optimizer_grouped_parameters, lr=lr, weight_decay=0.01, ) scheduler = get_linear_schedule_with_warmup( optimizer, num_warmup_steps=num_train_optimization_steps // 10, num_training_steps=num_train_optimization_steps) # ------------------------START TRAINING------------------- update_count = 0 start = time.time() print('start training....') for epoch in range(epochs): # ------------------------training------------------------ model.train() losses = 0 times = 0 print('\n' + '-' * 20 + f'epoch {epoch}' + '-' * 20) for batch in tqdm(train_dataloader): batch = [item.to(device) for item in batch] encoder_input, decoder_input, mask_encoder_input, mask_decoder_input = batch logits = model(input_ids=encoder_input, attention_mask=mask_encoder_input, decoder_input_ids=decoder_input, decoder_attention_mask=mask_decoder_input) out = logits[0][:, :-1].contiguous() target = decoder_input[:, 1:].contiguous() target_mask = mask_decoder_input[:, 1:].contiguous() loss = util.sequence_cross_entropy_with_logits(out, target, target_mask, average="token") loss.backward() losses += loss.item() times += 1 update_count += 1 if update_count % num_gradients_accumulation == num_gradients_accumulation - 1: torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm) optimizer.step() scheduler.step() optimizer.zero_grad() end = time.time() print(f'time: {(end - start)}') print(f'loss: {losses / times}') start = end # ------------------------validate------------------------ model.eval() perplexity = 0 batch_count = 0 print('\nstart calculate the perplexity....') with torch.no_grad(): for batch in tqdm(val_dataloader): batch = [item.to(device) for item in batch] encoder_input, decoder_input, mask_encoder_input, mask_decoder_input = batch logits = model(input_ids=encoder_input, attention_mask=mask_encoder_input, decoder_input_ids=decoder_input, decoder_attention_mask=mask_decoder_input) out = logits[0][:, :-1].contiguous() target = decoder_input[:, 1:].contiguous() target_mask = mask_decoder_input[:, 1:].contiguous() # print(out.shape,target.shape,target_mask.shape) loss = util.sequence_cross_entropy_with_logits(out, target, target_mask, average="token") perplexity += np.exp(loss.item()) batch_count += 1 print(f'\nvalidate perplexity: {perplexity / batch_count}') torch.save( model.state_dict(), os.path.join(os.path.abspath('.'), load_dir, "model-" + str(epoch) + ".pth"))
vocabsize = decparams["vocab_size"] max_length = decparams["max_length"] decoder_config = BertConfig( vocab_size=vocabsize, max_position_embeddings=max_length + 64, # this shuold be some large value num_attention_heads=decparams["num_attn_heads"], num_hidden_layers=decparams["num_hidden_layers"], hidden_size=decparams["hidden_size"], type_vocab_size=1, is_decoder=True) # Very Important decoder = BertForMaskedLM(config=decoder_config) # Define encoder decoder model model = EncoderDecoderModel(encoder=encoder, decoder=decoder) model.to(device) def count_parameters(mdl): return sum(p.numel() for p in mdl.parameters() if p.requires_grad) print(f'The encoder has {count_parameters(encoder):,} trainable parameters') print(f'The decoder has {count_parameters(decoder):,} trainable parameters') print(f'The model has {count_parameters(model):,} trainable parameters') optimizer = optim.Adam(model.parameters(), lr=modelparams['lr']) criterion = nn.NLLLoss(ignore_index=de_tokenizer.pad_token_id) num_train_batches = len(train_dataloader) num_valid_batches = len(valid_dataloader)
class PhonetizerModel: phon_tokenizer = { 'e': 7, 'i': 8, 'R': 9, 'a': 10, 'o': 11, 't': 12, 's': 13, 'l': 14, 'k': 15, 'p': 16, 'm': 17, 'n': 18, 'd': 19, 'y': 20, '@': 21, 'f': 22, 'z': 23, 'b': 24, '§': 25, 'v': 26, '2': 27, '1': 28, 'Z': 29, 'g': 30, 'u': 31, 'S': 32 } phon_untokenizer = {v: k for k, v in phon_tokenizer.items()} char_tokenizer = { 'e': 7, 'i': 8, 'a': 9, 'r': 10, 'o': 11, 's': 12, 't': 13, 'n': 14, 'l': 15, 'é': 16, 'c': 17, 'p': 18, 'u': 19, 'm': 20, 'd': 21, '-': 22, 'h': 23, 'g': 24, 'b': 25, 'v': 26, 'f': 27, 'k': 28, 'y': 29, 'x': 30, 'è': 31, 'ï': 32, 'j': 33, 'z': 34, 'w': 35, 'q': 36 } def __init__(self, device='cpu', model=None): vocabsize = 37 max_length = 50 encoder_config = BertConfig(vocab_size=vocabsize, max_position_embeddings=max_length + 64, num_attention_heads=4, num_hidden_layers=4, hidden_size=128, type_vocab_size=1) encoder = BertModel(config=encoder_config) vocabsize = 33 max_length = 50 decoder_config = BertConfig(vocab_size=vocabsize, max_position_embeddings=max_length + 64, num_attention_heads=4, num_hidden_layers=4, hidden_size=128, type_vocab_size=1, add_cross_attentions=True, is_decoder=True) decoder_config.add_cross_attention = True decoder = BertLMHeadModel(config=decoder_config) # Define encoder decoder model self.model = EncoderDecoderModel(encoder=encoder, decoder=decoder) self.model.to(device) self.device = device if model is not None: self.model.load_state_dict(torch.load(model)) def phonetize(self, word): word = word.replace('à', 'a') word = word.replace('û', 'u') word = word.replace('ù', 'u') word = word.replace('î', 'i') word = word.replace('ç', 'ss') word = word.replace('ô', 'o') word = word.replace('â', 'a') word = word.replace('qu', 'k') word = word.replace('ê', 'e') assert set(word).issubset(set(PhonetizerModel.char_tokenizer.keys())) encoded = torch.tensor( [0] + [PhonetizerModel.char_tokenizer[p] for p in word] + [2]) output = self.model.generate( encoded.unsqueeze(0).to(self.device), max_length=50, decoder_start_token_id=0, eos_token_id=2, pad_token_id=1, ).detach().cpu().numpy()[0] bound = np.where(output == 2)[0][0] if 2 in output else 1000 phon_pred = ''.join([ PhonetizerModel.phon_untokenizer[c] for c in output[:bound] if c > 6 ]) return phon_pred def check_phonetization_error(self, word, phon): prediction = self.phonetize(word)[:5] score = pairwise2.align.globalms(list(phon[:5]), list(prediction), 2, -1, -1, -.5, score_only=True, gap_char=['-']) / len(phon[:5]) return score