def validation_step(self, batch, batch_idx): data, target = batch['waveform'], batch['label'] # Perform step _, output2 = self(data) # Calculate loss, must be CrossEntropy or a derivative val_loss = self.loss(output2, target) # Calculate KL divergence between the approximate posterior and the prior over all Bayesian layers kl_clean = kldiv(self.model) # Weight the KL divergence, so it does not overflow the loss term kl = self.model.weight_kl(kl_clean, self.val_dataset_size) # Apply KL weighting scheme, allows for balancing the KL term non-uniformly M = self.val_dataset_size / self.batch_size beta = get_beta(batch_idx, M, beta_type=self.kl_weighting_scheme) kl_weighted = beta * kl # Calculate accuracy acc = FM.accuracy(output2.squeeze(), target) # Loss is tensor metrics = {'val_loss': val_loss.item(), 'val_acc': acc.item()} self.log('val_acc', acc.item()) self.log('val_loss', val_loss.item()) self.log('val_kl_weighted', kl_weighted.item()) return metrics
def training_step(self, batch, batch_idx): """Performs a training step. Args: batch (dict): Output of the dataloader. batch_idx (int): Index no. of this batch. Returns: tensor: Total loss for this step. """ data, target = batch['waveform'], batch['label'] output1, output2 = self(data) # Calculate Focal loss for mid and final output train_loss1 = self.loss(output1, target) train_loss2 = self.loss(output2, target) # Calculate kl divergence over all Bayesian layers kl_clean = kldiv(self.model, self.kl_init) # Weight the KL divergence, so it does not overflow the loss term kl = self.model.weight_kl(kl_clean, self.train_dataset_size) # Apply KL weighting scheme, allows for balancing the KL term non-uniformly M = self.train_dataset_size / self.batch_size beta = get_beta(batch_idx, M, beta_type=self.kl_weighting_scheme, epoch=self.current_epoch + 1, num_epochs=self.max_epochs) kl_weighted = beta * kl # Variational inference objective = -Kl divergence + negative log likelihood ELBO = kl_weighted + train_loss2 # Calculate total loss total_train_loss = (0.3 * train_loss1) + ELBO self.log('train_loss', total_train_loss) self.log('train_ELBO', ELBO) self.log('train_kl_weighted', kl_weighted) return {'loss': total_train_loss}