Esempio n. 1
0
 def _get_scheduler(self, optimizer, scheduler: str, warmup_steps: int, t_total: int):
     """
     Returns the correct learning rate scheduler
     """
     scheduler = scheduler.lower()
     if scheduler == 'constantlr':
         return pytorch_transformers.ConstantLRSchedule(optimizer)
     elif scheduler == 'warmupconstant':
         return pytorch_transformers.WarmupConstantSchedule(optimizer, warmup_steps=warmup_steps)
     elif scheduler == 'warmuplinear':
         return pytorch_transformers.WarmupLinearSchedule(optimizer, warmup_steps=warmup_steps, t_total=t_total)
     elif scheduler == 'warmupcosine':
         return pytorch_transformers.WarmupCosineSchedule(optimizer, warmup_steps=warmup_steps, t_total=t_total)
     elif scheduler == 'warmupcosinewithhardrestarts':
         return pytorch_transformers.WarmupCosineWithHardRestartsSchedule(optimizer, warmup_steps=warmup_steps, t_total=t_total)
     else:
         raise ValueError("Unknown scheduler {}".format(scheduler))
    def train(self):
        if not self.pretrained_model:
            model = GPT2KWModel(config=self.model_config)
        else:
            self.print_and_log('加载预训练模型')
            model = GPT2KWModel.from_pretrained(self.pretrained_model)
        model.train()
        model.to(self.device)
        # 计算模型参数量
        num_parameters = 0
        parameters = model.parameters()
        for parameter in parameters:
            num_parameters += parameter.numel()
        self.print_and_log('模型参数量: {}'.format(num_parameters))

        self.print_and_log("开始加载训练集")
        train_loader, valid_loader = self.create_dataloader()
        self.print_and_log("训练集加载完毕")

        epoch_steps = int(train_loader.sampler.num_samples / self.batch_size / self.accumulation_steps)
        total_steps = epoch_steps * self.epochs
        self.print_and_log('总样本数 = {}'.format(train_loader.sampler.num_samples))
        self.print_and_log('epoch 步数 = {}'.format(epoch_steps))
        self.print_and_log('总步数 = {}'.format(total_steps))

        optimizer = pytorch_transformers.AdamW(model.parameters(), lr=self.lr, correct_bias=True)
        # scheduler = pytorch_transformers.WarmupLinearSchedule(optimizer, warmup_steps=self.warmup_steps, t_total=total_steps)
        scheduler = pytorch_transformers.WarmupCosineSchedule(optimizer, warmup_steps=self.warmup_steps, t_total=total_steps)
        if self.fp16:
            try:
                from apex import amp
            except ImportError:
                raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.")
            model, optimizer = amp.initialize(model, optimizer, opt_level=self.fp16_opt_level)

        if torch.cuda.device_count() > 1:
            model = DataParallel(model)
            multi_gpu = True
        else:
            multi_gpu = False

        overall_step = 0
        running_loss = 0
        model.train()
        for epoch in range(self.epochs):
            self.print_and_log('epoch {}'.format(epoch + 1))
            now = datetime.now()
            self.print_and_log('time: {}'.format(now))
            optimizer.zero_grad()
            for i, batch_data in enumerate(train_loader):
                if torch.cuda.is_available():
                    keyword_ids = batch_data[0].to(self.device, non_blocking=True)
                    passage_ids = batch_data[1].to(self.device, non_blocking=True)
                    label_ids = passage_ids.clone().to(self.device, non_blocking=True)
                else:
                    keyword_ids = batch_data[0]
                    passage_ids = batch_data[1]
                    label_ids = passage_ids.clone()
                outputs = model(input_ids=passage_ids, keyword_ids=keyword_ids, labels=label_ids)
                loss, logits = outputs[:2]
                # 多 GPU 训练
                if multi_gpu:
                    loss = loss.mean()
                # 梯度累加
                if self.gradient_accumulation > 1:
                    loss = loss / self.gradient_accumulation
                #  混合精度训练或正常训练
                if self.fp16:
                    with amp.scale_loss(loss, optimizer) as scaled_loss:
                        scaled_loss.backward()
                        torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), self.max_grad_norm)
                else:
                    loss.backward()
                    torch.nn.utils.clip_grad_norm_(model.parameters(), self.max_grad_norm)
                # 更新权重
                if (i + 1) % self.gradient_accumulation == 0:
                    running_loss += loss.item()
                    scheduler.step()
                    optimizer.step()
                    optimizer.zero_grad()
                    overall_step += 1
                # 报告 train loss
                if (overall_step + 1) % self.log_step == 0 and running_loss != 0:
                    self.print_and_log('now time: {}:{}. Step {} of epoch {}, loss {}'.format(
                        datetime.now().hour,
                        datetime.now().minute,
                        overall_step + 1,
                        epoch + 1,
                        running_loss * self.gradient_accumulation / self.log_step))
                    running_loss = 0

            # 开始验证
            with torch.no_grad():
                valid_start_time = datetime.now()
                model.eval()
                valid_loss = 0
                valid_step = 0
                for i, valid_batch_data in enumerate(valid_loader):
                    if torch.cuda.is_available():
                        keyword_ids = valid_batch_data[0].to(self.device, non_blocking=True)
                        passage_ids = valid_batch_data[1].to(self.device, non_blocking=True)
                        label_ids = passage_ids.clone().to(self.device, non_blocking=True)
                    else:
                        keyword_ids = valid_batch_data[0]
                        passage_ids = valid_batch_data[1]
                        label_ids = passage_ids.clone()
                    outputs = model(input_ids=passage_ids, keyword_ids=keyword_ids, labels=label_ids)
                    loss, logits = outputs[:2]
                    valid_loss += loss
                    valid_step += 1
                valid_loss = valid_loss / valid_step
                self.print_and_log('valid duration: {}, valid loss: {}'.format(datetime.now() - valid_start_time, valid_loss))

            # 保存模型
            if (epoch + 1) % 1 == 0:
                if not os.path.exists(self.output_dir + 'model_epoch{}'.format(epoch + 1)):
                    os.makedirs(self.output_dir + 'model_epoch{}'.format(epoch + 1))
                model_to_save = model.module if hasattr(model, 'module') else model
                model_to_save.save_pretrained(self.output_dir + 'model_epoch{}'.format(epoch + 1))
                # torch.save(scheduler.state_dict(), output_dir + 'model_epoch{}/scheduler.pt'.format(epoch + 1))
                # torch.save(optimizer.state_dict(), output_dir + 'model_epoch{}/optimizer.pt'.format(epoch + 1))

            then = datetime.now()
            self.print_and_log('time: {}'.format(then))
            self.print_and_log('time for one epoch: {}'.format(then - now))
            model.train()

        self.print_and_log('training finished')
        self.f_log.close()
        if not os.path.exists(self.output_dir + 'final_model'):
            os.makedirs(self.output_dir + 'final_model')
        model_to_save = model.module if hasattr(model, 'module') else model
        model_to_save.save_pretrained(self.output_dir + 'final_model')