Example #1
0
class BERTTrainer:
    """
    BERTTrainer make the pretrained BERT model with two LM training method.

        1. Masked Language Model : 3.3.1 Task #1: Masked LM
        2. Next Sentence prediction : 3.3.2 Task #2: Next Sentence Prediction

    please check the details on README.md with simple example.

    """
    def __init__(self,
                 bert: BERT,
                 vocab_size: int,
                 train_dataloader: DataLoader,
                 test_dataloader: DataLoader = None,
                 lr: float = 1e-4,
                 betas=(0.9, 0.999),
                 weight_decay: float = 0.01,
                 warmup_steps=10000,
                 with_cuda: bool = True,
                 cuda_devices=None,
                 log_freq: int = 10,
                 pad_index=0):
        """
        :param bert: BERT model which you want to train
        :param vocab_size: total word vocab size
        :param train_dataloader: train dataset data loader
        :param test_dataloader: test dataset data loader [can be None]
        :param lr: learning rate of optimizer
        :param betas: Adam optimizer betas
        :param weight_decay: Adam optimizer weight decay param
        :param with_cuda: traning with cuda
        :param log_freq: logging frequency of the batch iteration
        """

        # Setup cuda device for BERT training, argument -c, --cuda should be true
        cuda_condition = torch.cuda.is_available() and with_cuda
        self.device = torch.device("cuda:0" if cuda_condition else "cpu")

        # This BERT model will be saved every epoch
        self.bert = bert
        # Initialize the BERT Language Model, with BERT model
        self.model = BERTLM(bert, vocab_size).to(self.device)

        # Distributed GPU training if CUDA can detect more than 1 GPU
        if with_cuda and torch.cuda.device_count() > 1:
            print("Using %d GPUS for BERT" % torch.cuda.device_count())
            self.model = nn.DataParallel(self.model, device_ids=cuda_devices)

        # Setting the train and test data loader
        self.train_data = train_dataloader
        self.test_data = test_dataloader
        self.pad_index = pad_index
        # Setting the Adam optimizer with hyper-param
        # self.optim = Adam(self.model.parameters(), lr=lr,
        #                   betas=betas, weight_decay=weight_decay)
        # self.optim_schedule = ScheduledOptim(
        #     self.optim, self.bert.hidden, n_warmup_steps=warmup_steps)
        self.optim = SGD(self.model.parameters(), lr=lr, momentum=0.9)
        # Using Negative Log Likelihood Loss function for predicting the masked_token
        self.criterion = nn.NLLLoss(ignore_index=self.pad_index)

        self.log_freq = log_freq

        print("Total Parameters:",
              sum([p.nelement() for p in self.model.parameters()]))

    def train(self, epoch):
        self.model.train()
        return self.iteration(epoch, self.train_data)

    def test(self, epoch):
        self.model.eval()
        return self.iteration(epoch, self.test_data, train=False)

    def iteration(self, epoch, data_loader, train=True):
        """
        loop over the data_loader for training or testing
        if on train status, backward operation is activated
        and also auto save the model every peoch

        :param epoch: current epoch index
        :param data_loader: torch.utils.data.DataLoader for iteration
        :param train: boolean value of is train or test
        :return: None
        """
        # pdb.set_trace()
        str_code = "train" if train else "test"

        # Setting the tqdm progress bar
        data_iter = tqdm.tqdm(enumerate(data_loader),
                              desc="EP_%s:%d" % (str_code, epoch),
                              total=len(data_loader),
                              bar_format="{l_bar}{r_bar}")

        avg_loss = 0.0
        total_correct = 0
        total_element = 0

        def calculate_iter(data):
            next_sent_output, mask_lm_output = self.model.forward(
                data["bert_input"], data["segment_label"], data["adj_mat"],
                train)
            mask_loss = self.criterion(mask_lm_output.transpose(1, 2),
                                       data["bert_label"])
            loss = mask_loss
            return loss

        for i, data in data_iter:
            # 0. batch_data will be sent into the device(GPU or cpu)
            # pdb.set_trace()
            data = data[0]
            data = {key: value.to(self.device) for key, value in data.items()}

            if train:
                loss = calculate_iter(data)
            else:
                with torch.no_grad():
                    loss = calculate_iter(data)
            # 1. forward the next_sentence_prediction and masked_lm model
            # next_sent_output, mask_lm_output = self.model.forward(
            #     data["bert_input"], data["segment_label"], data["adj_mat"], train)
            # # pdb.set_trace()
            # # 2-1. NLL(negative log likelihood) loss of is_next classification result
            # # next_loss = self.criterion(next_sent_output, data["is_next"])

            # # 2-2. NLLLoss of predicting masked token word
            # mask_loss = self.criterion(
            #     mask_lm_output.transpose(1, 2), data["bert_label"])
            # # pdb.set_trace()
            # # 2-3. Adding next_loss and mask_loss : 3.4 Pre-training Procedure
            # # loss = next_loss + mask_loss
            # loss = mask_loss

            # 3. backward and optimization only in train
            if train:
                self.optim.zero_grad()
                loss.backward()
                # self.optim.step_and_update_lr()
                self.optim.step()
            # pdb.set_trace()
            # mlm prediction accuracy
            # correct = next_sent_output.argmax(
            #     dim=-1).eq(data["is_next"]).sum().item()
            correct = 0
            elements = 0
            for labels, t_labels in zip(mask_lm_output.argmax(dim=-1),
                                        data["bert_label"]):
                correct += sum([
                    1 if l == t and t != self.pad_index else 0
                    for l, t in zip(labels, t_labels)
                ])
                elements += sum([1 for t in t_labels if t != self.pad_index])
            # next sentence prediction accuracy
            # correct = next_sent_output.argmax(
            #     dim=-1).eq(data["is_next"]).sum().item()
            avg_loss += loss.item()
            total_correct += correct
            # total_element += data["is_next"].nelement()
            total_element += elements

            post_fix = {
                "epoch": epoch,
                "iter": i,
                "avg_loss": avg_loss / (i + 1),
                "avg_acc": total_correct / total_element * 100,
                "loss": loss.item()
            }

            if i % self.log_freq == 0 and i != 0:
                data_iter.write(str(post_fix))

        print("EP%d_%s, avg_loss=" % (epoch, str_code),
              avg_loss / len(data_iter), "total_acc=",
              total_correct * 100.0 / total_element)
        return avg_loss / len(data_iter)

    def save(self, epoch, file_path="output/bert_trained.model"):
        """
        Saving the current BERT model on file_path

        :param epoch: current epoch number
        :param file_path: model output path which gonna be file_path+"ep%d" % epoch
        :return: final_output_path
        """
        # output_path = file_path + ".ep%d" % epoch
        # torch.save(self.bert.cpu(), output_path)
        # self.bert.to(self.device)
        # print("EP:%d Model Saved on:" % epoch, output_path)
        # return output_path

        output_path = file_path  # + ".ep%d" % epoch
        # if self.updated:
        #     return output_path
        # torch.save(self.bert.cpu(), output_path)
        torch.save(
            {
                'epoch': epoch,
                'model_state_dict': self.model.state_dict()
                # 'optimizer_state_dict': optimizer.state_dict(),
                # 'loss': loss,
                # ...
            },
            output_path)
        # self.bert.to(self.device)
        print("EP:%d Model Saved on:" % epoch, output_path)
        # self.updated = True
        return output_path
Example #2
0
        mask_loss = (torch.sum((frame1-data["visual_word"][:,:max_frames,:])**2)\
                  +torch.sum((frame2-data["visual_word"][:,max_frames:,:])**2))\
                  /(2*max_frames*feature_size*batchsize)

        mu_loss = (torch.sum((torch.mean(hid1,1)-data['n1'])**2)\
            +torch.sum((torch.mean(hid2,1)-data['n2'])**2))/(hidden_size*batchsize)
        loss = 0.92 * mask_loss + 0.08 * nei_loss + 0.8 * mu_loss
        # loss = mask_loss

        loss.backward()
        optimizer.step()
        itera += 1
        infos['iter'] = itera
        infos['epoch'] = epoch
        if itera % 10 == 0 or batchsize < batch_size:
            print 'Epoch:%d Step:[%d/%d] neiloss: %.2f maskloss: %.2f mu_loss: %.2f' \
            % (epoch, i, total_len, nei_loss.data.cpu().numpy(),\
               mask_loss.data.cpu().numpy(),mu_loss.data.cpu().numpy())

    torch.save(model.state_dict(), file_path + '/9288.pth')
    torch.save(optimizer.state_dict(), optimizer_pth_path)

    with open(os.path.join(file_path, 'infos.pkl'), 'wb') as f:
        pickle.dump(infos, f)
    with open(os.path.join(file_path, 'histories.pkl'), 'wb') as f:
        pickle.dump(histories, f)
    epoch += 1
    if epoch > num_epochs:
        break
    model.train()