def generate_image(frame, true_dist): # generates a batch of samples next to each other in one image! samples = session.run(fixed_noise_samples, feed_dict={real_data: fixed_real_data, condition_data: fixed_labels}) # [-1,1] #for im in samples: # add a image for the samples # tf.summary.image("{}-image".format(grad.name.replace(":","_")), im) samples_255 = ((samples+1.)*(255./2)).astype('int32') # [0,255] samples_01 = ((samples+1.)/2.).astype('float32') # [0,1] for i in range(0, BATCH_SIZE): samples_255 = np.insert(samples_255, i*2, fixed_real_data_255[i,:], axis=0) # show cond digit next to generated sample imsaver.save_images(samples_255.reshape((2*BATCH_SIZE, 1, IM_DIM, IM_DIM)), 'samples_{}.jpg'.format(frame)) print("Iteration %d : \n" % frame) # compare generated to real one real = tf.reshape(fixed_real_data, [BATCH_SIZE,IM_DIM,IM_DIM,1]) pred = tf.reshape(samples_01, [BATCH_SIZE,IM_DIM,IM_DIM,1]) ssimval = tf.image.ssim(real, pred, max_val=1.0) # tensor batch in, out tensor of ssimvals (64,) mseval_per_entry = tf.keras.metrics.mse(real, pred) # mse on grayscale, on [0,1] mseval = tf.reduce_mean(mseval_per_entry, [1,2]) tf.summary.tensor_summary("SSIM values", ssimval) tf.summary.tensor_summary("MSE values", mseval) ssimval_list = ssimval.eval() # to numpy array # (50,) mseval_list = mseval.eval() # (50,) # print(ssimval_list) # print(mseval_list) for i in range (0,3): plotter.plot('SSIM for sample %d' % (i+1), ssimval_list[i]) plotter.plot('MSE for sample %d' % (i+1), mseval_list[i]) print("sample %d \t MSE: %.5f \t SSIM: %.5f \r\n" % (i, mseval_list[i], ssimval_list[i]))
def generate_image( frame, true_dist ): # generates a batch of samples next to each other in one image! # test: generate some samples samples = session.run(fixed_noise_samples, feed_dict={real_data: fixed_data }) # [-1,1] ### for MNIST # (50, 784) samples_255 = ((samples + 1.) * (255. / 2)).astype('int32') # [0,255] samples_01 = ((samples + 1.) / 2.).astype('float32') # [0,1] imsaver.save_images(samples_255.reshape((BATCH_SIZE, IM_DIM, IM_DIM)), 'samples_{}.png'.format(frame)) ### for MNIST print("Iteration %d : \n" % frame) # compare generated to real ones real = tf.reshape(fixed_data, [BATCH_SIZE, IM_DIM, IM_DIM, 1]) ### for MNIST # real_gray = tf.image.rgb_to_grayscale(real) # tensor batch in&out returns original dtype = float [0,1] ### for MNIST pred = tf.reshape(samples_01, [BATCH_SIZE, IM_DIM, IM_DIM, 1]) ### for MNIST # pred_gray = tf.image.rgb_to_grayscale(pred) ### for MNIST ssimval = tf.image.ssim( real, pred, max_val=1.0 ) # in tensor batch, out tensor ssimvals (64,) ### for MNIST mseval_per_entry = tf.keras.metrics.mse( real, pred) # mse on grayscale, on [0,1] ### for MNIST mseval = tf.reduce_mean(mseval_per_entry, [1, 2]) ### for MNIST # ssimvals 0.2 to 0.75 :) # msevals 1-9 e -1 to -3 ssimval_list = ssimval.eval() # to numpy array # (64,) # (50,) mseval_list = mseval.eval() # (64,) # (50,) # print(ssimval_list) # print(mseval_list) for i in range(0, 3): plotter.plot('SSIM for sample %d' % (i + 1), ssimval_list[i]) plotter.plot('MSE for sample %d' % (i + 1), mseval_list[i]) print("sample %d \t MSE: %.5f \t SSIM: %.5f \r\n" % (i, mseval_list[i], ssimval_list[i]))
def generate_image( frame, true_dist ): # generates a batch of samples next to each other in one image! samples = session.run(fixed_noise_samples, feed_dict={ real_data_int: fixed_real_data_int, cond_data_int: fixed_cond_data_int }) # [-1,1] samples_255 = ((samples + 1.) * (255. / 2)).astype('int32') # [0,255] samples_01 = ((samples + 1.) / 2.).astype('float32') # [0,1] for i in range(0, BATCH_SIZE): samples_255 = np.insert( samples_255, i * 2, fixed_cond_data_int[i], axis=0) # show last frame next to generated sample imsaver.save_images( samples_255.reshape((2 * BATCH_SIZE, 3, IM_DIM, IM_DIM)), 'samples_{}.jpg'.format(frame)) print("Iteration %d : \n" % frame) # compare generated to real one real = tf.reshape(fixed_real_data_norm01, [BATCH_SIZE, IM_DIM, IM_DIM, 3]) real_gray = tf.image.rgb_to_grayscale( real) # tensor batch in&out; returns original dtype = float [0,1] pred = tf.reshape(samples_01, [BATCH_SIZE, IM_DIM, IM_DIM, 3]) pred_gray = tf.image.rgb_to_grayscale(pred) ssimval = tf.image.ssim( real_gray, pred_gray, max_val=1.0) # tensor batch in, out tensor of ssimvals (64,) mseval_per_entry = tf.keras.metrics.mse( real_gray, pred_gray) # mse on grayscale, on [0,1] mseval = tf.reduce_mean(mseval_per_entry, [1, 2]) # ssimvals 0.2 to 0.75 :) # msevals 1-9 e -1 to -3 ssimval_list = ssimval.eval() # to numpy array # (64,) mseval_list = mseval.eval() # (64,) #print(ssimval_list) # print(mseval_list) for i in range(0, 3): plotter.plot('SSIM for sample %d' % (i + 1), ssimval_list[i]) plotter.plot('MSE for sample %d' % (i + 1), mseval_list[i]) print("sample %d \t MSE: %.5f \t SSIM: %.5f \r\n" % (i, mseval_list[i], ssimval_list[i]))
def generate_image( frame, final ): # generates a batch of samples next to each other in one image! samples = session.run(fixed_noise_samples, feed_dict={ condition_data: fixed_cond_data, time_data: fixed_time_data }) # [0,1] samples_255 = ((samples) * 255.99).astype('uint8') # [0,1] -> [0,255] #print(samples.min()) #print(samples.max()) # add ground truth for i in range(0, BATCH_SIZE): samples_255 = np.insert(samples_255, i * 4, fixed_cond_2_data_255[i, :], axis=0) # cond_time left of sample samples_255 = np.insert(samples_255, i * 4 + 1, fixed_cond_data_255[i, :], axis=0) # cond left of sample samples_255 = np.insert(samples_255, i * 4 + 3, fixed_real_data_255[i, :], axis=0) # real right of sample imsaver.save_images(samples_255.reshape((4 * BATCH_SIZE, IM_DIM, IM_DIM)), 'samples_{}.jpg'.format(frame), alternate_viz=True, conds=True, gt=True, time=True) print("Iteration %d :" % frame) # compare generated to real one real = tf.reshape(fixed_cond_data, [BATCH_SIZE, IM_DIM, IM_DIM, 1]) pred = tf.reshape(samples, [BATCH_SIZE, IM_DIM, IM_DIM, 1]) ssim_vals = tf.image.ssim(real, pred, max_val=1.0) # on batch mse_vals = tf.reduce_mean(tf.keras.metrics.mse(real, pred), [1, 2]) # mse on grayscale, on [0,1] psnr_vals = tf.image.psnr(real, pred, max_val=1.0) # on batch, out tf.float32 ssim_avg = (tf.reduce_mean(ssim_vals)).eval() mse_avg = (tf.reduce_mean(mse_vals)).eval() psnr_avg = (tf.reduce_mean(psnr_vals)).eval() plotter.plot( 'SSIM avg', ssim_avg) # show average of ssim and mse vals over whole batch plotter.plot('MSE avg', mse_avg) plotter.plot('PSNR avg', psnr_avg) if (final): print('final iteration %d SSIM avg: %.2f MSE avg: %.2f' % (iteration, ssim_avg, mse_avg)) ssim_vals_list = ssim_vals.eval() mse_vals_list = mse_vals.eval() psnr_vals_list = psnr_vals.eval() print(ssim_vals_list) # save it in nohup.out print(mse_vals_list) print(psnr_vals_list)
def generate_image(frame, final): # generates a batch of samples next to each other in one image! inference_start_time = time.time() # inference time analysis if(MODE == 'cond' or MODE == 'enc'): samples = session.run(fixed_noise_samples, feed_dict={condition_data: fixed_cond_data}) # [0,1] elif(MODE == 'plain'): samples = session.run(fixed_noise_samples) # [0,1] elif(MODE == 'vae'): samples = session.run(fixed_noise_samples, feed_dict={real_data: fixed_cond_data}) # [0,1] samples_noise = session.run(_noise_samples, feed_dict={real_data: fixed_cond_data}) # [0,1] noise_samples_255 = ((samples_noise)*255.).astype('uint8') # [0, 1] -> [0,255] inference_end_time = time.time() # inference time analysis inference_time = (inference_end_time - inference_start_time) print("The architecture took ", inference_time, "sec for the generation of ", BATCH_SIZE, "images") samples_255 = ((samples)*255.).astype('uint8') # [0, 1] -> [0,255] # print('samples 255') # print(samples_255.min()) # print(samples_255.max()) if(MODE == 'plain'): imsaver.save_images(samples_255.reshape((BATCH_SIZE, IM_DIM, IM_DIM)), 'samples_{}.jpg'.format(frame), alternate_viz=True) elif(MODE == 'vae'): for i in range(0, BATCH_SIZE): samples_255 = np.insert(samples_255, i*2, fixed_cond_data_255[i,:], axis=0) # real (cond digit) next to sample imsaver.save_images(samples_255.reshape((2*BATCH_SIZE, IM_DIM, IM_DIM)), 'samples_{}.jpg'.format(frame), alternate_viz=True, conds=True) imsaver.save_images(noise_samples_255.reshape((BATCH_SIZE, IM_DIM, IM_DIM)), 'noise_samples_{}.jpg'.format(frame), alternate_viz=True) else: # if(MODE == 'enc' or MODE == 'cond') add ground truth for i in range(0, BATCH_SIZE): samples_255 = np.insert(samples_255, i*3, fixed_cond_data_255[i,:], axis=0) # cond left of sample samples_255 = np.insert(samples_255, i*3+2, fixed_real_data_255[i,:], axis=0) # real right of sample imsaver.save_images(samples_255.reshape((3*BATCH_SIZE, IM_DIM, IM_DIM)), 'samples_{}.jpg'.format(frame), alternate_viz=True, conds=True, gt=True) print("Iteration %d :" % frame) # compare generated to real one real = tf.reshape(fixed_cond_data, [BATCH_SIZE,IM_DIM,IM_DIM,1]) pred = tf.reshape(samples, [BATCH_SIZE,IM_DIM,IM_DIM,1]) ssim_vals = tf.image.ssim(real, pred, max_val=1.0) # on batch mse_vals = tf.reduce_mean(tf.keras.metrics.mse(real, pred), [1,2]) # mse on grayscale, on [0,1] psnr_vals = tf.image.psnr(real, pred, max_val=1.0) # on batch, out tf.float32 ssim_avg = (tf.reduce_mean(ssim_vals)).eval() mse_avg = (tf.reduce_mean(mse_vals)).eval() psnr_avg = (tf.reduce_mean(psnr_vals)).eval() plotter.plot('SSIM avg', ssim_avg) # show average of ssim and mse vals over whole batch plotter.plot('MSE avg', mse_avg) plotter.plot('PSNR avg', psnr_avg) if(final): print('final iteration %d SSIM avg: %.2f MSE avg: %.2f' % (iteration, ssim_avg, mse_avg)) ssim_vals_list = ssim_vals.eval() mse_vals_list = mse_vals.eval() psnr_vals_list = psnr_vals.eval() print(ssim_vals_list) # save it in nohup.out print(mse_vals_list) print(psnr_vals_list)
_data = next(gen) # [0,1] # (50, 6, 4096) _real_data = (_data[:, 5, :]).reshape( (BATCH_SIZE, output_dim)) # current frame for now _cond_data = (_data[:, 4, :]).reshape( (BATCH_SIZE, output_dim)) # one last frame for now _time_data = (_data[:, (5 - TIMESTEPS):5, :]).reshape( (BATCH_SIZE, TIMESTEPS, output_dim)) # past frames _disc_cost, _, summary_str = session.run( [disc_cost, disc_train_op, merged_summary_op_d], feed_dict={ real_data: _real_data, condition_data: _cond_data, time_data: _time_data }) summary_writer.add_summary(summary_str, iteration) plotter.plot('train disc cost', _disc_cost) iteration_time = time.time() - start_time plotter.plot('time', iteration_time) # Validation: Calculate validation loss and generate samples every 100 iters if (iteration % 100 == 99): # or iteration < 10): dev_disc_costs = [] dev_vae_losses = [] _data = next(dev_generator) # [0,1] # (50, 6, 4096) _real_data = (_data[:, 5, :]).reshape( (BATCH_SIZE, output_dim)) # current frame for now _cond_data = (_data[:, 4, :]).reshape( (BATCH_SIZE, output_dim)) # one last frame for now _time_data = (_data[:, (5 - TIMESTEPS):5, :]).reshape( (BATCH_SIZE, TIMESTEPS, output_dim)) # past frames _dev_disc_cost, _dev_gen_cost = session.run(
session.run(init_op) for iteration in range(START_ITER, ITERS): # START_ITER: 0 or from last checkpoint start_time = time.time() # Train generator if iteration > 0: _data = next(gen) # shape: (batchsize, 6144) _cond_data = _data # digit as cond _ = session.run(gen_train_op, feed_dict={cond_data: _cond_data}) # Train duscriminator for i in range(DISC_ITERS): _data = next(gen) # shape: (batchsize, 6144) _cond_data = _data # digit as cond _real_data = _data # current frame for disc _disc_cost, _ = session.run([disc_cost, disc_train_op], feed_dict={real_data: _real_data, cond_data: _cond_data}) plotter.plot('train disc cost', _disc_cost) plotter.plot('time', time.time() - start_time) # Calculate dev loss and generate samples every 100 iters if iteration % 100 == 99: dev_disc_costs = [] _data = next(dev_generator) # shape: (batchsize, 6144) _cond_data = _data # digit as cond _real_data = _data # current frame for disc _dev_disc_cost = session.run(disc_cost, feed_dict={real_data: _real_data, cond_data: _cond_data}) dev_disc_costs.append(_dev_disc_cost) plotter.plot('dev disc cost', np.mean(dev_disc_costs)) generate_image(iteration, _data) save_path = saver.save(session, restore_path) # Save the variables to disk. print("Model saved in path: %s" % save_path) # chkp.print_tensors_in_checkpoint_file("model.ckpt", tensor_name='', all_tensors=True)
def train(self, config): """Train DCGAN""" #define optimizer self.g_optim = tf.train.AdamOptimizer( learning_rate=config.learning_rate, beta1=config.beta1, beta2=config.beta2).minimize( self.gen_cost, var_list=params_with_name('Generator'), colocate_gradients_with_ops=True) self.d_optim = tf.train.AdamOptimizer( learning_rate=config.learning_rate, beta1=config.beta1, beta2=config.beta2).minimize( self.disc_cost, var_list=params_with_name('Discriminator.'), colocate_gradients_with_ops=True) tf.global_variables_initializer().run() #try to load trained parameters print('-------------') existing_gan, ckpt_name = self.load() #count number of variables total_parameters = 0 for variable in tf.trainable_variables(): # shape is an array of tf.Dimension shape = variable.get_shape() variable_parameters = 1 for dim in shape: variable_parameters *= dim.value total_parameters += variable_parameters print('-------------') print('number of variables: ' + str(total_parameters)) print('-------------') #start training counter_batch = 0 epoch = 0 #fitting errors f, sbplt = plt.subplots(2, 2, figsize=(8, 8), dpi=250) matplotlib.rcParams.update({'font.size': 8}) plt.subplots_adjust(left=left, bottom=bottom, right=right, top=top, wspace=wspace, hspace=hspace) for iteration in range(config.num_iter): start_time = time.time() # Train generator (only after the critic has been trained, at least once) if iteration + ckpt_name > 0: _ = self.sess.run(self.g_optim) # Train critic disc_iters = config.critic_iters for i in range(disc_iters): #get batch and update critic _data = self.training_samples[:, counter_batch * config.batch_size: (counter_batch + 1) * config.batch_size].T _disc_cost, _ = self.sess.run([self.disc_cost, self.d_optim], feed_dict={self.inputs: _data}) #if we have reached the end of the real samples set, we start over and increment the number of epochs if counter_batch == int( self.training_samples.shape[1] / self.batch_size) - 1: counter_batch = 0 epoch += 1 else: counter_batch += 1 aux = time.time() - start_time #plot the critics loss and the iteration time plot.plot(self.sample_dir, 'train disc cost', -_disc_cost) plot.plot(self.sample_dir, 'time', aux) if (iteration + ckpt_name == 500) or iteration % 20000 == 19999 or ( iteration + ckpt_name >= config.num_iter - 10): print('epoch ' + str(epoch)) if config.dataset == 'uniform' or config.dataset == 'packets': #this is to evaluate whether the discriminator has overfit dev_disc_costs = [] for ind_dev in range( int(self.dev_samples.shape[1] / self.batch_size)): images = self.dev_samples[:, ind_dev * config.batch_size:(ind_dev + 1) * config.batch_size].T _dev_disc_cost = self.sess.run( self.disc_cost, feed_dict={self.inputs: images}) dev_disc_costs.append(_dev_disc_cost) #plot the dev loss plot.plot(self.sample_dir, 'dev disc cost', -np.mean(dev_disc_costs)) #save the network parameters self.save(iteration + ckpt_name) #get simulated samples, calculate their statistics and compare them with the original ones fake_samples = self.sess.run([self.ex_samples])[0] acf_error, mean_error, corr_error, time_course_error,_ = analysis.get_stats(X=fake_samples.T, num_neurons=config.num_neurons,\ num_bins=config.num_bins, folder=config.sample_dir, name='fake'+str(iteration+ckpt_name), critic_cost=-_disc_cost,instance=config.data_instance) #plot the fitting errors sbplt[0][0].plot(iteration + ckpt_name, mean_error, '+b') sbplt[0][0].set_title('spk-count mean error') sbplt[0][0].set_xlabel('iterations') sbplt[0][0].set_ylabel('L1 error') sbplt[0][0].set_xlim([ 0 - config.num_iter / 4, config.num_iter + config.num_iter / 4 ]) sbplt[0][1].plot(iteration + ckpt_name, time_course_error, '+b') sbplt[0][1].set_title('time course error') sbplt[0][1].set_xlabel('iterations') sbplt[0][1].set_ylabel('L1 error') sbplt[0][1].set_xlim([ 0 - config.num_iter / 4, config.num_iter + config.num_iter / 4 ]) sbplt[1][0].plot(iteration + ckpt_name, acf_error, '+b') sbplt[1][0].set_title('AC error') sbplt[1][0].set_xlabel('iterations') sbplt[1][0].set_ylabel('L1 error') sbplt[1][0].set_xlim([ 0 - config.num_iter / 4, config.num_iter + config.num_iter / 4 ]) sbplt[1][1].plot(iteration + ckpt_name, corr_error, '+b') sbplt[1][1].set_title('corr error') sbplt[1][1].set_xlabel('iterations') sbplt[1][1].set_ylabel('L1 error') sbplt[1][1].set_xlim([ 0 - config.num_iter / 4, config.num_iter + config.num_iter / 4 ]) f.savefig(self.sample_dir + 'fitting_errors.svg', dpi=600, bbox_inches='tight') plt.close(f) plot.flush(self.sample_dir) plot.tick()
start_time = time.time() # Train generator (and Encoder) if (iteration > 0): _data = next(gen) # [0,1] _real_data = (_data[:, 1, :]).reshape( (BATCH_SIZE, output_dim)) # current frame for now _cond_data = (_data[:, 0, :]).reshape( (BATCH_SIZE, output_dim)) # one last frame for now _, summary_str = session.run([train_op, merged_summary_op_vae], feed_dict={ real_data: _real_data, condition_data: _cond_data }) summary_writer.add_summary(summary_str, iteration) iteration_time = time.time() - start_time plotter.plot('time', iteration_time) # Validation: Calculate validation loss and generate samples every 1000 iters if (iteration % 100 == 0): # or iteration < 10): dev_disc_costs = [] dev_vae_losses = [] _data = next(dev_generator) # [0,1] _real_data = (_data[:, 1, :]).reshape( (BATCH_SIZE, output_dim)) # current frame for now _cond_data = (_data[:, 0, :]).reshape( (BATCH_SIZE, output_dim)) # one last frame for now _dev_vae_loss = session.run(vae_loss, feed_dict={ real_data: _real_data, condition_data: _cond_data })
# Train generator if iteration > 0: for i in xrange(GEN_ITERS): _ = session.run(gen_optim) # Train critic for i in xrange(CRITIC_ITERS): if TRAIN_DETECTOR: data, labels = data_gen.mix_real_adv(gen_adv, gen2) d1_cost_, d2_cost_, _, _ = session.run( [d1_cost, d2_cost, d1_optim, d2_optim], feed_dict={ real_data_int: data, y: labels }) plot.plot('train d2-bce cost', d2_cost_) plot.plot('time', time.time() - start_time) else: data, labels = gen1.next() d1_cost_, _ = session.run([d1_cost, d1_optim], feed_dict={ real_data_int: data, y: labels }) plot.plot('train d1-wgan cost', d1_cost_) plot.plot('time', time.time() - start_time) # Calculate inception score every 1K iters if iteration % 1000 == 999: inception_score = test_inception()
def generate_image( frame, final ): # generates a batch of samples next to each other in one image! inference_start_time = time.time() if (MODE == 'cond_ordered'): samples = session.run(fixed_noise_samples, feed_dict={condition_data: fixed_labels_array}) # [0,1] elif (MODE == 'cond'): samples = session.run(fixed_noise_samples, feed_dict={condition_data: sorted_labels}) # [0,1] elif (MODE == 'plain'): samples = session.run(fixed_noise_samples) # [0,1] else: samples = session.run(fixed_noise_samples, feed_dict={real_data: sorted_data}) # [0,1] inference_end_time = time.time() # inference time analysis inference_time = (inference_end_time - inference_start_time) print("The architecture took ", inference_time, "sec for the generation of ", BATCH_SIZE, "images") samples_255 = ((samples) * (255.)).astype('uint8') # [0,255] if (MODE == 'enc' or MODE == 'vae' or MODE == 'cond'): for i in range(0, BATCH_SIZE): samples_255 = np.insert(samples_255, i * 2, fixed_data_255[i, :], axis=0) # real (cond digit) next to sample imsaver.save_images(samples_255.reshape( (2 * BATCH_SIZE, IM_DIM, IM_DIM)), 'samples_{}.jpg'.format(frame), alternate_viz=True, conds=True) print("Iteration %d :" % frame) # compare generated to real one real = tf.reshape(sorted_data, [BATCH_SIZE, IM_DIM, IM_DIM, 1]) pred = tf.reshape(samples, [BATCH_SIZE, IM_DIM, IM_DIM, 1]) ssim_vals = tf.image.ssim(real, pred, max_val=1.0) # on batch mse_vals = tf.reduce_mean(tf.keras.metrics.mse(real, pred), [1, 2]) # mse on grayscale, on [0,1] psnr_vals = tf.image.psnr(real, pred, max_val=1.0) # on batch, out tf.float32 ssim_avg = (tf.reduce_mean(ssim_vals)).eval() mse_avg = (tf.reduce_mean(mse_vals)).eval() psnr_avg = (tf.reduce_mean(psnr_vals)).eval() plotter.plot( 'SSIM avg', ssim_avg) # show average of ssim and mse vals over whole batch plotter.plot('MSE avg', mse_avg) plotter.plot('PSNR avg', psnr_avg) if (final): print('final iteration %d SSIM avg: %.2f MSE avg: %.2f' % (iteration, ssim_avg, mse_avg)) ssim_vals_list = ssim_vals.eval() mse_vals_list = mse_vals.eval() psnr_vals_list = psnr_vals.eval() print(ssim_vals_list) # save it in nohup.out print(mse_vals_list) print(psnr_vals_list) else: imsaver.save_images(samples_255.reshape( (BATCH_SIZE, IM_DIM, IM_DIM)), 'samples_{}.jpg'.format(frame)) # , alternate_viz=True) if (MODE == 'vae'): noise_samples = session.run(more_noise_samples) # [0,1] noise_samples_255 = ((noise_samples) * (255.)).astype( 'uint8') # [0,255] imsaver.save_images( noise_samples_255.reshape((BATCH_SIZE, IM_DIM, IM_DIM)), 'noise_samples_{}.jpg'.format(frame)) # , alternate_viz=True) if (final): # calculate accuracy of samples (mnist classificator) if (MODE == 'cond_ordered'): _fixed_labels = np.zeros((fixed_labels_array.size, N_LABELS)) _fixed_labels[np.arange(fixed_labels_array.size), fixed_labels_array] = 1 # np to one-hot else: _fixed_labels = np.zeros((sorted_labels.size, N_LABELS)) _fixed_labels[np.arange(sorted_labels.size), sorted_labels] = 1 # np to one-hot accu = session.run(accuracy, feed_dict={x: samples, y: _fixed_labels}) print('Accuracy at step %d: %s' % (iteration, accu))
p.requires_grad = False # to avoid computation netG.zero_grad() noise = torch.randn(BATCH_SIZE, 128) if use_cuda: noise = noise.cuda(gpu) noisev = autograd.Variable(noise) fake = netG(noisev) G = netD(fake) G = G.mean() G.backward(mone) G_cost = -G optimizerG.step() # Write logs and save samples plot.plot('./tmp/cifar10/train disc cost', D_cost.cpu().data.numpy()) plot.plot('./tmp/cifar10/time', time.time() - start_time) plot.plot('./tmp/cifar10/train gen cost', G_cost.cpu().data.numpy()) plot.plot('./tmp/cifar10/wasserstein distance', Wasserstein_D.cpu().data.numpy()) # Calculate inception score every 1K iters if False and iteration % 1000 == 999: inception_score = get_inception_score(netG) plot.plot('./tmp/cifar10/inception score', inception_score[0]) # Calculate dev loss and generate samples every 100 iters if iteration % 100 == 99: dev_disc_costs = [] for images, _ in dev_gen: images = images.reshape(BATCH_SIZE, 3, 32, 32).permute(0, 2, 3, 1)
session.run(tf.global_variables_initializer()) gen = inf_train_gen() start_time = time.time() for iteration in range(ITERS): # Train generator if iteration > 0: _ = session.run(gen_train_op) # Train critic disc_iters = CRITIC_ITERS for i in range(disc_iters): _data = next(gen) _disc_cost, _ = session.run([disc_cost, disc_train_op], feed_dict={real_data: _data}) plot.plot(FOLDER, 'train disc cost', _disc_cost) plot.plot(FOLDER, 'time', time.time() - start_time) if (iteration < 5) or iteration % 200 == 199: t = time.time() dev_disc_costs = [] for (images, ) in dev_gen(): _dev_disc_cost = session.run(disc_cost, feed_dict={real_data: images}) dev_disc_costs.append(_dev_disc_cost) plot.plot(FOLDER, 'dev disc cost', np.mean(dev_disc_costs)) generate_image(iteration) save(iteration) if (iteration < 5) or (iteration % 200 == 199):
def train(self, config): """Train DCGAN""" #define optimizer self.g_optim = tf.train.AdamOptimizer( learning_rate=config.learning_rate, beta1=config.beta1, beta2=config.beta2).minimize( self.gen_cost, var_list=params_with_name('Generator'), colocate_gradients_with_ops=True) self.d_optim = tf.train.AdamOptimizer( learning_rate=config.learning_rate, beta1=config.beta1, beta2=config.beta2).minimize( self.disc_cost, var_list=params_with_name('Discriminator.'), colocate_gradients_with_ops=True) #initizialize variables try: tf.global_variables_initializer().run() except: tf.initialize_all_variables().run() #try to load trained parameters self.load() #get real samples if config.dataset == 'uniform': firing_rates_mat = config.firing_rate + 2 * ( np.random.random(int(self.num_neurons / config.group_size), ) - 0.5) * config.firing_rate / 2 correlations_mat = config.correlation + 2 * ( np.random.random(int(self.num_neurons / config.group_size), ) - 0.5) * config.correlation / 2 aux = np.arange(int(self.num_neurons / config.group_size)) activity_peaks = [ [x] * config.group_size for x in aux ] #np.random.randint(0,high=self.num_bins,size=(1,self.num_neurons)).reshape(self.num_neurons,1) activity_peaks = np.asarray(activity_peaks) activity_peaks = activity_peaks.flatten() activity_peaks = activity_peaks * config.group_size * self.num_bins / self.num_neurons activity_peaks = activity_peaks.reshape(self.num_neurons, 1) #activity_peaks = np.zeros((self.num_neurons,1))+self.num_bins/4 self.real_samples = sim_pop_activity.get_samples(num_samples=config.num_samples, num_bins=self.num_bins,\ num_neurons=self.num_neurons, correlations_mat=correlations_mat, group_size=config.group_size, refr_per=config.ref_period,firing_rates_mat=firing_rates_mat, activity_peaks=activity_peaks) #get dev samples dev_samples = sim_pop_activity.get_samples(num_samples=int(config.num_samples/4), num_bins=self.num_bins,\ num_neurons=self.num_neurons, correlations_mat=correlations_mat, group_size=config.group_size, refr_per=config.ref_period,firing_rates_mat=firing_rates_mat, activity_peaks=activity_peaks) #save original statistics analysis.get_stats(X=self.real_samples, num_neurons=self.num_neurons, num_bins=self.num_bins, folder=self.sample_dir, name='real', firing_rate_mat=firing_rates_mat, correlation_mat=correlations_mat, activity_peaks=activity_peaks) elif config.dataset == 'retina': self.real_samples = retinal_data.get_samples( num_bins=self.num_bins, num_neurons=self.num_neurons, instance=config.data_instance) #save original statistics analysis.get_stats(X=self.real_samples, num_neurons=self.num_neurons, num_bins=self.num_bins, folder=self.sample_dir, name='real', instance=config.data_instance) #count number of variables total_parameters = 0 for variable in tf.trainable_variables(): # shape is an array of tf.Dimension shape = variable.get_shape() variable_parameters = 1 for dim in shape: variable_parameters *= dim.value total_parameters += variable_parameters print('number of varaibles: ' + str(total_parameters)) #start training counter_batch = 0 epoch = 0 #fitting errors f, sbplt = plt.subplots(2, 2, figsize=(8, 8), dpi=250) matplotlib.rcParams.update({'font.size': 8}) plt.subplots_adjust(left=left, bottom=bottom, right=right, top=top, wspace=wspace, hspace=hspace) for iteration in range(config.num_iter): start_time = time.time() # Train generator (only after the critic has been trained, at least once) if iteration > 0: _ = self.sess.run(self.g_optim) # Train critic disc_iters = config.critic_iters for i in range(disc_iters): #get batch and trained critic _data = self.real_samples[:, counter_batch * config.batch_size:(counter_batch + 1) * config.batch_size].T _disc_cost, _ = self.sess.run([self.disc_cost, self.d_optim], feed_dict={self.inputs: _data}) #if we have reached the end of the real samples set, we start over and increment the number of epochs if counter_batch == int( self.real_samples.shape[1] / self.batch_size) - 1: counter_batch = 0 epoch += 1 else: counter_batch += 1 aux = time.time() - start_time #plot the critics loss and the iteration time plot.plot(self.sample_dir, 'train disc cost', -_disc_cost) plot.plot(self.sample_dir, 'time', aux) if ( iteration == 500 ) or iteration % 20000 == 19999 or iteration > config.num_iter - 10: print('epoch ' + str(epoch)) if config.dataset == 'uniform': #this is to evaluate whether the discriminator has overfit dev_disc_costs = [] for ind_dev in range( int(dev_samples.shape[1] / self.batch_size)): images = dev_samples[:, ind_dev * config.batch_size:(ind_dev + 1) * config.batch_size].T _dev_disc_cost = self.sess.run( self.disc_cost, feed_dict={self.inputs: images}) dev_disc_costs.append(_dev_disc_cost) #plot the dev loss plot.plot(self.sample_dir, 'dev disc cost', -np.mean(dev_disc_costs)) #save the network parameters self.save(iteration) #get simulated samples, calculate their statistics and compare them with the original ones fake_samples = self.get_samples(num_samples=2**13) fake_samples = fake_samples.eval(session=self.sess) fake_samples = self.binarize(samples=fake_samples) acf_error, mean_error, corr_error, time_course_error,_ = analysis.get_stats(X=fake_samples.T, num_neurons=config.num_neurons,\ num_bins=config.num_bins, folder=config.sample_dir, name='fake'+str(iteration), critic_cost=-_disc_cost,instance=config.data_instance) #plot the fitting errors sbplt[0][0].plot(iteration, mean_error, '+b') sbplt[0][0].set_title('spk-count mean error') sbplt[0][0].set_xlabel('iterations') sbplt[0][0].set_ylabel('L1 error') sbplt[0][0].set_xlim([ 0 - config.num_iter / 4, config.num_iter + config.num_iter / 4 ]) sbplt[0][1].plot(iteration, time_course_error, '+b') sbplt[0][1].set_title('time course error') sbplt[0][1].set_xlabel('iterations') sbplt[0][1].set_ylabel('L1 error') sbplt[0][1].set_xlim([ 0 - config.num_iter / 4, config.num_iter + config.num_iter / 4 ]) sbplt[1][0].plot(iteration, acf_error, '+b') sbplt[1][0].set_title('AC error') sbplt[1][0].set_xlabel('iterations') sbplt[1][0].set_ylabel('L1 error') sbplt[1][0].set_xlim([ 0 - config.num_iter / 4, config.num_iter + config.num_iter / 4 ]) sbplt[1][1].plot(iteration, corr_error, '+b') sbplt[1][1].set_title('corr error') sbplt[1][1].set_xlabel('iterations') sbplt[1][1].set_ylabel('L1 error') sbplt[1][1].set_xlim([ 0 - config.num_iter / 4, config.num_iter + config.num_iter / 4 ]) f.savefig(self.sample_dir + 'fitting_errors.svg', dpi=600, bbox_inches='tight') plt.close(f) plot.flush(self.sample_dir) plot.tick()