# ----------------------------------------------------------------------------- if not pre_trained: validation = data[:10000] for _ in range(epochs): train_loader = torch.utils.data.DataLoader(data, batch_size=batch_size, shuffle=True) rbm.train(train_loader) # A good measure of well-fitting is the free energy difference # between some known and unknown instances. It is related to the # log-likelihood difference, but it does not depend on the # partition function. It should be around 0, and if it grows, it # might be overfitting to the training data. # High-probability instances have very negative free energy, so the # gap becoming very negative is sign of overfitting. gap = (rbm.free_energy(validation) - rbm.free_energy(test)).mean(0) print('Gap = {}'.format(gap.item())) torch.save(rbm.state_dict(), model_dir) # ----------------------------------------------------------------------------- # Plotting # ----------------------------------------------------------------------------- print('Reconstructing images') plt.figure(figsize=(20, 10)) zero = torch.zeros(25, 784).to(device) images = [zero.cpu().numpy().reshape((5 * 28, 5 * 28))] sampler.internal_sampling = True for i in range(k_reconstruct): zero = sampler.get_h_from_v(zero, rbm.W, rbm.hbias) zero = sampler.get_v_from_h(zero, rbm.W, rbm.vbias)