예제 #1
0
    def test_on_loader(self, data_loader, max_iter=None):
        """Iterate over the validation set

        Args:
            data_loader: iterable validation data loader
            max_iter: max number of iterations to perform if the end of the dataset is not reached
        """
        self.model.eval()
        test_loss_meter = BasicMeter.get("test_loss").reset()
        test_accuracy_meter = BasicMeter.get("test_accuracy").reset()
        test_accuracy = []
        # Iterate through tasks, each iteration loads n tasks, with n = number of GPU
        for batch_idx, _data in enumerate(data_loader):
            batch = _data[0]
            loss, accuracy = self.val_on_batch(batch)
            test_loss_meter.update(float(loss), 1)
            test_accuracy_meter.update(float(accuracy), 1)
            test_accuracy.append(float(accuracy))
        from scipy.stats import sem, t
        confidence = 0.95
        n = len(test_accuracy)
        std_err = sem(np.array(test_accuracy))
        h = std_err * t.ppf((1 + confidence) / 2, n - 1)
        return {
            "test_loss": test_loss_meter.mean(),
            "test_accuracy": test_accuracy_meter.mean(),
            "test_confidence": h
        }
예제 #2
0
    def train_on_loader(self,
                        data_loader,
                        max_iter=None,
                        debug_plot_path=None):
        """Iterate over the training set

        Args:
            data_loader: iterable training data loader
            max_iter: max number of iterations to perform if the end of the dataset is not reached
        """
        self.model.train()
        train_loss_meter = BasicMeter.get("train_loss").reset()
        # Iterate through tasks, each iteration loads n tasks, with n = number of GPU
        self.optimizer.zero_grad()
        for batch_idx, batch in enumerate(data_loader):
            loss = self.train_on_batch(
                batch) / self.exp_dict["tasks_per_batch"]
            train_loss_meter.update(float(loss), 1)
            loss.backward()
            if ((batch_idx + 1) % self.exp_dict["tasks_per_batch"]) == 0:
                self.optimizer.step()
                self.optimizer.zero_grad()
            if batch_idx + 1 == max_iter:
                break
        return {"train_loss": train_loss_meter.mean()}
예제 #3
0
 def train_on_loader(self,
                     data_loader,
                     max_iter=None,
                     debug_plot_path=None):
     """Iterate over the training set
     
     Args:
         data_loader (torch.utils.data.DataLoader): a pytorch dataloader
         max_iter (int, optional): Max number of iterations if the end of the dataset is not reached. Defaults to None.
     
     Returns:
         metrics: dictionary with metrics of the training set
     """
     self.model.train()
     train_loss_meter = BasicMeter.get("train_loss").reset()
     # Iterate through tasks, each iteration loads n tasks, with n = number of GPU
     for batch_idx, batch in enumerate(data_loader):
         self.optimizer.zero_grad()
         loss = self.train_on_batch(batch)
         train_loss_meter.update(float(loss), 1)
         loss.backward()
         self.optimizer.step()
         if batch_idx + 1 == max_iter:
             break
     return {"train_loss": train_loss_meter.mean()}
예제 #4
0
    def test_on_loader(self, data_loader, max_iter=None):
        """Iterate over the validation set

        Args:
            data_loader: iterable validation data loader
            max_iter: max number of iterations to perform if the end of the dataset is not reached
        """
        self.model.eval()
        test_loss_meter = BasicMeter.get("test_loss").reset()
        test_accuracy_meter = BasicMeter.get("test_accuracy").reset()
        # Iterate through tasks, each iteration loads n tasks, with n = number of GPU
        for batch_idx, _data in enumerate(data_loader):
            batch = _data[0]
            loss, accuracy = self.val_on_batch(batch)
            test_loss_meter.update(float(loss), 1)
            test_accuracy_meter.update(float(accuracy), 1)
        return {"test_loss": test_loss_meter.mean(), "test_accuracy": test_accuracy_meter.mean()}
예제 #5
0
    def val_on_loader(self, data_loader, max_iter=None):
        """Iterate over the validation set

        Args:
            data_loader: iterable validation data loader
            max_iter: max number of iterations to perform if the end of the dataset is not reached
        """
        self.model.eval()
        val_loss_meter = BasicMeter.get("val_loss").reset()
        val_accuracy_meter = BasicMeter.get("val_accuracy").reset()
        # Iterate through tasks, each iteration loads n tasks, with n = number of GPU
        for batch_idx, _data in enumerate(data_loader):
            batch = _data[0]
            loss, accuracy = self.val_on_batch(batch)
            val_loss_meter.update(float(loss), 1)
            val_accuracy_meter.update(float(accuracy), 1)
        loss = BasicMeter.get(self.exp_dict["target_loss"], recursive=True, force=False).mean()
        self.scheduler.step(loss)  # update the learning rate monitor
        return {"val_loss": val_loss_meter.mean(), "val_accuracy": val_accuracy_meter.mean()}
예제 #6
0
    def test_on_loader(self, data_loader, max_iter=None):
        """Iterate over the validation set

        Args:
            data_loader: iterable validation data loader
            max_iter: max number of iterations to perform if the end of the dataset is not reached
        """
        self.model.eval()

        test_accuracy_meter = BasicMeter.get("test_accuracy").reset()
        test_accuracy = []
        # Iterate through tasks, each iteration loads n tasks, with n = number of GPU
        dirname = os.path.split(self.exp_dict["pretrained_weights_root"])[-1]
        with tqdm.tqdm(total=len(data_loader)) as pbar:
            for batch_all in data_loader:
                batch = batch_all[0]
                loss, accuracy = self.val_on_batch(batch)

                test_accuracy_meter.update(float(accuracy), 1)
                test_accuracy.append(float(accuracy))

                string = ("'%s' - %s - finetuned: %.3f -  ssl: %.3f" %
                          (self.label, dirname, self.best_accuracy,
                           test_accuracy_meter.mean()))
                # print(string)
                pbar.update(1)
                pbar.set_description(string)

        confidence = 0.95
        n = len(test_accuracy)
        std_err = sem(np.array(test_accuracy))
        h = std_err * t.ppf((1 + confidence) / 2, n - 1)
        return {
            "test_loss": -1,
            "ssl_accuracy": test_accuracy_meter.mean(),
            "ssl_confidence": h,
            'finetuned_accuracy': self.best_accuracy
        }