print(accuracy)

fig, ax = plt.subplots(1, figsize=(8, 8))
utils.make_confusion_matrix(labels, predictions, categories, ax)
"""## Visualize learned filters

The output of convolutional filters is still spatial, so it can be intepreted as images. Here, we inspect some of the filter responses of our network in order to visualize what features these filters pick up on.
"""

# apply the first convolutional filter of our model
# to the first image from the training dataset
model.to(torch.device('cpu'))
model.eval()
im = train_dataset[0][0][None]

conv1_response = model.conv1(im).detach().numpy()[0]
print(conv1_response.shape)

# visualize the filters in the first layer
n_filters = 8
fig, axes = plt.subplots(1, 1 + n_filters, figsize=(16, 4))
im = im[0].numpy().transpose((1, 2, 0))
axes[0].imshow(im)
for chan_id in range(n_filters):
    axes[chan_id + 1].imshow(conv1_response[chan_id], cmap='gray')

plt.show()
"""## Tasks and Questions

Tasks:
- Construct a CNN that can be applied to input images of arbitrary size. Hint: Have a look at [nn.AdaptiveAveragePool2d](https://pytorch.org/docs/stable/generated/torch.nn.AdaptiveAvgPool2d.html).