def gen_displayable_images(): suffix = '_image.jpg' for n in range(10): prefix = './images_DBN/digitwise/' + str(n) + '_' names = ['original', 'hidden', 'reconstructed'] names = [prefix + name + suffix for name in names] image_beautifier(names, './images_DBN/' + str(n) + '.jpg') if __name__ == '__main__': mnist = MNIST() train_x, train_y, test_x, test_y = mnist.load_dataset() layers = [512, 128, 64, 10] dbn = DBN(train_x.shape[1], layers) dbn.layer_parameters = torch.load('mnist_trained_dbn.pt') for n in range(10): x = test_x[np.where(test_y == n)[0][0]] x = x.unsqueeze(0) gen_image, hidden_image = dbn.reconstructor(x) gen_image = gen_image.numpy() hidden_image = hidden_image.numpy() image = x.numpy() image = mnist.inv_transform_normalizer(image)[0] hidden_image = (hidden_image * 255)[0] gen_image = mnist.inv_transform_normalizer(gen_image)[0] image = np.reshape(image, (28, 28)) hidden_image = np.reshape(hidden_image, (5, 2))