def eval(self, dev_loader): self.model.eval() eval_loss = 0 for step, (_, features, labels) in enumerate(dev_loader): if self.ngpu > 0: features = map_to_cuda(features) labels = map_to_cuda(labels) loss = self.model(features, labels) eval_loss += loss.item() return eval_loss / (step + 1)
def train_one_epoch(self, epoch, train_loader): self.model.train() batch_steps = len(train_loader) step_loss = AverageMeter() for step, (_, batch) in enumerate(train_loader): if self.ngpu > 0: batch = map_to_cuda(batch) start = time.process_time() loss = self.model(**batch) loss = torch.mean(loss) / self.accum_steps if self.mixed_precision: import apex.amp as amp with amp.scale_loss(loss, self.optimizer.optimizer) as scaled_loss: scaled_loss.backward() else: loss.backward() if self.grad_noise: raise NotImplementedError if self.get_rank() == 0: step_loss.update(loss.item() * self.accum_steps, batch['inputs'].size(0)) if step % self.accum_steps == 0: if self.local_rank == 0: self.mean_loss.update(step_loss.avg) grad_norm = torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.grad_clip) if math.isnan(grad_norm): self.logger.warning('Grad norm is NAN. DO NOT UPDATE MODEL!') else: self.optimizer.step() self.optimizer.zero_grad() if self.is_visual and self.local_rank == 0: self.visulizer.add_scalar('train_loss', loss.item(), self.global_step) self.visulizer.add_scalar('lr', self.optimizer.lr, self.global_step) if self.global_step % self.log_interval == 0 and self.local_rank == 0: end = time.process_time() process = step * self.world_size / batch_steps * 100 self.logger.info('-Training-Epoch-%d(%.5f%%), Global Step:%d, lr:%.8f, Loss:%.5f, AvgLoss: %.5f, ' 'Run Time:%.3f' % (epoch, process, self.global_step, self.optimizer.lr, step_loss.avg, self.mean_loss.mean(), end - start)) self.global_step += 1 step_loss.reset() return self.mean_loss.mean()
def eval(self, dev_loader): self.model.eval() eval_loss = 0 for step, (_, batch) in enumerate(dev_loader): if self.ngpu > 0: batch = map_to_cuda(batch) loss = self.model(**batch) eval_loss += loss.item() return eval_loss / (step + 1)