def _init_model(self): self.model = BNN(input_size=self.data.num_features, hidden_sizes=self.model_params['hidden_sizes'], output_size=self.data.num_classes, act_func=self.model_params['act_func'], prior_prec=self.model_params['prior_prec'], prec_init=self.optim_params['prec_init']) if self.use_cuda: self.model = self.model.cuda()
def __init__(self, data_set, model_params, train_params, optim_params, evals_per_epoch=1, normalize_x=False, normalize_y=False, results_folder="./results", data_folder=DEFAULT_DATA_FOLDER, use_cuda=torch.cuda.is_available()): super(type(self), self).__init__(data_set, model_params, train_params, optim_params, evals_per_epoch, normalize_x, normalize_y, results_folder, data_folder, use_cuda) # Define name for experiment class experiment_name = "bbb_mlp_reg" # Define folder name for results self.folder_name = folder_name(experiment_name, data_set, model_params, train_params, optim_params, results_folder) # Initialize model self.model = BNN(input_size = self.data.num_features, hidden_sizes = model_params['hidden_sizes'], output_size = self.data.num_classes, act_func = model_params['act_func'], prior_prec = model_params['prior_prec'], prec_init = optim_params['prec_init']) if use_cuda: self.model = self.model.cuda() # Define prediction function def prediction(x): mu_list = [self.model(x) for _ in range(self.train_params['train_mc_samples'])] return mu_list self.prediction = prediction # Define objective def objective(mu_list, y): return metrics.avneg_elbo_gaussian(mu_list, y, tau = self.model_params['noise_prec'], train_set_size = self.data.get_train_size(), kl = self.model.kl_divergence()) self.objective = objective # Initialize optimizer self.optimizer = Adam(self.model.parameters(), lr = optim_params['learning_rate'], betas = optim_params['betas'], eps = 1e-8) # Initialize metric history self.metric_history = dict(elbo_neg_ave = [], train_pred_logloss=[], train_pred_rmse=[], test_pred_logloss=[], test_pred_rmse=[]) # Initialize final metric self.final_metric = dict(elbo_neg_ave = [], train_pred_logloss=[], train_pred_rmse=[], test_pred_logloss=[], test_pred_rmse=[])