def run_epoch(self, iteration_loss=False): """Runs an epoch of training. Keyword arguments: - iteration_loss (``bool``, optional): Prints loss at every step. Returns: - The epoch loss (float). """ compression_scheduler = self.compression_algo.scheduler self.model.train() epoch_loss = 0.0 self.metric.reset() for step, batch_data in enumerate(self.data_loader): # Get the inputs and labels inputs = batch_data[0].to(self.device) labels = batch_data[1].to(self.device) # Forward propagation outputs = self.model(inputs) labels, loss_outputs, metric_outputs = do_model_specific_postprocessing( self.model_name, labels, outputs) # Loss computation loss = self.criterion(loss_outputs, labels) compression_loss = self.compression_algo.loss() loss += compression_loss # Backpropagation self.optim.zero_grad() loss.backward() self.optim.step() compression_scheduler.step() # Keep track of loss for current epoch epoch_loss += loss.item() # Keep track of the evaluation metric self.metric.add(metric_outputs.detach(), labels.detach()) if iteration_loss: print("[Step: %d] Iteration loss: %.4f" % (step, loss.item())) return epoch_loss / len(self.data_loader), self.metric.value()
def run_epoch(self, iteration_loss=False): """Runs an epoch of validation. Keyword arguments: - iteration_loss (``bool``, optional): Prints loss at every step. Returns: - The epoch loss (float), and the values of the specified metrics """ self.model.eval() epoch_loss = 0.0 self.metric.reset() for step, batch_data in tqdm(enumerate(self.data_loader), total=len(self.data_loader)): # Get the inputs and labels inputs = batch_data[0].to(self.device) labels = batch_data[1].to(self.device) with torch.no_grad(): # Forward propagation outputs = self.model(inputs) labels, loss_outputs, metric_outputs = do_model_specific_postprocessing(self.model_name, labels, outputs) # Loss computation loss = self.criterion(loss_outputs, labels) # Keep track of loss for current epoch epoch_loss += loss.item() self.metric.add(metric_outputs.detach(), labels.detach()) if iteration_loss: logger.info("[Step: {}] Iteration loss: {:.4f}".format(step, loss.item())) return epoch_loss / len(self.data_loader), self.metric.value()
def criterion_fn(model_outputs, target, criterion_): labels, loss_outputs, _ = \ loss_funcs.do_model_specific_postprocessing(config.model, target, model_outputs) return criterion_(loss_outputs, labels)