def gen_displayable_images(): suffix = '_image.jpg' for n in range(10): prefix = './images_RBM/digitwise/'+str(n)+'_' names = ['original', 'hidden', 'reconstructed'] names = [prefix+name+suffix for name in names] image_beautifier(names, './images_RBM/'+str(n)+'.jpg') if __name__ == '__main__': mnist = MNIST() train_x, train_y, test_x, test_y = mnist.load_dataset() vn = train_x.shape[1] hn = 2500 rbm = RBM(vn, hn) rbm.load_rbm('mnist_trained_rbm.pt') for n in range(10): x = test_x[np.where(test_y==n)[0][0]] x = x.unsqueeze(0) hidden_image = [] gen_image = [] for k in range(rbm.k): _, hk = rbm.sample_h(x) _, vk = rbm.sample_v(hk) gen_image.append(vk.numpy()) hidden_image.append(hk.numpy()) hidden_image = np.array(hidden_image) hidden_image = np.mean(hidden_image, axis=0) gen_image = np.array(gen_image) gen_image = np.mean(gen_image, axis=0)