예제 #1
0
    def _evaluate_model(self, metric_dict, x_train, y_train, x_test, y_test):

        # Normalize train x
        if self.normalize_x:
            x_train = (x_train - self.x_means) / self.x_stds

        # Get train predictions
        logits_list = self.optimizer.get_mc_predictions(
            self.model.forward,
            inputs=x_train,
            mc_samples=self.train_params['eval_mc_samples'],
            ret_numpy=False)

        # Store train metrics
        metric_dict['train_pred_logloss'].append(
            metrics.predictive_avneg_loglik_categorical(
                logits_list, y_train).detach().cpu().item())
        metric_dict['train_pred_accuracy'].append(
            metrics.softmax_predictive_accuracy(logits_list,
                                                y_train).detach().cpu().item())
        metric_dict['elbo_neg_ave'].append(
            metrics.avneg_elbo_categorical(
                logits_list,
                y_train,
                train_set_size=self.data.get_train_size(),
                kl=self.optimizer.kl_divergence()).detach().cpu().item())

        # Normalize test x
        if self.normalize_x:
            x_test = (x_test - self.x_means) / self.x_stds

        # Get test predictions
        logits_list = self.optimizer.get_mc_predictions(
            self.model.forward,
            inputs=x_test,
            mc_samples=self.train_params['eval_mc_samples'],
            ret_numpy=False)

        # Store test metrics
        metric_dict['test_pred_logloss'].append(
            metrics.predictive_avneg_loglik_categorical(
                logits_list, y_test).detach().cpu().item())
        metric_dict['test_pred_accuracy'].append(
            metrics.softmax_predictive_accuracy(logits_list,
                                                y_test).detach().cpu().item())
예제 #2
0
 def objective(logits_list, y):
     return metrics.avneg_elbo_categorical(
         logits_list,
         y,
         train_set_size=self.data.get_train_size(),
         kl=self.model.kl_divergence())