def load_train_images(): if not os.path.exists(minist_path + 'training_images.npy'): images, labels = mnist_tools.load_train_images() np.save(minist_path + 'training_images.npy', images) np.save(minist_path + 'training_labels.npy', labels) images = np.load(minist_path + 'training_images.npy') labels = np.load(minist_path + 'training_labels.npy') return images, labels
def load_train_images(): return mnist_tools.load_train_images()
def main(): # load MNIST images images, labels = load_train_images() # config config_energy_model = to_object(params_energy_model["config"]) config_generative_model = to_object(params_generative_model["config"]) # settings max_epoch = 1000 n_trains_per_epoch = 1000 batchsize_positive = 128 batchsize_negative = 128 plot_interval = 30 # seed np.random.seed(args.seed) if args.gpu_enabled: cuda.cupy.random.seed(args.seed) # init weightnorm layers if config_energy_model.use_weightnorm: print "initializing weight normalization layers of the energy model ..." x_positive = sample_from_data(images, len(images) // 10) ddgm.compute_energy(x_positive) if config_generative_model.use_weightnorm: print "initializing weight normalization layers of the generative model ..." x_negative = ddgm.generate_x(len(images) // 10) # training progress = Progress() for epoch in xrange(1, max_epoch): progress.start_epoch(epoch, max_epoch) sum_energy_positive = 0 sum_energy_negative = 0 sum_loss = 0 sum_kld = 0 for t in xrange(n_trains_per_epoch): # sample from data distribution x_positive = sample_from_data(images, batchsize_positive) x_negative = ddgm.generate_x(batchsize_negative) # train energy model energy_positive = ddgm.compute_energy_sum(x_positive) energy_negative = ddgm.compute_energy_sum(x_negative) loss = energy_positive - energy_negative ddgm.backprop_energy_model(loss) # train generative model # TODO: KLD must be greater than or equal to 0 x_negative = ddgm.generate_x(batchsize_negative) kld = ddgm.compute_kld_between_generator_and_energy_model( x_negative) ddgm.backprop_generative_model(kld) sum_energy_positive += float(energy_positive.data) sum_energy_negative += float(energy_negative.data) sum_loss += float(loss.data) sum_kld += float(kld.data) if t % 10 == 0: progress.show(t, n_trains_per_epoch, {}) progress.show( n_trains_per_epoch, n_trains_per_epoch, { "x+": sum_energy_positive / n_trains_per_epoch, "x-": sum_energy_negative / n_trains_per_epoch, "loss": sum_loss / n_trains_per_epoch, "kld": sum_kld / n_trains_per_epoch }) ddgm.save(args.model_dir) if epoch % plot_interval == 0 or epoch == 1: plot(filename="epoch_{}_time_{}min".format( epoch, progress.get_total_time()))