"Load in data"
mnist_data = np.load('../simple_regularization/data/mnist.npy')[()]
trainData, trainLabels = mnist_data['train']  # training set
testData, testLabels = mnist_data['test']
totData = torch.cat((trainData, testData), 0)
totLabels = torch.cat((trainLabels, testLabels))

# In[]
"Compute the eigenspectra of single layer mlp"
singleMlpSpectra = []
for activation in range(len(singleMlpModels)):
    activationSpectra = []
    for arch in range(len(singleMlpModels[activation])):
        realizationSpectra = []
        for model in singleMlpModels[activation][arch]:
            hidden, _ = model.bothOutputs(totData)
            tempSpectra = []
            for h in hidden:
                h = h.detach().numpy()
                h = h - np.mean(h, 0)
                cov = h.T @ h / h[:, 0].size
                cov = (cov + cov.T) / 2
                temp, _ = np.linalg.eigh(cov)
                tempSpectra.append(copy.deepcopy(temp[::-1]))
            realizationSpectra.append(copy.deepcopy(tempSpectra[::-1]))
        activationSpectra.append(copy.deepcopy(realizationSpectra))
    singleMlpSpectra.append(copy.deepcopy(activationSpectra))

# In[]
"Compute the eigenspectra of double layer mlp"
doubleMlpSpectra = []