def run_epoch(self, iteration_loss=False):
        """Runs an epoch of training.

        Keyword arguments:
        - iteration_loss (``bool``, optional): Prints loss at every step.

        Returns:
        - The epoch loss (float).

        """
        compression_scheduler = self.compression_algo.scheduler

        self.model.train()
        epoch_loss = 0.0
        self.metric.reset()
        for step, batch_data in enumerate(self.data_loader):
            # Get the inputs and labels
            inputs = batch_data[0].to(self.device)
            labels = batch_data[1].to(self.device)

            # Forward propagation
            outputs = self.model(inputs)

            labels, loss_outputs, metric_outputs = do_model_specific_postprocessing(
                self.model_name, labels, outputs)

            # Loss computation
            loss = self.criterion(loss_outputs, labels)

            compression_loss = self.compression_algo.loss()
            loss += compression_loss

            # Backpropagation
            self.optim.zero_grad()
            loss.backward()
            self.optim.step()

            compression_scheduler.step()

            # Keep track of loss for current epoch
            epoch_loss += loss.item()

            # Keep track of the evaluation metric
            self.metric.add(metric_outputs.detach(), labels.detach())

            if iteration_loss:
                print("[Step: %d] Iteration loss: %.4f" % (step, loss.item()))

        return epoch_loss / len(self.data_loader), self.metric.value()
Ejemplo n.º 2
0
    def run_epoch(self, iteration_loss=False):
        """Runs an epoch of validation.

        Keyword arguments:
        - iteration_loss (``bool``, optional): Prints loss at every step.

        Returns:
        - The epoch loss (float), and the values of the specified metrics

        """
        self.model.eval()
        epoch_loss = 0.0
        self.metric.reset()
        for step, batch_data in tqdm(enumerate(self.data_loader), total=len(self.data_loader)):
            # Get the inputs and labels
            inputs = batch_data[0].to(self.device)
            labels = batch_data[1].to(self.device)

            with torch.no_grad():
                # Forward propagation
                outputs = self.model(inputs)

                labels, loss_outputs, metric_outputs = do_model_specific_postprocessing(self.model_name,
                                                                                        labels,
                                                                                        outputs)

                # Loss computation
                loss = self.criterion(loss_outputs, labels)

            # Keep track of loss for current epoch
            epoch_loss += loss.item()

            self.metric.add(metric_outputs.detach(), labels.detach())

            if iteration_loss:
                logger.info("[Step: {}] Iteration loss: {:.4f}".format(step, loss.item()))

        return epoch_loss / len(self.data_loader), self.metric.value()
Ejemplo n.º 3
0
 def criterion_fn(model_outputs, target, criterion_):
     labels, loss_outputs, _ = \
         loss_funcs.do_model_specific_postprocessing(config.model, target, model_outputs)
     return criterion_(loss_outputs, labels)