def run(dataset, net_type, train=True): # Hyper Parameter settings train_ens = cfg.train_ens valid_ens = cfg.valid_ens test_ens = cfg.test_ens n_epochs = cfg.n_epochs lr_start = cfg.lr_start num_workers = cfg.num_workers valid_size = cfg.valid_size batch_size = cfg.batch_size trainset, testset, inputs, outputs = data.getDataset_regression(dataset) train_loader, valid_loader, test_loader = data.getDataloader( trainset, testset, valid_size, batch_size, num_workers) net = getModel(net_type, inputs, outputs).to(device) print(len(train_loader)) print(len(valid_loader)) print(len(test_loader)) ckpt_dir = f'checkpoints/regression/{dataset}/bayesian' ckpt_name = f'checkpoints/regression/{dataset}/bayesian/model_{net_type}.pt' if not os.path.exists(ckpt_dir): os.makedirs(ckpt_dir, exist_ok=True) criterion = metrics.ELBO_regression_hetero(len(trainset)).to(device) if train: optimizer = Adam(net.parameters(), lr=lr_start) valid_loss_max = np.Inf for epoch in range(n_epochs): # loop over the dataset multiple times cfg.curr_epoch_no = epoch utils.adjust_learning_rate(optimizer, metrics.lr_linear(epoch, 0, n_epochs, lr_start)) train_loss, train_mse, train_kl = train_model(net, optimizer, criterion, train_loader, num_ens=train_ens) valid_loss, valid_mse = validate_model(net, criterion, valid_loader, num_ens=valid_ens) print('Epoch: {} \tTraining Loss: {:.4f} \tTraining MSE: {:.4f} \tValidation Loss: {:.4f} \tValidation MSE: {:.4f} \ttrain_kl_div: {:.4f}'.format( epoch, train_loss, train_mse, valid_loss, valid_mse, train_kl)) # save model if validation MSE has increased if valid_loss <= valid_loss_max: print('Validation loss decreased ({:.6f} --> {:.6f}). Saving model ...'.format( valid_loss_max, valid_loss)) torch.save(net.state_dict(), ckpt_name) valid_loss_max = valid_loss # test saved model best_model = getModel(net_type, inputs, outputs).to(device) best_model.load_state_dict(torch.load(ckpt_name)) test_loss, test_mse = test_model(best_model, criterion, test_loader, num_ens=test_ens) print('Test Loss: {:.4f} \tTest MSE: {:.4f} '.format( test_loss, test_mse)) test_uncertainty(best_model, testset[:100], data='ccpp')
def run(dataset, net_type): # Hyper Parameter settings train_ens = cfg.train_ens valid_ens = cfg.valid_ens n_epochs = cfg.n_epochs lr_start = cfg.lr_start num_workers = cfg.num_workers valid_size = cfg.valid_size batch_size = cfg.batch_size trainset, testset, inputs, outputs = data.getDataset(dataset) train_loader, valid_loader, test_loader = data.getDataloader( trainset, testset, valid_size, batch_size, num_workers) net = getModel(net_type, inputs, outputs).to(device) ckpt_dir = f'checkpoints/{dataset}/bayesian' ckpt_name = f'checkpoints/{dataset}/bayesian/model_{net_type}.pt' if not os.path.exists(ckpt_dir): os.makedirs(ckpt_dir, exist_ok=True) criterion = metrics.ELBO(len(trainset)).to(device) optimizer = Adam(net.parameters(), lr=lr_start) valid_loss_max = np.Inf for epoch in range(n_epochs): # loop over the dataset multiple times utils.adjust_learning_rate( optimizer, metrics.lr_linear(epoch, 0, n_epochs, lr_start)) train_loss, train_acc, train_kl = train_model(net, optimizer, criterion, train_loader, num_ens=train_ens) valid_loss, valid_acc = validate_model(net, criterion, valid_loader, num_ens=valid_ens) print( 'Epoch: {} \tTraining Loss: {:.4f} \tTraining Accuracy: {:.4f} \tValidation Loss: {:.4f} \tValidation Accuracy: {:.4f} \ttrain_kl_div: {:.4f}' .format(epoch, train_loss, train_acc, valid_loss, valid_acc, train_kl)) # save model if validation accuracy has increased if valid_loss <= valid_loss_max: print( 'Validation loss decreased ({:.6f} --> {:.6f}). Saving model ...' .format(valid_loss_max, valid_loss)) torch.save(net.state_dict(), ckpt_name) valid_loss_max = valid_loss
def train_splitted(num_tasks, bayesian=True, net_type='lenet'): assert 10 % num_tasks == 0 # Hyper Parameter settings train_ens = cfg.train_ens valid_ens = cfg.valid_ens n_epochs = cfg.n_epochs lr_start = cfg.lr_start if bayesian: ckpt_dir = f"checkpoints/MNIST/bayesian/splitted/{num_tasks}-tasks/" else: ckpt_dir = f"checkpoints/MNIST/frequentist/splitted/{num_tasks}-tasks/" if not os.path.exists(ckpt_dir): os.makedirs(ckpt_dir, exist_ok=True) loaders, datasets = mix_utils.get_splitmnist_dataloaders( num_tasks, return_datasets=True) models = mix_utils.get_splitmnist_models(num_tasks, bayesian=bayesian, pretrained=False, net_type=net_type) for task in range(1, num_tasks + 1): print(f"Training task-{task}..") trainset, testset, _, _ = datasets[task - 1] train_loader, valid_loader, _ = loaders[task - 1] net = models[task - 1] net = net.to(device) ckpt_name = ckpt_dir + f"model_{net_type}_{num_tasks}.{task}.pt" criterion = (metrics.ELBO(len(trainset)) if bayesian else nn.CrossEntropyLoss()).to(device) optimizer = Adam(net.parameters(), lr=lr_start) valid_loss_max = np.Inf for epoch in range(n_epochs): # loop over the dataset multiple times utils.adjust_learning_rate( optimizer, metrics.lr_linear(epoch, 0, n_epochs, lr_start)) if bayesian: train_loss, train_acc, train_kl = train_bayesian( net, optimizer, criterion, train_loader, num_ens=train_ens) valid_loss, valid_acc = validate_bayesian(net, criterion, valid_loader, num_ens=valid_ens) print( 'Epoch: {} \tTraining Loss: {:.4f} \tTraining Accuracy: {:.4f} \tValidation Loss: {:.4f} \tValidation Accuracy: {:.4f} \ttrain_kl_div: {:.4f}' .format(epoch, train_loss, train_acc, valid_loss, valid_acc, train_kl)) else: train_loss, train_acc = train_frequentist( net, optimizer, criterion, train_loader) valid_loss, valid_acc = validate_frequentist( net, criterion, valid_loader) print( 'Epoch: {} \tTraining Loss: {:.4f} \tTraining Accuracy: {:.4f} \tValidation Loss: {:.4f} \tValidation Accuracy: {:.4f}' .format(epoch, train_loss, train_acc, valid_loss, valid_acc)) # save model if validation accuracy has increased if valid_loss <= valid_loss_max: print( 'Validation loss decreased ({:.6f} --> {:.6f}). Saving model ...' .format(valid_loss_max, valid_loss)) torch.save(net.state_dict(), ckpt_name) valid_loss_max = valid_loss print(f"Done training task-{task}")
def run(dataset, net_type, train=True): # Hyper Parameter settings train_ens = cfg.train_ens valid_ens = cfg.valid_ens test_ens = cfg.test_ens n_epochs = cfg.n_epochs lr_start = cfg.lr_start num_workers = cfg.num_workers valid_size = cfg.valid_size batch_size = cfg.batch_size trainset, testset, inputs, outputs = data.getDataset_regression(dataset) train_loader, valid_loader, test_loader = data.getDataloader( trainset, testset, valid_size, batch_size, num_workers) net = getModel(net_type, inputs, outputs).to(device) print(len(train_loader)) print(len(valid_loader)) print(len(test_loader)) ckpt_dir = f'checkpoints/regression/{dataset}/' + name ckpt_name = f'checkpoints/regression/{dataset}/'+ name + '/model_{net_type}.pt' if not os.path.exists(ckpt_dir): os.makedirs(ckpt_dir, exist_ok=True) criterion = metrics.ELBO_regression(len(trainset)).to(device) # criterion = metrics.ELBO_regression(len(train_loader)).to(device) kl_cost_train = np.zeros(n_epochs) pred_cost_train = np.zeros(n_epochs) mse_train = np.zeros(n_epochs) kl_cost_val = np.zeros(n_epochs) pred_cost_val = np.zeros(n_epochs) mse_val = np.zeros(n_epochs) if train: optimizer = Adam(net.parameters(), lr=lr_start) valid_loss_max = np.Inf for epoch in range(n_epochs): # loop over the dataset multiple times cfg.curr_epoch_no = epoch utils.adjust_learning_rate(optimizer, metrics.lr_linear(epoch, 0, n_epochs, lr_start)) train_loss, train_mse, train_kl, train_pred = train_model(net, optimizer, criterion, train_loader, num_ens=train_ens) valid_loss, valid_mse, valid_kl, valid_pred = validate_model(net, criterion, valid_loader, num_ens=valid_ens) kl_cost_train[epoch] = train_kl pred_cost_train[epoch] = train_pred mse_train[epoch] = train_mse kl_cost_val[epoch] = valid_kl pred_cost_val[epoch] = valid_pred mse_val[epoch] = valid_mse print('Epoch: {} \ttra loss: {:.4f} \ttra_kl: {:.4f} \ttra_pred: {:.4f} \ttra MSE: {:.4f} \nval loss: {:.4f} \tVal kl: {:.4f} \tval_pred: {:.4f} \tval MSE: {:.4f} ' .format( epoch, train_loss, train_kl, train_pred, train_mse, valid_loss, valid_kl, valid_pred, valid_mse)) # save model if validation MSE has increased if valid_loss <= valid_loss_max: print('Validation loss decreased ({:.6f} --> {:.6f}). Saving model ...'.format( valid_loss_max, valid_loss)) torch.save(net.state_dict(), ckpt_name) valid_loss_max = valid_loss # fig cost vs its textsize = 15 marker = 5 plt.figure(dpi=100) fig, ax1 = plt.subplots() ax1.plot(pred_cost_train[20:], 'r--') ax1.plot(pred_cost_val[20:], 'b-') ax1.set_ylabel('Pred_loss') plt.xlabel('epoch') plt.grid(b=True, which='major', color='k', linestyle='-') plt.grid(b=True, which='minor', color='k', linestyle='--') lgd = plt.legend(['train error', 'test error'], markerscale=marker, prop={'size': textsize, 'weight': 'normal'}) ax = plt.gca() plt.title('Regression costs') for item in ([ax.title, ax.xaxis.label, ax.yaxis.label] + ax.get_xticklabels() + ax.get_yticklabels()): item.set_fontsize(textsize) item.set_weight('normal') plt.savefig(ckpt_dir + '/pred_cost.png', bbox_extra_artists=(lgd,), bbox_inches='tight') plt.figure() fig, ax1 = plt.subplots() ax1.plot(kl_cost_train, 'r') ax1.plot(kl_cost_val, 'b') ax1.set_ylabel('nats?') plt.xlabel('epoch') plt.grid(b=True, which='major', color='k', linestyle='-') plt.grid(b=True, which='minor', color='k', linestyle='--') ax = plt.gca() plt.title('DKL (per sample)') for item in ([ax.title, ax.xaxis.label, ax.yaxis.label] + ax.get_xticklabels() + ax.get_yticklabels()): item.set_fontsize(textsize) item.set_weight('normal') plt.savefig(ckpt_dir + '/KL_cost.png', bbox_extra_artists=(lgd,), bbox_inches='tight') plt.figure(dpi=100) fig2, ax2 = plt.subplots() ax2.set_ylabel('% error') ax2.plot(mse_val[20:], 'b-') ax2.plot(mse_train[20:], 'r--') plt.xlabel('epoch') plt.grid(b=True, which='major', color='k', linestyle='-') plt.grid(b=True, which='minor', color='k', linestyle='--') ax2.get_yaxis().set_minor_formatter(matplotlib.ticker.ScalarFormatter()) ax2.get_yaxis().set_major_formatter(matplotlib.ticker.ScalarFormatter()) lgd = plt.legend(['val mse', 'train mse'], markerscale=marker, prop={'size': textsize, 'weight': 'normal'}) ax = plt.gca() for item in ([ax.title, ax.xaxis.label, ax.yaxis.label] + ax.get_xticklabels() + ax.get_yticklabels()): item.set_fontsize(textsize) item.set_weight('normal') plt.savefig(ckpt_dir + '/mse.png', bbox_extra_artists=(lgd,), box_inches='tight') # test saved model best_model = getModel(net_type, inputs, outputs).to(device) best_model.load_state_dict(torch.load(ckpt_name)) test_loss, test_mse = test_model(best_model, criterion, test_loader, num_ens=test_ens) print('Test Loss: {:.4f} \tTest MSE: {:.4f} '.format( test_loss, test_mse)) test_uncertainty(best_model, testset[:500], data='uci_har')