Exemplo n.º 1
0
    def eval_model(self, test=False, test_idx=0):
        self.model.eval()
        self.preds_list, self.probs_list, self.labels_list, self.loss_list, self.id_list = [], [], [], [], []
        batch_loader = self.config['val_loader'] if not test else self.config[
            'test_loader'][test_idx]
        with torch.no_grad():
            for iters, batch in enumerate(batch_loader):
                batch = self.batch_to_device(batch)
                if batch_loader.dataset.return_ids:
                    self.id_list.append(batch['ids'])
                self.eval_iter_step(iters, batch, test=test)

            self.probs_list = [
                prob for batch_prob in self.probs_list for prob in batch_prob
            ]
            self.preds_list = [
                pred for batch_pred in self.preds_list for pred in batch_pred
            ]
            self.labels_list = [
                label for batch_labels in self.labels_list
                for label in batch_labels
            ]
            self.id_list = [
                data_id for batch_id in self.id_list for data_id in batch_id
            ]

            val_loss = sum(self.loss_list) / len(self.loss_list)
            eval_metrics = standard_metrics(torch.tensor(self.probs_list),
                                            torch.tensor(self.labels_list),
                                            add_optimal_acc=True)
            # if test:
            # 	print(classification_report(np.array(self.labels_list), np.array(self.preds_list)))
        return eval_metrics, val_loss
Exemplo n.º 2
0
    def end_training(self):
        # Termination message
        print("\n" + "-" * 100)
        if self.terminate_training:
            LOGGER.info(
                "Training terminated early because the Validation {} did not improve for  {}  epochs"
                .format(self.config['optimize_for'], self.config['patience']))
        else:
            LOGGER.info(
                "Maximum epochs of {} reached. Finished training !!".format(
                    self.config['max_epoch']))

        print_test_stats(self.best_val_metrics, test=False)

        print("-" * 50 + "\n\t\tEvaluating on test set\n" + "-" * 50)
        if not self.config["no_model_checkpoints"]:
            if os.path.isfile(self.model_file):
                self.load_model()
                self.model.to(self.device)
            else:
                raise ValueError(
                    "No Saved model state_dict found for the chosen model...!!! \nAborting evaluation on test set..."
                    .format(self.config['model_name']))

            self.export_val_predictions(
            )  # Runs evaluation, no need to run it again here
            val_probs = torch.tensor(self.probs_list)
            val_labels = torch.tensor(self.labels_list)
            threshold = 0.5  # the default threshelod for binary classification
            # Uncomment below line if you have implemented this optional feature
            # threshold = find_optimal_threshold(val_probs, val_labels, metric="accuracy")
            best_val_metrics = standard_metrics(val_probs,
                                                val_labels,
                                                threshold=threshold,
                                                add_aucroc=False)
            LOGGER.info(
                "Optimal threshold on validation dataset: %.4f (accuracy=%4.2f%%)"
                % (threshold, 100.0 * best_val_metrics["accuracy"]))

            # Testing is in the standard form not possible, as we do not have any labels (gives an error in standard_metrics)
            # Instead, we should write out the predictions in the form of the leaderboard
            self.test_metrics = dict()
            for test_idx in range(len(self.config['test_loader'])):
                test_name = self.config['test_loader'][test_idx].dataset.name
                LOGGER.info("Export and testing on %s..." % test_name)
                if hasattr(self.config['test_loader'][test_idx].dataset, "data") and \
                   hasattr(self.config['test_loader'][test_idx].dataset.data, "labels") and \
                   self.config['test_loader'][test_idx].dataset.data.labels[0] == -1:  # Step 1: Find the optimal threshold on validation set
                    self.export_test_predictions(test_idx=test_idx,
                                                 threshold=threshold)
                    self.test_metrics[test_name] = dict()
                else:
                    test_idx_metrics, _ = self.eval_model(test=True,
                                                          test_idx=test_idx)
                    self.test_metrics[test_name] = test_idx_metrics
                    print_test_stats(test_idx_metrics, test=True)
                    self.export_val_predictions(test=True,
                                                test_idx=test_idx,
                                                threshold=threshold)
        else:
            LOGGER.info(
                "No model checkpoints were saved. Hence, testing will be skipped."
            )
            self.test_metrics = dict()

        self.export_metrics()

        self.config['writer'].close()

        if self.config['remove_checkpoints']:
            LOGGER.info("Removing checkpoint %s..." % self.model_file)
            os.remove(self.model_file)
Exemplo n.º 3
0
    def train_epoch_step(self):
        self.model.train()
        lr = self.scheduler.get_last_lr()
        self.total_iters += self.iters + 1
        self.probs_list = [
            pred for batch_pred in self.probs_list for pred in batch_pred
        ]
        self.labels_list = [
            label for batch_labels in self.labels_list
            for label in batch_labels
        ]

        # Evaluate on train set
        self.train_metrics = standard_metrics(torch.tensor(self.probs_list),
                                              torch.tensor(self.labels_list),
                                              add_optimal_acc=True)
        log_tensorboard(self.config,
                        self.config['writer'],
                        self.model,
                        self.epoch,
                        self.iters,
                        self.total_iters,
                        self.loss_list,
                        self.train_metrics,
                        lr[0],
                        loss_only=False,
                        val=False)
        self.train_loss = self.loss_list[:]

        # Evaluate on dev set
        val_time = time.time()
        self.val_metrics, self.val_loss = self.eval_model()
        self.config['writer'].add_scalar("Stats/time_validation",
                                         time.time() - val_time,
                                         self.total_iters)

        # print stats
        print_stats(self.config, self.epoch, self.train_metrics,
                    self.train_loss, self.val_metrics, self.val_loss,
                    self.start, lr[0])

        # log validation stats in tensorboard
        log_tensorboard(self.config,
                        self.config['writer'],
                        self.model,
                        self.epoch,
                        self.iters,
                        self.total_iters,
                        self.val_loss,
                        self.val_metrics,
                        lr[0],
                        loss_only=False,
                        val=True)

        # Check for early stopping criteria
        self.check_early_stopping()
        self.probs_list = []
        self.preds_list = []
        self.labels_list = []
        self.loss_list = []
        self.id_list = []

        self.train_loss = sum(self.train_loss) / len(self.train_loss)
        del self.val_metrics
        del self.val_loss