def test_plot_activation_dist(self):
        import torch
        import numpy as np
        from torchvision.models import resnet18
        from src.visualization import plot_weight_dist, sample_dataset

        activations = {}

        def get_activation(name):
            def hook(model, input, output):
                activations[name] = output.detach()

            return hook

        for trained in [True, False]:
            with torch.no_grad():
                model = resnet18(pretrained=trained)
                model.eval()
                l_name, l_x = [], []
                for n, weight in model.named_modules():
                    weight.register_forward_hook(get_activation(n))
                images, labels = sample_dataset(self.ds, self.BATCH_SIZE)
                model(images)
                for k, v in activations.items():
                    w = np.array(v.detach().flatten())
                    l_name.append(
                        f"{k}\nmean={w.mean():.3f}\nstd={w.std():.3f}")
                    l_x.append(w)
            plot_weight_dist(l_name, l_x,
                             f'plots/test_plot_activation_dist_{trained}.png')
 def test_scatterplot_images(self):
     import numpy as np
     from torchvision.transforms import ToPILImage
     from src.visualization import scatterplot_images, sample_dataset
     from src.utils import seed_rng
     seed_rng(0)
     images, labels = sample_dataset(self.ds, 100)
     labels = labels.numpy()
     trans = ToPILImage()
     images = [trans(img) for img in images]
     lbl2x0 = {
         lbl: 48 * ix + 32
         for lbl, ix in enumerate(np.unique(labels))
     }
     cnt_lbl = {lbl: 0 for lbl in np.unique(labels)}
     x, y = [], []
     for img, lbl in zip(images, labels):
         x.append(lbl2x0[lbl])
         y.append(cnt_lbl[lbl] * 48 + 32)
         cnt_lbl[lbl] += 1
     data = np.stack([x, y], axis=1)
     scatterplot_images(data, images, f'plots/test_scatterplot_images.png')
    def test_plot_gradient_dist(self):
        import torch
        import numpy as np
        from torchvision.models import resnet18
        from src.visualization import plot_weight_dist, sample_dataset

        gradients = {}
        grad_norm = {}

        def get_gradients(name):
            def hook(model, grad_input, grad_output):
                gradients[name] = grad_output
                grad_norm[name] = grad_output[0].norm()

            return hook

        loss_fn = torch.nn.CrossEntropyLoss()
        for trained in [True, False]:
            model = resnet18(pretrained=trained)
            model.train()
            l_name, l_x = [], []
            for n, weight in model.named_modules():
                weight.register_backward_hook(get_gradients(n))
            images, labels = sample_dataset(self.ds, self.BATCH_SIZE)
            out = model(images)
            err = loss_fn(out, labels)
            err.backward()
            for k, v in gradients.items():
                w = np.array(v[0].detach().flatten())
                norm = grad_norm[k].numpy()
                l_name.append(
                    f"{k}\nmean={w.mean():.3f}\nstd={w.std():.3f}\nnorm:{norm:.3f}"
                )
                l_x.append(w)
            plot_weight_dist(l_name, l_x,
                             f'plots/test_plot_gradient_dist_{trained}.png')
 def test_draw_label_on_batch(self):
     from src.visualization import (visualize_batch, sample_dataset,
                                    draw_label_on_batch)
     image_tensor, label_tensor = sample_dataset(self.ds, self.BATCH_SIZE)
     image_tensor = draw_label_on_batch(image_tensor, label_tensor)
     visualize_batch(image_tensor, 'plots/test_draw_label_on_batch.png')
 def test_visualize_batch(self):
     from src.visualization import visualize_batch
     from src.visualization import sample_dataset
     images, labels = sample_dataset(self.ds, self.BATCH_SIZE)
     visualize_batch(images, 'plots/test_visualize_batch.png')
    def test_sample_dataset(self):
        from src.visualization import sample_dataset
        images, labels = sample_dataset(self.ds, self.BATCH_SIZE)

        self.assertEqual(images.shape, (self.BATCH_SIZE, 3, 32, 32))
        self.assertEqual(labels.shape, (self.BATCH_SIZE, ))