Exemple #1
0
    def test_kl_divergence_non_bayesian_module(self):
        linear = nn.Linear(10, 10)
        to_feed = torch.ones((1, 10))
        predicted = linear(to_feed)

        kl_complexity_cost = kl_divergence_from_nn(linear)
        self.assertEqual((torch.tensor(0) == kl_complexity_cost).all(),
                         torch.tensor(True))
        pass
Exemple #2
0
    def test_kl_divergence_bayesian_linear_module(self):
        blinear = BayesianLinear(10, 10)
        to_feed = torch.ones((1, 10))
        predicted = blinear(to_feed)

        complexity_cost = blinear.log_variational_posterior - blinear.log_prior
        kl_complexity_cost = kl_divergence_from_nn(blinear)

        self.assertEqual((complexity_cost == kl_complexity_cost).all(),
                         torch.tensor(True))
        pass
Exemple #3
0
    def nn_kl_divergence(self):
        """Returns the sum of the KL divergence of each of the BayesianModules of the model, which are from
            their posterior current distribution of weights relative to a scale-mixtured prior (and simpler) distribution of weights

            Parameters:
                N/a

            Returns torch.tensor with 0 dim.      
        
        """
        return kl_divergence_from_nn(self)
Exemple #4
0
    def test_kl_divergence_bayesian_conv2d_module(self):
        bconv = BayesianConv2d(in_channels=3,
                               out_channels=3,
                               kernel_size=(3, 3))

        to_feed = torch.ones((1, 3, 25, 25))
        predicted = bconv(to_feed)

        complexity_cost = bconv.log_variational_posterior - bconv.log_prior
        kl_complexity_cost = kl_divergence_from_nn(bconv)

        self.assertEqual((complexity_cost == kl_complexity_cost).all(),
                         torch.tensor(True))
        pass
Exemple #5
0
    def test_kl_divergence(self):
        #create model
        #do two inferences over same datapoint, check if different

        to_feed = torch.ones((1, 10))

        @variational_estimator
        class VariationalEstimator(nn.Module):
            def __init__(self):
                super().__init__()
                self.blinear = BayesianLinear(10, 10)

            def forward(self, x):
                return self.blinear(x)

        model = VariationalEstimator()
        predicted = model(to_feed)

        complexity_cost = model.nn_kl_divergence()
        kl_complexity_cost = kl_divergence_from_nn(model)

        self.assertEqual((complexity_cost == kl_complexity_cost).all(), torch.tensor(True))
def train_epoch_classification(model,
                               optimizer,
                               device,
                               data_loader,
                               epoch,
                               params,
                               cyclic_lr_schedule=None):
    model.train()
    epoch_loss = 0
    nb_data = 0
    num_iters = len(data_loader.dataset)

    total_scores = []
    total_targets = []

    for iter, (batch_graphs, batch_targets, batch_snorm_n, batch_snorm_e,
               batch_smiles) in enumerate(data_loader):
        #   SWA cyclic lr schedule
        if cyclic_lr_schedule is not None:
            lr = cyclic_lr_schedule(iter / num_iters)
            swa_utils.adjust_learning_rate(optimizer, lr)

        batch_x = batch_graphs.ndata['feat'].to(device)  # num x feat
        batch_e = batch_graphs.edata['feat'].to(device)
        batch_snorm_e = batch_snorm_e.to(device)
        batch_targets = batch_targets.to(device)
        batch_targets = batch_targets.float()
        batch_snorm_n = batch_snorm_n.to(device)  # num x 1
        optimizer.zero_grad()

        loss = 0
        sample_nbr = params['bbp_sample_nbr']
        complexity_cost_weight = params['bbp_complexity']
        for _ in range(sample_nbr):
            batch_scores = model.forward(batch_graphs, batch_x, batch_e,
                                         batch_snorm_n, batch_snorm_e)
            loss = model.loss(batch_scores, batch_targets)
            kl_loss = kl_divergence_from_nn(model)
            loss += kl_loss * complexity_cost_weight

        loss.backward()

        #   SGD with high lr show gradient explosion --> temporally using gradient clipping
        if params['grad_clip'] != 0.:
            clipping_value = params['grad_clip']
            torch.nn.utils.clip_grad_norm_(model.parameters(), clipping_value)

        optimizer.step()
        epoch_loss += loss.detach().item()
        nb_data += batch_targets.size(0)

        total_scores.append(batch_scores)
        total_targets.append(batch_targets)

    epoch_loss /= (iter + 1)

    total_scores = torch.cat(total_scores, dim=0)
    total_targets = torch.cat(total_targets, dim=0)

    epoch_train_perf = binary_class_perfs(total_scores.detach(),
                                          total_targets.detach())

    return epoch_loss, epoch_train_perf, optimizer, total_scores, total_targets