class BERTTrainer: """ BERTTrainer make the pretrained BERT model with two LM training method. 1. Masked Language Model : 3.3.1 Task #1: Masked LM 2. Next Sentence prediction : 3.3.2 Task #2: Next Sentence Prediction please check the details on README.md with simple example. """ def __init__(self, bert: BERT, vocab_size: int, train_dataloader: DataLoader, test_dataloader: DataLoader = None, lr: float = 1e-4, betas=(0.9, 0.999), weight_decay: float = 0.01, warmup_steps=10000, with_cuda: bool = True, cuda_devices=None, log_freq: int = 10, pad_index=0): """ :param bert: BERT model which you want to train :param vocab_size: total word vocab size :param train_dataloader: train dataset data loader :param test_dataloader: test dataset data loader [can be None] :param lr: learning rate of optimizer :param betas: Adam optimizer betas :param weight_decay: Adam optimizer weight decay param :param with_cuda: traning with cuda :param log_freq: logging frequency of the batch iteration """ # Setup cuda device for BERT training, argument -c, --cuda should be true cuda_condition = torch.cuda.is_available() and with_cuda self.device = torch.device("cuda:0" if cuda_condition else "cpu") # This BERT model will be saved every epoch self.bert = bert # Initialize the BERT Language Model, with BERT model self.model = BERTLM(bert, vocab_size).to(self.device) # Distributed GPU training if CUDA can detect more than 1 GPU if with_cuda and torch.cuda.device_count() > 1: print("Using %d GPUS for BERT" % torch.cuda.device_count()) self.model = nn.DataParallel(self.model, device_ids=cuda_devices) # Setting the train and test data loader self.train_data = train_dataloader self.test_data = test_dataloader self.pad_index = pad_index # Setting the Adam optimizer with hyper-param # self.optim = Adam(self.model.parameters(), lr=lr, # betas=betas, weight_decay=weight_decay) # self.optim_schedule = ScheduledOptim( # self.optim, self.bert.hidden, n_warmup_steps=warmup_steps) self.optim = SGD(self.model.parameters(), lr=lr, momentum=0.9) # Using Negative Log Likelihood Loss function for predicting the masked_token self.criterion = nn.NLLLoss(ignore_index=self.pad_index) self.log_freq = log_freq print("Total Parameters:", sum([p.nelement() for p in self.model.parameters()])) def train(self, epoch): self.model.train() return self.iteration(epoch, self.train_data) def test(self, epoch): self.model.eval() return self.iteration(epoch, self.test_data, train=False) def iteration(self, epoch, data_loader, train=True): """ loop over the data_loader for training or testing if on train status, backward operation is activated and also auto save the model every peoch :param epoch: current epoch index :param data_loader: torch.utils.data.DataLoader for iteration :param train: boolean value of is train or test :return: None """ # pdb.set_trace() str_code = "train" if train else "test" # Setting the tqdm progress bar data_iter = tqdm.tqdm(enumerate(data_loader), desc="EP_%s:%d" % (str_code, epoch), total=len(data_loader), bar_format="{l_bar}{r_bar}") avg_loss = 0.0 total_correct = 0 total_element = 0 def calculate_iter(data): next_sent_output, mask_lm_output = self.model.forward( data["bert_input"], data["segment_label"], data["adj_mat"], train) mask_loss = self.criterion(mask_lm_output.transpose(1, 2), data["bert_label"]) loss = mask_loss return loss for i, data in data_iter: # 0. batch_data will be sent into the device(GPU or cpu) # pdb.set_trace() data = data[0] data = {key: value.to(self.device) for key, value in data.items()} if train: loss = calculate_iter(data) else: with torch.no_grad(): loss = calculate_iter(data) # 1. forward the next_sentence_prediction and masked_lm model # next_sent_output, mask_lm_output = self.model.forward( # data["bert_input"], data["segment_label"], data["adj_mat"], train) # # pdb.set_trace() # # 2-1. NLL(negative log likelihood) loss of is_next classification result # # next_loss = self.criterion(next_sent_output, data["is_next"]) # # 2-2. NLLLoss of predicting masked token word # mask_loss = self.criterion( # mask_lm_output.transpose(1, 2), data["bert_label"]) # # pdb.set_trace() # # 2-3. Adding next_loss and mask_loss : 3.4 Pre-training Procedure # # loss = next_loss + mask_loss # loss = mask_loss # 3. backward and optimization only in train if train: self.optim.zero_grad() loss.backward() # self.optim.step_and_update_lr() self.optim.step() # pdb.set_trace() # mlm prediction accuracy # correct = next_sent_output.argmax( # dim=-1).eq(data["is_next"]).sum().item() correct = 0 elements = 0 for labels, t_labels in zip(mask_lm_output.argmax(dim=-1), data["bert_label"]): correct += sum([ 1 if l == t and t != self.pad_index else 0 for l, t in zip(labels, t_labels) ]) elements += sum([1 for t in t_labels if t != self.pad_index]) # next sentence prediction accuracy # correct = next_sent_output.argmax( # dim=-1).eq(data["is_next"]).sum().item() avg_loss += loss.item() total_correct += correct # total_element += data["is_next"].nelement() total_element += elements post_fix = { "epoch": epoch, "iter": i, "avg_loss": avg_loss / (i + 1), "avg_acc": total_correct / total_element * 100, "loss": loss.item() } if i % self.log_freq == 0 and i != 0: data_iter.write(str(post_fix)) print("EP%d_%s, avg_loss=" % (epoch, str_code), avg_loss / len(data_iter), "total_acc=", total_correct * 100.0 / total_element) return avg_loss / len(data_iter) def save(self, epoch, file_path="output/bert_trained.model"): """ Saving the current BERT model on file_path :param epoch: current epoch number :param file_path: model output path which gonna be file_path+"ep%d" % epoch :return: final_output_path """ # output_path = file_path + ".ep%d" % epoch # torch.save(self.bert.cpu(), output_path) # self.bert.to(self.device) # print("EP:%d Model Saved on:" % epoch, output_path) # return output_path output_path = file_path # + ".ep%d" % epoch # if self.updated: # return output_path # torch.save(self.bert.cpu(), output_path) torch.save( { 'epoch': epoch, 'model_state_dict': self.model.state_dict() # 'optimizer_state_dict': optimizer.state_dict(), # 'loss': loss, # ... }, output_path) # self.bert.to(self.device) print("EP:%d Model Saved on:" % epoch, output_path) # self.updated = True return output_path
mask_loss = (torch.sum((frame1-data["visual_word"][:,:max_frames,:])**2)\ +torch.sum((frame2-data["visual_word"][:,max_frames:,:])**2))\ /(2*max_frames*feature_size*batchsize) mu_loss = (torch.sum((torch.mean(hid1,1)-data['n1'])**2)\ +torch.sum((torch.mean(hid2,1)-data['n2'])**2))/(hidden_size*batchsize) loss = 0.92 * mask_loss + 0.08 * nei_loss + 0.8 * mu_loss # loss = mask_loss loss.backward() optimizer.step() itera += 1 infos['iter'] = itera infos['epoch'] = epoch if itera % 10 == 0 or batchsize < batch_size: print 'Epoch:%d Step:[%d/%d] neiloss: %.2f maskloss: %.2f mu_loss: %.2f' \ % (epoch, i, total_len, nei_loss.data.cpu().numpy(),\ mask_loss.data.cpu().numpy(),mu_loss.data.cpu().numpy()) torch.save(model.state_dict(), file_path + '/9288.pth') torch.save(optimizer.state_dict(), optimizer_pth_path) with open(os.path.join(file_path, 'infos.pkl'), 'wb') as f: pickle.dump(infos, f) with open(os.path.join(file_path, 'histories.pkl'), 'wb') as f: pickle.dump(histories, f) epoch += 1 if epoch > num_epochs: break model.train()