def nn(args, dm): device = args.device n_neurons = 50 X, y = ds_from_dl(dm.train_dataloader(), device) Xval, yval = ds_from_dl(dm.train_dataloader(), device) mean = torch.nn.Sequential( torch.nn.Linear(X.shape[1], n_neurons), torch.nn.ReLU(), torch.nn.Linear(n_neurons, y.shape[1]), ) mean.to(device) var = torch.nn.Sequential( torch.nn.Linear(X.shape[1], n_neurons), torch.nn.ReLU(), torch.nn.Linear(n_neurons, y.shape[1]), torch.nn.Softplus(), ) var.to(device) optimizer = torch.optim.Adam(chain(mean.parameters(), var.parameters()), lr=args.lr) it = 0 progressBar = tqdm(desc="Training nn", total=args.iters, unit="iter") batches = batchify(X, y, batch_size=args.batch_size) while it < args.iters: switch = 1.0 if it > args.iters / 2 else 0.0 optimizer.zero_grad() data, label = next(batches) m, v = mean(data), var(data) v = switch * v + (1 - switch) * torch.tensor([0.02**2], device=device) loss = normal_log_prob(label, m, v).sum() (-loss).backward() optimizer.step() it += 1 progressBar.update() progressBar.set_postfix({"loss": loss.item()}) progressBar.close() data = Xval label = yval m, v = mean(data), var(data) # Consistency with our way - only evaluate on std data. # m = m * y_std + y_mean # v = v * y_std ** 2 log_px = normal_log_prob(label, m, v) rmse = ((label - m)**2).mean().sqrt() return log_px.mean().item(), rmse.item()
def mcdnn(args, dm): device = args.device n_neurons = 50 X, y = ds_from_dl(dm.train_dataloader(), device) Xval, yval = ds_from_dl(dm.train_dataloader(), device) mean = torch.nn.Sequential( torch.nn.Linear(X.shape[1], n_neurons), torch.nn.Dropout(p=0.05), torch.nn.ReLU(), torch.nn.Linear(n_neurons, y.shape[1]), torch.nn.Dropout(p=0.05), ) mean.to(device) optimizer = torch.optim.Adam(mean.parameters(), lr=args.lr) it = 0 progressBar = tqdm(desc="Training nn", total=args.iters, unit="iter") batches = batchify(X, y, batch_size=args.batch_size) while it < args.iters: optimizer.zero_grad() data, label = next(batches) m = mean(data) loss = (m - label).abs().pow(2.0).mean() loss.backward() optimizer.step() it += 1 progressBar.update() progressBar.set_postfix({"loss": loss.item()}) progressBar.close() data = Xval label = yval samples = torch.zeros(args.mcmc, Xval.shape[0], y.shape[1]).to(device) for i in range(args.mcmc): samples[i, :] = mean(data) m, v = samples.mean(dim=0), samples.var(dim=0) log_probs = normal_log_prob(label, m, v) rmse = ((label - m)**2).mean().sqrt() return log_probs.mean().item(), rmse.item()
def john(args, dm): device = args.device n_neurons = 50 X, y = ds_from_dl(dm.train_dataloader(), device) Xval, yval = ds_from_dl(dm.train_dataloader(), device) args.n_clusters = min(args.n_clusters, X.shape[0]) mean_psu = 1 mean_ssu = 40 mean_M = 50 var_psu = 2 var_ssu = 10 var_M = 10 num_draws_train = 20 kmeans = KMeans(n_clusters=args.n_clusters) kmeans.fit(np.concatenate([X.cpu()], axis=0)) c = torch.tensor(kmeans.cluster_centers_, dtype=torch.float32) c = c.to(device) class translatedSigmoid(torch.nn.Module): def __init__(self): super(translatedSigmoid, self).__init__() self.beta = torch.nn.Parameter(torch.tensor([1.5])) def forward(self, x): beta = torch.nn.functional.softplus(self.beta) alpha = -beta * (6.9077542789816375) return torch.sigmoid((x + alpha) / beta) class GPNNModel(torch.nn.Module): def __init__(self): super(GPNNModel, self).__init__() self.mean = torch.nn.Sequential( torch.nn.Linear(X.shape[1], n_neurons), torch.nn.ReLU(), torch.nn.Linear(n_neurons, y.shape[1]), ) self.alph = torch.nn.Sequential( torch.nn.Linear(X.shape[1], n_neurons), torch.nn.ReLU(), torch.nn.Linear(n_neurons, y.shape[1]), torch.nn.Softplus(), ) self.bet = torch.nn.Sequential( torch.nn.Linear(X.shape[1], n_neurons), torch.nn.ReLU(), torch.nn.Linear(n_neurons, y.shape[1]), torch.nn.Softplus(), ) self.trans = translatedSigmoid() def forward(self, x, switch): d = dist(x, c) d_min = d.min(dim=1, keepdim=True)[0] s = self.trans(d_min) mean = self.mean(x) if switch: a = self.alph(x) b = self.bet(x) gamma_dist = D.Gamma(a + 1e-8, 1.0 / (b + 1e-8)) if self.training: samples_var = gamma_dist.rsample( torch.Size([num_draws_train])) x_var = 1.0 / (samples_var + 1e-8) else: samples_var = gamma_dist.rsample(torch.Size([1000])) x_var = 1.0 / (samples_var + 1e-8) var = (1 - s) * x_var + s * torch.ones_like(x_var).type_as(x_var) else: var = 0.05 * torch.ones_like(mean) return mean, var model = GPNNModel() model.to(device) optimizer = torch.optim.Adam(model.mean.parameters(), lr=1e-2) optimizer2 = torch.optim.Adam( chain(model.alph.parameters(), model.bet.parameters(), model.trans.parameters()), lr=1e-4, ) mean_Q, mean_w = gen_Qw(X.cpu(), mean_psu, mean_ssu, mean_M) mean_Q = torch.tensor(mean_Q).to(torch.float32).to(device) mean_w = torch.tensor(mean_w).to(torch.float32).to(device) if X.shape[0] > 100000 and X.shape[1] > 10: pca = PCA(n_components=0.5) temp = pca.fit_transform(X.cpu()) var_Q, var_w = gen_Qw(temp, var_psu, var_ssu, var_M) else: var_Q, var_w = gen_Qw(X.cpu(), var_psu, var_ssu, var_M) # mean_pseupoch = get_pseupoch(mean_w,0.5) # var_pseupoch = get_pseupoch(var_w,0.5) opt_switch = 1 var_Q = torch.tensor(var_Q).to(torch.float32).to(device) var_w = torch.tensor(var_w).to(torch.float32).to(device) model.train() batches = batchify(X, y, batch_size=args.batch_size) it = 0 while it < args.iters: switch = 1.0 if it > args.iters / 2.0 else 0.0 if it % 11: opt_switch = opt_switch + 1 # change between var and mean optimizer data, label = next(batches) data.to(device) label.to(device) if not switch: optimizer.zero_grad() m, v = model(data, switch) loss = (-t_likelihood(label.reshape(-1, y.shape[1]), m, v.reshape(-1, y.shape[1])) / X.shape[0]) loss.backward() optimizer.step() else: if opt_switch % 2 == 0: # for b in range(mean_pseupoch): optimizer.zero_grad() batch = locality_sampler2(mean_psu, mean_ssu, mean_Q.cpu(), mean_w.cpu()) m, v = model(X[batch], switch) loss = (-t_likelihood(y[batch].reshape(-1, y.shape[1]), m, v, mean_w[batch]) / X.shape[0]) loss.backward() optimizer.step() else: # for b in range(var_pseupoch): optimizer2.zero_grad() batch = locality_sampler2(var_psu, var_ssu, var_Q.cpu(), var_w.cpu()) m, v = model(X[batch], switch) loss = (-t_likelihood(y[batch].reshape(-1, y.shape[1]), m, v, var_w[batch]) / X.shape[0]) loss.backward() optimizer2.step() if it % 500 == 0: m, v = model(data, switch) loss = -(-v.log() / 2 - ((m - label)**2) / (2 * v)).mean() print("Iter {0}/{1}, Loss {2}".format(it, args.iters, loss.item())) it += 1 model.eval() data = Xval label = yval with torch.no_grad(): m, v = model(data, switch) # m = m * y_std + y_mean # v = v * y_std ** 2 # log_px = normal_log_prob(label, m, v).mean(dim=0) # check for correctness log_px = t_likelihood(label.reshape(-1, y.shape[1]), m, v) / Xval.shape[0] # check rmse = ((label - m)**2).mean().sqrt() return log_px.mean().item(), rmse.item()
def bnn(args, dm): import tensorflow as tf import tensorflow_probability as tfp from tensorflow_probability import distributions as tfd device = args.device n_neurons = 50 X, y = ds_from_dl(dm.train_dataloader(), device) Xval, yval = ds_from_dl(dm.train_dataloader(), device) tf.compat.v1.reset_default_graph() tf.compat.v1.disable_eager_execution() # y, y_mean, y_std = normalize_y(y) def VariationalNormal(name, shape, constraint=None): means = tf.compat.v1.get_variable(name + "_mean", initializer=tf.ones(shape), constraint=constraint) stds = tf.compat.v1.get_variable(name + "_std", initializer=-1.0 * tf.ones(shape)) return tfd.Normal(loc=means, scale=tf.nn.softplus(stds)) x_p = tf.compat.v1.placeholder(tf.float32, shape=(None, X.shape[1])) y_p = tf.compat.v1.placeholder(tf.float32, shape=(None, y.shape[1])) with tf.compat.v1.name_scope("model", values=[x_p]): layer1 = tfp.layers.DenseFlipout( units=n_neurons, activation="relu", kernel_posterior_fn=tfp.layers.default_mean_field_normal_fn(), bias_posterior_fn=tfp.layers.default_mean_field_normal_fn(), ) layer2 = tfp.layers.DenseFlipout( units=1, activation="linear", kernel_posterior_fn=tfp.layers.default_mean_field_normal_fn(), bias_posterior_fn=tfp.layers.default_mean_field_normal_fn(), ) predictions = layer2(layer1(x_p)) noise = VariationalNormal("noise", [y.shape[1]], constraint=tf.keras.constraints.NonNeg()) pred_distribution = tfd.Normal(loc=predictions, scale=noise.sample()) neg_log_prob = -tf.reduce_mean( input_tensor=pred_distribution.log_prob(y_p)) kl_div = sum(layer1.losses + layer2.losses) / X.shape[0] elbo_loss = neg_log_prob + kl_div with tf.compat.v1.name_scope("train"): optimizer = tf.compat.v1.train.AdamOptimizer(learning_rate=args.lr) train_op = optimizer.minimize(elbo_loss) with tf.compat.v1.Session() as sess: sess.run(tf.compat.v1.global_variables_initializer()) it = 0 progressBar = tqdm(desc="Training BNN", total=args.iters, unit="iter") batches = batchify(X, y, batch_size=args.batch_size) while it < args.iters: data, label = next(batches) _, loss = sess.run( [train_op, elbo_loss], feed_dict={ x_p: data.cpu(), y_p: label.reshape(-1, y.shape[1]).cpu() }, ) progressBar.update() progressBar.set_postfix({"loss": loss}) it += 1 progressBar.close() W0_samples = layer1.kernel_posterior.sample(1000) b0_samples = layer1.bias_posterior.sample(1000) W1_samples = layer2.kernel_posterior.sample(1000) b1_samples = layer2.bias_posterior.sample(1000) noise_samples = noise.sample(1000) W0, b0, W1, b1, n = sess.run( [W0_samples, b0_samples, W1_samples, b1_samples, noise_samples]) def sample_net(x, W0, b0, W1, b1, n): h = np.maximum( np.matmul(x[np.newaxis], W0) + b0[:, np.newaxis, :], 0.0) return (np.matmul(h, W1) + b1[:, np.newaxis, :] + n[:, np.newaxis, :] * np.random.randn()) samples = sample_net(Xval.cpu(), W0, b0, W1, b1, n) m = samples.mean(axis=0) v = samples.var(axis=0) # m = m * y_std + y_mean # v = v * y_std ** 2 log_probs = normal_log_prob(yval.cpu(), m, v) rmse = math.sqrt(((m - yval.cpu())**2).mean()) return log_probs.mean(), rmse
def ensnn(args, dm): device = args.device n_neurons = 50 X, y = ds_from_dl(dm.train_dataloader(), device) Xval, yval = ds_from_dl(dm.train_dataloader(), device) ms, vs = [], [] for m in range(args.n_models): # initialize differently mean = torch.nn.Sequential( torch.nn.Linear(X.shape[1], n_neurons), torch.nn.ReLU(), torch.nn.Linear(n_neurons, y.shape[1]), ) mean.to(device) var = torch.nn.Sequential( torch.nn.Linear(X.shape[1], n_neurons), torch.nn.ReLU(), torch.nn.Linear(n_neurons, y.shape[1]), torch.nn.Softplus(), ) var.to(device) optimizer = torch.optim.Adam(chain(mean.parameters(), var.parameters()), lr=args.lr) it = 0 progressBar = tqdm(desc="Training nn", total=args.iters, unit="iter") batches = batchify(X, y, batch_size=args.batch_size) while it < args.iters: switch = 0.0 # 1.0 if it > args.iters/2 else 0.0 optimizer.zero_grad() data, label = next(batches) m, v = mean(data), var(data) v = switch * v + (1 - switch) * torch.tensor([0.02**2], device=device) loss = normal_log_prob(label, m, v).sum() (-loss).backward() optimizer.step() it += 1 progressBar.update() progressBar.set_postfix({"loss": loss.item()}) progressBar.close() data = Xval label = yval m, v = mean(data), var(data) # m = m * y_std + y_mean # v = v * y_std ** 2 ms.append(m) vs.append(v) ms = torch.stack(ms) vs = torch.stack(vs) m = ms.mean(dim=0) v = (vs + ms**2).mean(dim=0) - m**2 log_px = normal_log_prob(label, m, v) rmse = ((label - m)**2).mean().sqrt() return log_px.mean().item(), rmse.item()