def validation_step(self, batch, batch_idx):
        data, target = batch['waveform'], batch['label']

        # Perform step
        _, output2 = self(data)
        
        # Calculate loss, must be CrossEntropy or a derivative
        val_loss = self.loss(output2, target)
        
        # Calculate KL divergence between the approximate posterior and the prior over all Bayesian layers
        kl_clean = kldiv(self.model) 

        # Weight the KL divergence, so it does not overflow the loss term
        kl = self.model.weight_kl(kl_clean, self.val_dataset_size)      

        # Apply KL weighting scheme, allows for balancing the KL term non-uniformly
        M = self.val_dataset_size / self.batch_size
        beta = get_beta(batch_idx, M, beta_type=self.kl_weighting_scheme)
        kl_weighted = beta * kl

        # Calculate accuracy
        acc = FM.accuracy(output2.squeeze(), target)

        # Loss is tensor
        metrics = {'val_loss': val_loss.item(), 'val_acc': acc.item()}
        self.log('val_acc', acc.item())
        self.log('val_loss', val_loss.item())
        self.log('val_kl_weighted', kl_weighted.item())

        return metrics
    def training_step(self, batch, batch_idx):
        """Performs a training step.

        Args:
            batch (dict): Output of the dataloader.
            batch_idx (int): Index no. of this batch.

        Returns:
            tensor: Total loss for this step.
        """
        data, target = batch['waveform'], batch['label']
        output1, output2 = self(data)

        # Calculate Focal loss for mid and final output
        train_loss1 = self.loss(output1, target)
        train_loss2 = self.loss(output2, target)

        # Calculate kl divergence over all Bayesian layers
        kl_clean = kldiv(self.model, self.kl_init)

        # Weight the KL divergence, so it does not overflow the loss term
        kl = self.model.weight_kl(kl_clean, self.train_dataset_size)

        # Apply KL weighting scheme, allows for balancing the KL term non-uniformly
        M = self.train_dataset_size / self.batch_size
        beta = get_beta(batch_idx,
                        M,
                        beta_type=self.kl_weighting_scheme,
                        epoch=self.current_epoch + 1,
                        num_epochs=self.max_epochs)
        kl_weighted = beta * kl

        # Variational inference objective = -Kl divergence + negative log likelihood
        ELBO = kl_weighted + train_loss2

        # Calculate total loss
        total_train_loss = (0.3 * train_loss1) + ELBO

        self.log('train_loss', total_train_loss)
        self.log('train_ELBO', ELBO)
        self.log('train_kl_weighted', kl_weighted)

        return {'loss': total_train_loss}