def test_kl_divergence_non_bayesian_module(self): linear = nn.Linear(10, 10) to_feed = torch.ones((1, 10)) predicted = linear(to_feed) kl_complexity_cost = kl_divergence_from_nn(linear) self.assertEqual((torch.tensor(0) == kl_complexity_cost).all(), torch.tensor(True)) pass
def test_kl_divergence_bayesian_linear_module(self): blinear = BayesianLinear(10, 10) to_feed = torch.ones((1, 10)) predicted = blinear(to_feed) complexity_cost = blinear.log_variational_posterior - blinear.log_prior kl_complexity_cost = kl_divergence_from_nn(blinear) self.assertEqual((complexity_cost == kl_complexity_cost).all(), torch.tensor(True)) pass
def nn_kl_divergence(self): """Returns the sum of the KL divergence of each of the BayesianModules of the model, which are from their posterior current distribution of weights relative to a scale-mixtured prior (and simpler) distribution of weights Parameters: N/a Returns torch.tensor with 0 dim. """ return kl_divergence_from_nn(self)
def test_kl_divergence_bayesian_conv2d_module(self): bconv = BayesianConv2d(in_channels=3, out_channels=3, kernel_size=(3, 3)) to_feed = torch.ones((1, 3, 25, 25)) predicted = bconv(to_feed) complexity_cost = bconv.log_variational_posterior - bconv.log_prior kl_complexity_cost = kl_divergence_from_nn(bconv) self.assertEqual((complexity_cost == kl_complexity_cost).all(), torch.tensor(True)) pass
def test_kl_divergence(self): #create model #do two inferences over same datapoint, check if different to_feed = torch.ones((1, 10)) @variational_estimator class VariationalEstimator(nn.Module): def __init__(self): super().__init__() self.blinear = BayesianLinear(10, 10) def forward(self, x): return self.blinear(x) model = VariationalEstimator() predicted = model(to_feed) complexity_cost = model.nn_kl_divergence() kl_complexity_cost = kl_divergence_from_nn(model) self.assertEqual((complexity_cost == kl_complexity_cost).all(), torch.tensor(True))
def train_epoch_classification(model, optimizer, device, data_loader, epoch, params, cyclic_lr_schedule=None): model.train() epoch_loss = 0 nb_data = 0 num_iters = len(data_loader.dataset) total_scores = [] total_targets = [] for iter, (batch_graphs, batch_targets, batch_snorm_n, batch_snorm_e, batch_smiles) in enumerate(data_loader): # SWA cyclic lr schedule if cyclic_lr_schedule is not None: lr = cyclic_lr_schedule(iter / num_iters) swa_utils.adjust_learning_rate(optimizer, lr) batch_x = batch_graphs.ndata['feat'].to(device) # num x feat batch_e = batch_graphs.edata['feat'].to(device) batch_snorm_e = batch_snorm_e.to(device) batch_targets = batch_targets.to(device) batch_targets = batch_targets.float() batch_snorm_n = batch_snorm_n.to(device) # num x 1 optimizer.zero_grad() loss = 0 sample_nbr = params['bbp_sample_nbr'] complexity_cost_weight = params['bbp_complexity'] for _ in range(sample_nbr): batch_scores = model.forward(batch_graphs, batch_x, batch_e, batch_snorm_n, batch_snorm_e) loss = model.loss(batch_scores, batch_targets) kl_loss = kl_divergence_from_nn(model) loss += kl_loss * complexity_cost_weight loss.backward() # SGD with high lr show gradient explosion --> temporally using gradient clipping if params['grad_clip'] != 0.: clipping_value = params['grad_clip'] torch.nn.utils.clip_grad_norm_(model.parameters(), clipping_value) optimizer.step() epoch_loss += loss.detach().item() nb_data += batch_targets.size(0) total_scores.append(batch_scores) total_targets.append(batch_targets) epoch_loss /= (iter + 1) total_scores = torch.cat(total_scores, dim=0) total_targets = torch.cat(total_targets, dim=0) epoch_train_perf = binary_class_perfs(total_scores.detach(), total_targets.detach()) return epoch_loss, epoch_train_perf, optimizer, total_scores, total_targets