def sample_generate(top_k=50, temperature=1.0, decoder_path='decoder.pth', batch_size=1, gpu_id=0): # make sure your model is on GPU device = torch.device(f"cuda:{gpu_id}") print('load model') #------------------------LOAD MODEL----------------- tokenizer = BertTokenizer.from_pretrained("bert-base-chinese") encoder = TransformerEncoder() encoder.load_state_dict(torch.load("encoder.pth")) encoder = encoder.to(device) encoder.eval() decoder = TransformerDecoderLM() decoder.load_state_dict(torch.load(decoder_path)) decoder = decoder.to(device) decoder.eval() print('load success') #------------------------END LOAD MODEL-------------- #------------------------LOAD VALIDATE DATA------------------ test_data = torch.load("../test_data.pth") test_dataset = TensorDataset(*test_data) test_dataloader = DataLoader(dataset=test_dataset, shuffle=False, batch_size=batch_size) #------------------------END LOAD VALIDATE DATA-------------- #------------------------START SAMPLE GENERETE------------------- update_count = 0 bleu_2scores = 0 bleu_4scores = 0 nist_2scores = 0 nist_4scores = 0 sen_length = 0 meteor_scores = 0 sentences = [] print('start generate....') for batch in test_dataloader: with torch.no_grad(): batch = [item.to(device) for item in batch] encoder_input, decoder_input, mask, _ = batch _, past = encoder(encoder_input, mask) sentence = [] prev_pred = decoder_input[:, :1] sentence.append(prev_pred) length = 1 # decoding loop for i in range(100): mask = F.pad(mask, (0, 1), "constant", 1.0) logits, past = decoder(prev_pred, mask, past=past, past_length=length) logits = logits.squeeze(1) / temperature logits = top_k_logits(logits, k=top_k) probs = F.softmax(logits, dim=-1) prev_pred = torch.multinomial(probs, num_samples=1) sentence.append(prev_pred) if prev_pred[0][0] == 102: break length += 1 sentence = torch.cat(sentence, dim=-1) predict = tokenizer.convert_ids_to_tokens(sentence[0].tolist()) target = decoder_input.squeeze(dim=0) target_num = (target != 0).sum() inputs = encoder_input.squeeze(dim=0) input_num = (inputs != 0).sum() inputs = tokenizer.convert_ids_to_tokens( inputs[:input_num].tolist()) reference = tokenizer.convert_ids_to_tokens( target[:target_num].tolist()) print('-' * 20 + f"example {update_count}" + '-' * 20) print(f"input: {''.join(inputs)}") print(f"output: {''.join(reference)}") print(f"predict: {''.join(predict)}") temp_bleu_2, \ temp_bleu_4, \ temp_nist_2, \ temp_nist_4, \ temp_meteor_scores = calculate_metrics(predict[1:-1], reference[1:-1]) bleu_2scores += temp_bleu_2 bleu_4scores += temp_bleu_4 nist_2scores += temp_nist_2 nist_4scores += temp_nist_4 meteor_scores += temp_meteor_scores sentences.append(" ".join(predict[1:-1])) update_count += 1 entro, dist = cal_entropy(sentences) mean_len, var_len = cal_length(sentences) print(f'avg: {mean_len}, var: {var_len}') print(f'entro: {entro}') print(f'dist: {dist}') print(f'test bleu_2scores: {bleu_2scores / update_count}') print(f'test bleu_4scores: {bleu_4scores / update_count}') print(f'test nist_2scores: {nist_2scores / update_count}') print(f'test nist_4scores: {nist_4scores / update_count}') print(f'test meteor_scores: {meteor_scores / update_count}')
def calculate_perplexity(batch_size=1, gpu_id=0, decoder_path='decoder.pth'): # make sure your model is on GPU device = torch.device(f"cuda:{gpu_id}") #------------------------LOAD MODEL----------------- print('load the model....') encoder = TransformerEncoder() encoder.load_state_dict(torch.load("encoder.pth")) encoder = encoder.to(device) encoder.eval() decoder = TransformerDecoderLM() decoder.load_state_dict(torch.load(decoder_path)) decoder = decoder.to(device) decoder.eval() print('load success') #------------------------END LOAD MODEL-------------- #------------------------LOAD VAL DATA------------------ val_data = torch.load("validate_data.pth") val_dataset = TensorDataset(*val_data) train_data = torch.load("train_data.pth") train_dataset = TensorDataset(*train_data) test_data = torch.load("test_data.pth") test_dataset = TensorDataset(*test_data) val_dataloader = DataLoader(dataset=val_dataset, shuffle=False, batch_size=batch_size) train_dataloader = DataLoader(dataset=train_dataset, shuffle=False, batch_size=batch_size) test_dataloader = DataLoader(dataset=test_dataset, shuffle=False, batch_size=batch_size) #------------------------END LOAD VAL DATA-------------- #------------------------START VAL------------------- perplexity = 0 batch_count = 0 print('start calculate the train perplexity....') with torch.no_grad(): for batch in train_dataloader: batch = [item.to(device) for item in batch] encoder_input, decoder_input, mask_encoder_input, mask_decoder_input = batch _, past = encoder(encoder_input, mask_encoder_input) mask = torch.cat([mask_encoder_input, mask_decoder_input], dim=1) logits, _ = decoder(decoder_input, mask, past=past, past_length=0) out = logits[:, :-1].contiguous() target = decoder_input[:, 1:].contiguous() target_mask = mask_decoder_input[:, 1:].contiguous() loss = util.sequence_cross_entropy_with_logits(out, target, target_mask, average="token") perplexity += np.exp(loss.item()) batch_count += 1 print(f'train perplexity: {perplexity / batch_count}') perplexity = 0 batch_count = 0 print('start calculate the validate perplexity....') with torch.no_grad(): for batch in val_dataloader: batch = [item.to(device) for item in batch] encoder_input, decoder_input, mask_encoder_input, mask_decoder_input = batch _, past = encoder(encoder_input, mask_encoder_input) mask = torch.cat([mask_encoder_input, mask_decoder_input], dim=1) logits, _ = decoder(decoder_input, mask, past=past, past_length=0) out = logits[:, :-1].contiguous() target = decoder_input[:, 1:].contiguous() target_mask = mask_decoder_input[:, 1:].contiguous() loss = util.sequence_cross_entropy_with_logits(out, target, target_mask, average="token") perplexity += np.exp(loss.item()) batch_count += 1 print(f'validate perplexity: {perplexity / batch_count}') perplexity = 0 batch_count = 0 print('start calculate the test perplexity....') with torch.no_grad(): for batch in test_dataloader: batch = [item.to(device) for item in batch] encoder_input, decoder_input, mask_encoder_input, mask_decoder_input = batch _, past = encoder(encoder_input, mask_encoder_input) mask = torch.cat([mask_encoder_input, mask_decoder_input], dim=1) logits, _ = decoder(decoder_input, mask, past=past, past_length=0) out = logits[:, :-1].contiguous() target = decoder_input[:, 1:].contiguous() target_mask = mask_decoder_input[:, 1:].contiguous() loss = util.sequence_cross_entropy_with_logits(out, target, target_mask, average="token") perplexity += np.exp(loss.item()) batch_count += 1 print(f'test perplexity: {perplexity / batch_count}')
def train_model(epochs=10, num_gradients_accumulation=4, batch_size=4, gpu_id=0, lr=1e-5, load_dir='decoder_model'): # make sure your model is on GPU device = torch.device(f"cuda:{gpu_id}") #------------------------LOAD MODEL----------------- print('load the model....') encoder = TransformerEncoder() decoder = TransformerDecoderLM() encoder.load_state_dict(torch.load("encoder.pth")) decoder.load_state_dict(torch.load("decoder.pth")) encoder = encoder.to(device) decoder = decoder.to(device) print('load success') #------------------------END LOAD MODEL-------------- #------------------------LOAD TRAIN DATA------------------ train_data = torch.load("train_data.pth") train_dataset = TensorDataset(*train_data) train_dataloader = DataLoader(dataset=train_dataset, shuffle=True, batch_size=batch_size) val_data = torch.load("validate_data.pth") val_dataset = TensorDataset(*val_data) val_dataloader = DataLoader(dataset=val_dataset, shuffle=True, batch_size=batch_size) #------------------------END LOAD TRAIN DATA-------------- #------------------------SET OPTIMIZER------------------- num_train_optimization_steps = len( train_dataset) * epochs // batch_size // num_gradients_accumulation param_optimizer = list(decoder.named_parameters()) no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight'] optimizer_grouped_parameters = [{ 'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay': 0.01 }, { 'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0 }] optimizer = OpenAIAdam(optimizer_grouped_parameters, lr=lr, warmup=0.01, max_grad_norm=1.0, weight_decay=0.01, t_total=num_train_optimization_steps) #------------------------END SET OPTIMIZER-------------- #------------------------START TRAINING------------------- update_count = 0 start = time.time() print('start training....') for epoch in range(epochs): #------------------------training------------------------ decoder.train() losses = 0 times = 0 for batch in train_dataloader: batch = [item.to(device) for item in batch] encoder_input, decoder_input, mask_encoder_input, mask_decoder_input = batch _, past = encoder(encoder_input, mask_encoder_input) mask = torch.cat([mask_encoder_input, mask_decoder_input], dim=1) logits, _ = decoder(decoder_input, mask, past=past, past_length=0) out = logits[:, :-1].contiguous() target = decoder_input[:, 1:].contiguous() target_mask = mask_decoder_input[:, 1:].contiguous() loss = util.sequence_cross_entropy_with_logits(out, target, target_mask, average="token") loss.backward() losses += loss.item() times += 1 update_count += 1 if update_count % num_gradients_accumulation == num_gradients_accumulation - 1: optimizer.step() optimizer.zero_grad() end = time.time() print('-' * 20 + f'epoch {epoch}' + '-' * 20) print(f'time: {(end - start)}') print(f'loss: {losses / times}') start = end #------------------------validate------------------------ decoder.eval() perplexity = 0 batch_count = 0 print('start calculate the perplexity....') with torch.no_grad(): for batch in val_dataloader: batch = [item.to(device) for item in batch] encoder_input, decoder_input, mask_encoder_input, mask_decoder_input = batch _, past = encoder(encoder_input, mask_encoder_input) mask = torch.cat([mask_encoder_input, mask_decoder_input], dim=1) logits, _ = decoder(decoder_input, mask, past=past, past_length=0) out = logits[:, :-1].contiguous() target = decoder_input[:, 1:].contiguous() target_mask = mask_decoder_input[:, 1:].contiguous() loss = util.sequence_cross_entropy_with_logits(out, target, target_mask, average="token") perplexity += np.exp(loss.item()) batch_count += 1 print(f'validate perplexity: {perplexity / batch_count}') torch.save( decoder.state_dict(), os.path.join(os.path.abspath('.'), load_dir, str(epoch) + "decoder.pth"))
def sample_generate(top_k=50, temperature=1.0, decoder_path='decoder.pth', batch_size=1, show_num=10, gpu_id=0): # make sure your model is on GPU device = torch.device(f"cuda:{gpu_id}") print('load model') #------------------------LOAD MODEL----------------- tokenizer = BertTokenizer.from_pretrained("bert-base-chinese") encoder = TransformerEncoder() encoder.load_state_dict(torch.load("encoder.pth")) encoder = encoder.to(device) encoder.eval() decoder = TransformerDecoderLM() decoder.load_state_dict(torch.load(decoder_path)) decoder = decoder.to(device) decoder.eval() print('load success') #------------------------END LOAD MODEL-------------- #------------------------LOAD VALIDATE DATA------------------ val_data = torch.load("test_data.pth") val_dataset = TensorDataset(*val_data) val_dataloader = DataLoader(dataset=val_dataset, shuffle=False, batch_size=batch_size) #------------------------END LOAD VALIDATE DATA-------------- #------------------------START SAMPLE GENERETE------------------- update_count = 0 print('start validate....') for batch in val_dataloader: with torch.no_grad(): batch = [item.to(device) for item in batch] encoder_input, decoder_input, mask, _ = batch _, past = encoder(encoder_input, mask) sentence = [] prev_pred = decoder_input[:, :1] sentence.append(prev_pred) length = 1 # decoding loop for i in range(100): mask = F.pad(mask, (0, 1), "constant", 1.0) logits, past = decoder(prev_pred, mask, past=past, past_length=length) logits = logits.squeeze(1) / temperature logits = top_k_logits(logits, k=top_k) probs = F.softmax(logits, dim=-1) prev_pred = torch.multinomial(probs, num_samples=1) sentence.append(prev_pred) length += 1 sentence = torch.cat(sentence, dim=-1) res = "".join(tokenizer.convert_ids_to_tokens( sentence[0].tolist())) inputs = "".join( tokenizer.convert_ids_to_tokens(encoder_input[0].tolist())) target = "".join( tokenizer.convert_ids_to_tokens(decoder_input[0].tolist())) print('-' * 20 + f'Case {update_count}' + '-' * 20) print('-' * 20 + 'Input' + '-' * 20) print(inputs) print('') print('-' * 20 + 'Predcit' + '-' * 20) print(res[:100]) print('') print('-' * 20 + 'Target' + '-' * 20) print(target) print('') update_count += 1 if update_count == show_num: break