def train_one_epoch(self, model, dataloader, optimizer, scheduler, num_epochs, max_grad_norm=None, debugging=False): """Train the model for one epoch.""" model.train() timer = Timer() print( ("{:25}" + "|" + "{:^15}" * (3 + len(self.early_stopping_metrics)) + "|").format( "", "l1_loss", "l2_loss", "l3_loss", *self.early_stopping_metrics) ) total = 10 if debugging else len(dataloader) with tqdm(dataloader, total=total) as t: if num_epochs is not None: description = f"Training ({self.epoch}/{num_epochs})" else: description = "Training" t.set_description(description) for i, data in enumerate(t): timer.start() data = to_device(data, self.device) optimizer.zero_grad() # Forward output = model(**data) losses = output["losses"] # Calculate batch metrics metric = compute_metrics_from_inputs_and_outputs( inputs=data, outputs=output, tokenizer=self.tokenizer, save_csv_path=None) losses.update(metric) # Update tqdm with training information to_tqdm = [] # update tqdm for loss_type in ["l1_cls_loss", "l2_cls_loss", "l3_cls_loss", *self.early_stopping_metrics]: loss_n = losses[loss_type] if isinstance(loss_n, torch.Tensor) and torch.isnan(loss_n): to_tqdm.append("nan") else: to_tqdm.append(f"{loss_n.item():.3f}") des = ( "{:25}" + "|" + "{:^15}" * (3 + len(self.early_stopping_metrics)) + "|" ).format(description, *to_tqdm) t.set_description(des) # Backward losses["total_loss"].backward() if max_grad_norm is not None: torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_grad_norm) optimizer.step() if scheduler is not None: scheduler.step() timer.end() # Break when reaching 10 iterations when debugging if debugging and i == 9: break logger.info(f"{description} took {timer.get_total_time():.2f}s.") return
def train_one_epoch(self, model, dataloader, optimizer, scheduler, num_epochs, max_grad_norm=None, debugging=False): """Train the model for one epoch.""" model.train() timer = Timer() print(("{:25}" + "|" + "{:^45}" + "|" + "{:^45}" + "|" + "{:^45}" + "|").format("", "food", "service", "price")) print(("{:25}" + "|" + "{:^15}" * 3 + "|" + "{:^15}" * 3 + "|" + "{:^15}" * 3 + "|").format("", "mse_loss", "existence_loss", "acc", "mse_loss", "existence_loss", "acc", "mse_loss", "existence_loss", "acc")) total = 10 if debugging else len(dataloader) with tqdm(dataloader, total=total) as t: if num_epochs is not None: description = f"Training ({self.epoch}/{num_epochs})" else: description = "Training" t.set_description(description) for i, data in enumerate(t): timer.start() data = to_device(data, self.device) optimizer.zero_grad() # Forward output = model(**data, is_training=True) losses = output["losses"] # Calculate batch accuracy acc = compute_metrics_from_inputs_and_outputs( inputs=data, outputs=output, output_acc=True, confidence_threshold=self.config["evaluation"] ["confidence_threshold"]) losses.update(acc) # Update tqdm with training information to_tqdm = [] # update tqdm for name in ["food", "service", "price"]: for loss_type in ["score_loss", "existence_loss", "acc"]: n = f"{name}_{loss_type}" loss_n = losses[n] if (not isinstance( loss_n, torch.Tensor)) or torch.isnan(loss_n): to_tqdm.append("nan") else: to_tqdm.append(f"{loss_n.item():.3f}") des = ("{:25}" + "|" + "{:^15}" * 3 + "|" + "{:^15}" * 3 + "|" + "{:^15}" * 3 + "|").format(description, *to_tqdm) t.set_description(des) # Backward losses["total_loss"].backward() if max_grad_norm is not None: torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_grad_norm) optimizer.step() if scheduler is not None: scheduler.step() timer.end() # Break when reaching 10 iterations when debugging if debugging and i == 9: break logger.info(f"{description} took {timer.get_total_time():.2f}s.") return