Example #1
0
def generate_image(frame, true_dist):   # generates a batch of samples next to each other in one image!
    samples = session.run(fixed_noise_samples, feed_dict={real_data: fixed_real_data, condition_data: fixed_labels}) # [-1,1]
    #for im in samples: # add a image for the samples
    #    tf.summary.image("{}-image".format(grad.name.replace(":","_")), im)
    samples_255 = ((samples+1.)*(255./2)).astype('int32') # [0,255] 
    samples_01 = ((samples+1.)/2.).astype('float32') # [0,1]
    for i in range(0, BATCH_SIZE):
        samples_255 = np.insert(samples_255, i*2, fixed_real_data_255[i,:], axis=0) # show cond digit next to generated sample
    imsaver.save_images(samples_255.reshape((2*BATCH_SIZE, 1, IM_DIM, IM_DIM)), 'samples_{}.jpg'.format(frame))
    print("Iteration %d : \n" % frame)
    # compare generated to real one
    real = tf.reshape(fixed_real_data, [BATCH_SIZE,IM_DIM,IM_DIM,1])
    pred = tf.reshape(samples_01, [BATCH_SIZE,IM_DIM,IM_DIM,1])
    ssimval = tf.image.ssim(real, pred, max_val=1.0) # tensor batch in, out tensor of ssimvals (64,)
    mseval_per_entry = tf.keras.metrics.mse(real, pred)  # mse on grayscale, on [0,1]
    mseval = tf.reduce_mean(mseval_per_entry, [1,2])
    tf.summary.tensor_summary("SSIM values", ssimval)
    tf.summary.tensor_summary("MSE values", mseval)
    ssimval_list = ssimval.eval()  # to numpy array # (50,)
    mseval_list = mseval.eval() # (50,)
    # print(ssimval_list)
    # print(mseval_list)
    for i in range (0,3):
        plotter.plot('SSIM for sample %d' % (i+1), ssimval_list[i])
        plotter.plot('MSE for sample %d' % (i+1), mseval_list[i])
        print("sample %d \t MSE: %.5f \t SSIM: %.5f \r\n" % (i, mseval_list[i], ssimval_list[i]))
Example #2
0
def generate_image(
    frame, true_dist
):  # generates a batch of samples next to each other in one image!
    # test: generate some samples
    samples = session.run(fixed_noise_samples,
                          feed_dict={real_data: fixed_data
                                     })  # [-1,1] ### for MNIST # (50, 784)
    samples_255 = ((samples + 1.) * (255. / 2)).astype('int32')  # [0,255]
    samples_01 = ((samples + 1.) / 2.).astype('float32')  # [0,1]
    imsaver.save_images(samples_255.reshape((BATCH_SIZE, IM_DIM, IM_DIM)),
                        'samples_{}.png'.format(frame))  ### for MNIST
    print("Iteration %d : \n" % frame)
    # compare generated to real ones
    real = tf.reshape(fixed_data,
                      [BATCH_SIZE, IM_DIM, IM_DIM, 1])  ### for MNIST
    # real_gray = tf.image.rgb_to_grayscale(real) # tensor batch in&out returns original dtype = float [0,1] ### for MNIST
    pred = tf.reshape(samples_01,
                      [BATCH_SIZE, IM_DIM, IM_DIM, 1])  ### for MNIST
    # pred_gray = tf.image.rgb_to_grayscale(pred) ### for MNIST
    ssimval = tf.image.ssim(
        real, pred, max_val=1.0
    )  # in tensor batch, out tensor ssimvals (64,)  ### for MNIST
    mseval_per_entry = tf.keras.metrics.mse(
        real, pred)  # mse on grayscale, on [0,1] ### for MNIST
    mseval = tf.reduce_mean(mseval_per_entry, [1, 2])  ### for MNIST
    # ssimvals 0.2 to 0.75 :) # msevals 1-9 e -1 to -3
    ssimval_list = ssimval.eval()  # to numpy array # (64,)  # (50,)
    mseval_list = mseval.eval()  # (64,)   # (50,)
    # print(ssimval_list)
    # print(mseval_list)
    for i in range(0, 3):
        plotter.plot('SSIM for sample %d' % (i + 1), ssimval_list[i])
        plotter.plot('MSE for sample %d' % (i + 1), mseval_list[i])
        print("sample %d \t MSE: %.5f \t SSIM: %.5f \r\n" %
              (i, mseval_list[i], ssimval_list[i]))
Example #3
0
def generate_image(
    frame, true_dist
):  # generates a batch of samples next to each other in one image!
    samples = session.run(fixed_noise_samples,
                          feed_dict={
                              real_data_int: fixed_real_data_int,
                              cond_data_int: fixed_cond_data_int
                          })  # [-1,1]
    samples_255 = ((samples + 1.) * (255. / 2)).astype('int32')  # [0,255]
    samples_01 = ((samples + 1.) / 2.).astype('float32')  # [0,1]
    for i in range(0, BATCH_SIZE):
        samples_255 = np.insert(
            samples_255, i * 2, fixed_cond_data_int[i],
            axis=0)  # show last frame next to generated sample
    imsaver.save_images(
        samples_255.reshape((2 * BATCH_SIZE, 3, IM_DIM, IM_DIM)),
        'samples_{}.jpg'.format(frame))
    print("Iteration %d : \n" % frame)
    # compare generated to real one
    real = tf.reshape(fixed_real_data_norm01, [BATCH_SIZE, IM_DIM, IM_DIM, 3])
    real_gray = tf.image.rgb_to_grayscale(
        real)  # tensor batch in&out; returns original dtype = float [0,1]
    pred = tf.reshape(samples_01, [BATCH_SIZE, IM_DIM, IM_DIM, 3])
    pred_gray = tf.image.rgb_to_grayscale(pred)
    ssimval = tf.image.ssim(
        real_gray, pred_gray,
        max_val=1.0)  # tensor batch in, out tensor of ssimvals (64,)
    mseval_per_entry = tf.keras.metrics.mse(
        real_gray, pred_gray)  # mse on grayscale, on [0,1]
    mseval = tf.reduce_mean(mseval_per_entry, [1, 2])
    # ssimvals 0.2 to 0.75 :) # msevals 1-9 e -1 to -3
    ssimval_list = ssimval.eval()  # to numpy array # (64,)
    mseval_list = mseval.eval()  # (64,)
    #print(ssimval_list)
    # print(mseval_list)
    for i in range(0, 3):
        plotter.plot('SSIM for sample %d' % (i + 1), ssimval_list[i])
        plotter.plot('MSE for sample %d' % (i + 1), mseval_list[i])
        print("sample %d \t MSE: %.5f \t SSIM: %.5f \r\n" %
              (i, mseval_list[i], ssimval_list[i]))
def generate_image(
        frame, final
):  # generates a batch of samples next to each other in one image!
    samples = session.run(fixed_noise_samples,
                          feed_dict={
                              condition_data: fixed_cond_data,
                              time_data: fixed_time_data
                          })  # [0,1]
    samples_255 = ((samples) * 255.99).astype('uint8')  # [0,1] -> [0,255]
    #print(samples.min()) #print(samples.max())

    # add ground truth
    for i in range(0, BATCH_SIZE):
        samples_255 = np.insert(samples_255,
                                i * 4,
                                fixed_cond_2_data_255[i, :],
                                axis=0)  # cond_time left of sample
        samples_255 = np.insert(samples_255,
                                i * 4 + 1,
                                fixed_cond_data_255[i, :],
                                axis=0)  # cond left of sample
        samples_255 = np.insert(samples_255,
                                i * 4 + 3,
                                fixed_real_data_255[i, :],
                                axis=0)  # real right of sample
    imsaver.save_images(samples_255.reshape((4 * BATCH_SIZE, IM_DIM, IM_DIM)),
                        'samples_{}.jpg'.format(frame),
                        alternate_viz=True,
                        conds=True,
                        gt=True,
                        time=True)

    print("Iteration %d :" % frame)
    # compare generated to real one
    real = tf.reshape(fixed_cond_data, [BATCH_SIZE, IM_DIM, IM_DIM, 1])
    pred = tf.reshape(samples, [BATCH_SIZE, IM_DIM, IM_DIM, 1])
    ssim_vals = tf.image.ssim(real, pred, max_val=1.0)  # on batch
    mse_vals = tf.reduce_mean(tf.keras.metrics.mse(real, pred),
                              [1, 2])  # mse on grayscale, on [0,1]
    psnr_vals = tf.image.psnr(real, pred,
                              max_val=1.0)  # on batch, out tf.float32
    ssim_avg = (tf.reduce_mean(ssim_vals)).eval()
    mse_avg = (tf.reduce_mean(mse_vals)).eval()
    psnr_avg = (tf.reduce_mean(psnr_vals)).eval()
    plotter.plot(
        'SSIM avg',
        ssim_avg)  # show average of ssim and mse vals over whole batch
    plotter.plot('MSE avg', mse_avg)
    plotter.plot('PSNR avg', psnr_avg)
    if (final):
        print('final iteration %d SSIM avg: %.2f MSE avg: %.2f' %
              (iteration, ssim_avg, mse_avg))
        ssim_vals_list = ssim_vals.eval()
        mse_vals_list = mse_vals.eval()
        psnr_vals_list = psnr_vals.eval()
        print(ssim_vals_list)  # save it in nohup.out
        print(mse_vals_list)
        print(psnr_vals_list)
def generate_image(frame, final): # generates a batch of samples next to each other in one image!
    inference_start_time = time.time()  # inference time analysis
    if(MODE == 'cond' or MODE == 'enc'):
        samples = session.run(fixed_noise_samples, feed_dict={condition_data: fixed_cond_data}) # [0,1]
    elif(MODE == 'plain'):
        samples = session.run(fixed_noise_samples) # [0,1]
    elif(MODE == 'vae'):
        samples = session.run(fixed_noise_samples, feed_dict={real_data: fixed_cond_data}) # [0,1]
        samples_noise = session.run(_noise_samples, feed_dict={real_data: fixed_cond_data}) # [0,1]
        noise_samples_255 = ((samples_noise)*255.).astype('uint8') # [0, 1] -> [0,255] 
    inference_end_time = time.time()  # inference time analysis
    inference_time = (inference_end_time - inference_start_time)
    print("The architecture took ", inference_time, "sec for the generation of ", BATCH_SIZE, "images")

    samples_255 = ((samples)*255.).astype('uint8') # [0, 1] -> [0,255] 
    # print('samples 255') # print(samples_255.min()) # print(samples_255.max())

    if(MODE == 'plain'):
        imsaver.save_images(samples_255.reshape((BATCH_SIZE, IM_DIM, IM_DIM)), 'samples_{}.jpg'.format(frame), alternate_viz=True)  
    elif(MODE == 'vae'):
        for i in range(0, BATCH_SIZE):
            samples_255 = np.insert(samples_255, i*2, fixed_cond_data_255[i,:], axis=0) # real (cond digit) next to sample
        imsaver.save_images(samples_255.reshape((2*BATCH_SIZE, IM_DIM, IM_DIM)), 'samples_{}.jpg'.format(frame), alternate_viz=True, conds=True)  
        imsaver.save_images(noise_samples_255.reshape((BATCH_SIZE, IM_DIM, IM_DIM)), 'noise_samples_{}.jpg'.format(frame), alternate_viz=True)  
    else: # if(MODE == 'enc' or MODE == 'cond') add ground truth
        for i in range(0, BATCH_SIZE):
            samples_255 = np.insert(samples_255, i*3, fixed_cond_data_255[i,:], axis=0) # cond left of sample
            samples_255 = np.insert(samples_255, i*3+2, fixed_real_data_255[i,:], axis=0) # real right of sample
        imsaver.save_images(samples_255.reshape((3*BATCH_SIZE, IM_DIM, IM_DIM)), 'samples_{}.jpg'.format(frame), alternate_viz=True, conds=True, gt=True)  

        print("Iteration %d :" % frame)
        # compare generated to real one
        real = tf.reshape(fixed_cond_data, [BATCH_SIZE,IM_DIM,IM_DIM,1])
        pred = tf.reshape(samples, [BATCH_SIZE,IM_DIM,IM_DIM,1])
        ssim_vals = tf.image.ssim(real, pred, max_val=1.0) # on batch
        mse_vals = tf.reduce_mean(tf.keras.metrics.mse(real, pred), [1,2]) # mse on grayscale, on [0,1]
        psnr_vals = tf.image.psnr(real, pred, max_val=1.0) # on batch, out tf.float32
        ssim_avg = (tf.reduce_mean(ssim_vals)).eval()
        mse_avg = (tf.reduce_mean(mse_vals)).eval()
        psnr_avg = (tf.reduce_mean(psnr_vals)).eval()
        plotter.plot('SSIM avg', ssim_avg) # show average of ssim and mse vals over whole batch
        plotter.plot('MSE avg', mse_avg)
        plotter.plot('PSNR avg', psnr_avg)
        if(final):
            print('final iteration %d SSIM avg: %.2f MSE avg: %.2f' % (iteration, ssim_avg, mse_avg)) 
            ssim_vals_list = ssim_vals.eval()  
            mse_vals_list = mse_vals.eval()  
            psnr_vals_list = psnr_vals.eval()   
            print(ssim_vals_list) # save it in nohup.out
            print(mse_vals_list)
            print(psnr_vals_list)         
            _data = next(gen)  # [0,1] # (50, 6, 4096)
            _real_data = (_data[:, 5, :]).reshape(
                (BATCH_SIZE, output_dim))  # current frame for now
            _cond_data = (_data[:, 4, :]).reshape(
                (BATCH_SIZE, output_dim))  # one last frame for now
            _time_data = (_data[:, (5 - TIMESTEPS):5, :]).reshape(
                (BATCH_SIZE, TIMESTEPS, output_dim))  # past frames
            _disc_cost, _, summary_str = session.run(
                [disc_cost, disc_train_op, merged_summary_op_d],
                feed_dict={
                    real_data: _real_data,
                    condition_data: _cond_data,
                    time_data: _time_data
                })
        summary_writer.add_summary(summary_str, iteration)
        plotter.plot('train disc cost', _disc_cost)
        iteration_time = time.time() - start_time
        plotter.plot('time', iteration_time)

        # Validation: Calculate validation loss and generate samples every 100 iters
        if (iteration % 100 == 99):  # or iteration < 10):
            dev_disc_costs = []
            dev_vae_losses = []
            _data = next(dev_generator)  # [0,1] # (50, 6, 4096)
            _real_data = (_data[:, 5, :]).reshape(
                (BATCH_SIZE, output_dim))  # current frame for now
            _cond_data = (_data[:, 4, :]).reshape(
                (BATCH_SIZE, output_dim))  # one last frame for now
            _time_data = (_data[:, (5 - TIMESTEPS):5, :]).reshape(
                (BATCH_SIZE, TIMESTEPS, output_dim))  # past frames
            _dev_disc_cost, _dev_gen_cost = session.run(
Example #7
0
         session.run(init_op)	

    for iteration in range(START_ITER, ITERS):  # START_ITER: 0 or from last checkpoint
        start_time = time.time()
        # Train generator
        if iteration > 0:
            _data = next(gen)  # shape: (batchsize, 6144)
            _cond_data = _data # digit as cond
            _ = session.run(gen_train_op, feed_dict={cond_data: _cond_data})
        # Train duscriminator
        for i in range(DISC_ITERS):
            _data = next(gen)  # shape: (batchsize, 6144)
            _cond_data = _data # digit as cond
            _real_data = _data # current frame for disc
            _disc_cost, _ = session.run([disc_cost, disc_train_op], feed_dict={real_data: _real_data, cond_data: _cond_data})
        plotter.plot('train disc cost', _disc_cost)
        plotter.plot('time', time.time() - start_time)

        # Calculate dev loss and generate samples every 100 iters
        if iteration % 100 == 99:
            dev_disc_costs = []
            _data = next(dev_generator)  # shape: (batchsize, 6144)
            _cond_data = _data # digit as cond
            _real_data = _data # current frame for disc
            _dev_disc_cost = session.run(disc_cost, feed_dict={real_data: _real_data, cond_data: _cond_data})
            dev_disc_costs.append(_dev_disc_cost)
            plotter.plot('dev disc cost', np.mean(dev_disc_costs))
            generate_image(iteration, _data)
            save_path = saver.save(session, restore_path) # Save the variables to disk.
            print("Model saved in path: %s" % save_path)
            # chkp.print_tensors_in_checkpoint_file("model.ckpt", tensor_name='', all_tensors=True)
Example #8
0
    def train(self, config):
        """Train DCGAN"""
        #define optimizer
        self.g_optim = tf.train.AdamOptimizer(
            learning_rate=config.learning_rate,
            beta1=config.beta1,
            beta2=config.beta2).minimize(
                self.gen_cost,
                var_list=params_with_name('Generator'),
                colocate_gradients_with_ops=True)
        self.d_optim = tf.train.AdamOptimizer(
            learning_rate=config.learning_rate,
            beta1=config.beta1,
            beta2=config.beta2).minimize(
                self.disc_cost,
                var_list=params_with_name('Discriminator.'),
                colocate_gradients_with_ops=True)

        tf.global_variables_initializer().run()

        #try to load trained parameters
        print('-------------')
        existing_gan, ckpt_name = self.load()

        #count number of variables
        total_parameters = 0
        for variable in tf.trainable_variables():
            # shape is an array of tf.Dimension
            shape = variable.get_shape()
            variable_parameters = 1
            for dim in shape:
                variable_parameters *= dim.value
            total_parameters += variable_parameters
        print('-------------')
        print('number of variables: ' + str(total_parameters))
        print('-------------')
        #start training
        counter_batch = 0
        epoch = 0
        #fitting errors
        f, sbplt = plt.subplots(2, 2, figsize=(8, 8), dpi=250)
        matplotlib.rcParams.update({'font.size': 8})
        plt.subplots_adjust(left=left,
                            bottom=bottom,
                            right=right,
                            top=top,
                            wspace=wspace,
                            hspace=hspace)
        for iteration in range(config.num_iter):
            start_time = time.time()
            # Train generator (only after the critic has been trained, at least once)
            if iteration + ckpt_name > 0:
                _ = self.sess.run(self.g_optim)

            # Train critic
            disc_iters = config.critic_iters
            for i in range(disc_iters):
                #get batch and update critic
                _data = self.training_samples[:, counter_batch *
                                              config.batch_size:
                                              (counter_batch + 1) *
                                              config.batch_size].T
                _disc_cost, _ = self.sess.run([self.disc_cost, self.d_optim],
                                              feed_dict={self.inputs: _data})
                #if we have reached the end of the real samples set, we start over and increment the number of epochs
                if counter_batch == int(
                        self.training_samples.shape[1] / self.batch_size) - 1:
                    counter_batch = 0
                    epoch += 1
                else:
                    counter_batch += 1
            aux = time.time() - start_time
            #plot the  critics loss and the iteration time
            plot.plot(self.sample_dir, 'train disc cost', -_disc_cost)
            plot.plot(self.sample_dir, 'time', aux)

            if (iteration + ckpt_name
                    == 500) or iteration % 20000 == 19999 or (
                        iteration + ckpt_name >= config.num_iter - 10):
                print('epoch ' + str(epoch))
                if config.dataset == 'uniform' or config.dataset == 'packets':
                    #this is to evaluate whether the discriminator has overfit
                    dev_disc_costs = []
                    for ind_dev in range(
                            int(self.dev_samples.shape[1] / self.batch_size)):
                        images = self.dev_samples[:, ind_dev *
                                                  config.batch_size:(ind_dev +
                                                                     1) *
                                                  config.batch_size].T
                        _dev_disc_cost = self.sess.run(
                            self.disc_cost, feed_dict={self.inputs: images})
                        dev_disc_costs.append(_dev_disc_cost)
                    #plot the dev loss
                    plot.plot(self.sample_dir, 'dev disc cost',
                              -np.mean(dev_disc_costs))

                #save the network parameters
                self.save(iteration + ckpt_name)

                #get simulated samples, calculate their statistics and compare them with the original ones
                fake_samples = self.sess.run([self.ex_samples])[0]
                acf_error, mean_error, corr_error, time_course_error,_ = analysis.get_stats(X=fake_samples.T, num_neurons=config.num_neurons,\
                    num_bins=config.num_bins, folder=config.sample_dir, name='fake'+str(iteration+ckpt_name), critic_cost=-_disc_cost,instance=config.data_instance)
                #plot the fitting errors
                sbplt[0][0].plot(iteration + ckpt_name, mean_error, '+b')
                sbplt[0][0].set_title('spk-count mean error')
                sbplt[0][0].set_xlabel('iterations')
                sbplt[0][0].set_ylabel('L1 error')
                sbplt[0][0].set_xlim([
                    0 - config.num_iter / 4,
                    config.num_iter + config.num_iter / 4
                ])
                sbplt[0][1].plot(iteration + ckpt_name, time_course_error,
                                 '+b')
                sbplt[0][1].set_title('time course error')
                sbplt[0][1].set_xlabel('iterations')
                sbplt[0][1].set_ylabel('L1 error')
                sbplt[0][1].set_xlim([
                    0 - config.num_iter / 4,
                    config.num_iter + config.num_iter / 4
                ])
                sbplt[1][0].plot(iteration + ckpt_name, acf_error, '+b')
                sbplt[1][0].set_title('AC error')
                sbplt[1][0].set_xlabel('iterations')
                sbplt[1][0].set_ylabel('L1 error')
                sbplt[1][0].set_xlim([
                    0 - config.num_iter / 4,
                    config.num_iter + config.num_iter / 4
                ])
                sbplt[1][1].plot(iteration + ckpt_name, corr_error, '+b')
                sbplt[1][1].set_title('corr error')
                sbplt[1][1].set_xlabel('iterations')
                sbplt[1][1].set_ylabel('L1 error')
                sbplt[1][1].set_xlim([
                    0 - config.num_iter / 4,
                    config.num_iter + config.num_iter / 4
                ])
                f.savefig(self.sample_dir + 'fitting_errors.svg',
                          dpi=600,
                          bbox_inches='tight')
                plt.close(f)
                plot.flush(self.sample_dir)

            plot.tick()
Example #9
0
        start_time = time.time()
        # Train generator (and Encoder)
        if (iteration > 0):
            _data = next(gen)  # [0,1]
            _real_data = (_data[:, 1, :]).reshape(
                (BATCH_SIZE, output_dim))  # current frame for now
            _cond_data = (_data[:, 0, :]).reshape(
                (BATCH_SIZE, output_dim))  # one last frame for now
            _, summary_str = session.run([train_op, merged_summary_op_vae],
                                         feed_dict={
                                             real_data: _real_data,
                                             condition_data: _cond_data
                                         })
            summary_writer.add_summary(summary_str, iteration)
        iteration_time = time.time() - start_time
        plotter.plot('time', iteration_time)

        # Validation: Calculate validation loss and generate samples every 1000 iters
        if (iteration % 100 == 0):  # or iteration < 10):
            dev_disc_costs = []
            dev_vae_losses = []
            _data = next(dev_generator)  # [0,1]
            _real_data = (_data[:, 1, :]).reshape(
                (BATCH_SIZE, output_dim))  # current frame for now
            _cond_data = (_data[:, 0, :]).reshape(
                (BATCH_SIZE, output_dim))  # one last frame for now
            _dev_vae_loss = session.run(vae_loss,
                                        feed_dict={
                                            real_data: _real_data,
                                            condition_data: _cond_data
                                        })
Example #10
0
        # Train generator
        if iteration > 0:
            for i in xrange(GEN_ITERS):
                _ = session.run(gen_optim)

        # Train critic
        for i in xrange(CRITIC_ITERS):
            if TRAIN_DETECTOR:
                data, labels = data_gen.mix_real_adv(gen_adv, gen2)
                d1_cost_, d2_cost_, _, _ = session.run(
                    [d1_cost, d2_cost, d1_optim, d2_optim],
                    feed_dict={
                        real_data_int: data,
                        y: labels
                    })
                plot.plot('train d2-bce cost', d2_cost_)
                plot.plot('time', time.time() - start_time)
            else:
                data, labels = gen1.next()
                d1_cost_, _ = session.run([d1_cost, d1_optim],
                                          feed_dict={
                                              real_data_int: data,
                                              y: labels
                                          })

            plot.plot('train d1-wgan cost', d1_cost_)
            plot.plot('time', time.time() - start_time)

        # Calculate inception score every 1K iters
        if iteration % 1000 == 999:
            inception_score = test_inception()
Example #11
0
def generate_image(
        frame, final
):  # generates a batch of samples next to each other in one image!
    inference_start_time = time.time()
    if (MODE == 'cond_ordered'):
        samples = session.run(fixed_noise_samples,
                              feed_dict={condition_data:
                                         fixed_labels_array})  # [0,1]
    elif (MODE == 'cond'):
        samples = session.run(fixed_noise_samples,
                              feed_dict={condition_data:
                                         sorted_labels})  # [0,1]
    elif (MODE == 'plain'):
        samples = session.run(fixed_noise_samples)  # [0,1]
    else:
        samples = session.run(fixed_noise_samples,
                              feed_dict={real_data: sorted_data})  # [0,1]
    inference_end_time = time.time()  # inference time analysis
    inference_time = (inference_end_time - inference_start_time)
    print("The architecture took ", inference_time,
          "sec for the generation of ", BATCH_SIZE, "images")

    samples_255 = ((samples) * (255.)).astype('uint8')  # [0,255]

    if (MODE == 'enc' or MODE == 'vae' or MODE == 'cond'):
        for i in range(0, BATCH_SIZE):
            samples_255 = np.insert(samples_255,
                                    i * 2,
                                    fixed_data_255[i, :],
                                    axis=0)  # real (cond digit) next to sample
        imsaver.save_images(samples_255.reshape(
            (2 * BATCH_SIZE, IM_DIM, IM_DIM)),
                            'samples_{}.jpg'.format(frame),
                            alternate_viz=True,
                            conds=True)
        print("Iteration %d :" % frame)
        # compare generated to real one
        real = tf.reshape(sorted_data, [BATCH_SIZE, IM_DIM, IM_DIM, 1])
        pred = tf.reshape(samples, [BATCH_SIZE, IM_DIM, IM_DIM, 1])
        ssim_vals = tf.image.ssim(real, pred, max_val=1.0)  # on batch
        mse_vals = tf.reduce_mean(tf.keras.metrics.mse(real, pred),
                                  [1, 2])  # mse on grayscale, on [0,1]
        psnr_vals = tf.image.psnr(real, pred,
                                  max_val=1.0)  # on batch, out tf.float32
        ssim_avg = (tf.reduce_mean(ssim_vals)).eval()
        mse_avg = (tf.reduce_mean(mse_vals)).eval()
        psnr_avg = (tf.reduce_mean(psnr_vals)).eval()
        plotter.plot(
            'SSIM avg',
            ssim_avg)  # show average of ssim and mse vals over whole batch
        plotter.plot('MSE avg', mse_avg)
        plotter.plot('PSNR avg', psnr_avg)
        if (final):
            print('final iteration %d SSIM avg: %.2f MSE avg: %.2f' %
                  (iteration, ssim_avg, mse_avg))
            ssim_vals_list = ssim_vals.eval()
            mse_vals_list = mse_vals.eval()
            psnr_vals_list = psnr_vals.eval()
            print(ssim_vals_list)  # save it in nohup.out
            print(mse_vals_list)
            print(psnr_vals_list)
    else:
        imsaver.save_images(samples_255.reshape(
            (BATCH_SIZE, IM_DIM,
             IM_DIM)), 'samples_{}.jpg'.format(frame))  # , alternate_viz=True)
    if (MODE == 'vae'):
        noise_samples = session.run(more_noise_samples)  # [0,1]
        noise_samples_255 = ((noise_samples) * (255.)).astype(
            'uint8')  # [0,255]
        imsaver.save_images(
            noise_samples_255.reshape((BATCH_SIZE, IM_DIM, IM_DIM)),
            'noise_samples_{}.jpg'.format(frame))  # , alternate_viz=True)
    if (final):  # calculate accuracy of samples (mnist classificator)
        if (MODE == 'cond_ordered'):
            _fixed_labels = np.zeros((fixed_labels_array.size, N_LABELS))
            _fixed_labels[np.arange(fixed_labels_array.size),
                          fixed_labels_array] = 1  # np to one-hot
        else:
            _fixed_labels = np.zeros((sorted_labels.size, N_LABELS))
            _fixed_labels[np.arange(sorted_labels.size),
                          sorted_labels] = 1  # np to one-hot
        accu = session.run(accuracy, feed_dict={x: samples, y: _fixed_labels})
        print('Accuracy at step %d: %s' % (iteration, accu))
Example #12
0
        p.requires_grad = False  # to avoid computation
    netG.zero_grad()

    noise = torch.randn(BATCH_SIZE, 128)
    if use_cuda:
        noise = noise.cuda(gpu)
    noisev = autograd.Variable(noise)
    fake = netG(noisev)
    G = netD(fake)
    G = G.mean()
    G.backward(mone)
    G_cost = -G
    optimizerG.step()

    # Write logs and save samples
    plot.plot('./tmp/cifar10/train disc cost', D_cost.cpu().data.numpy())
    plot.plot('./tmp/cifar10/time', time.time() - start_time)
    plot.plot('./tmp/cifar10/train gen cost', G_cost.cpu().data.numpy())
    plot.plot('./tmp/cifar10/wasserstein distance',
              Wasserstein_D.cpu().data.numpy())

    # Calculate inception score every 1K iters
    if False and iteration % 1000 == 999:
        inception_score = get_inception_score(netG)
        plot.plot('./tmp/cifar10/inception score', inception_score[0])

    # Calculate dev loss and generate samples every 100 iters
    if iteration % 100 == 99:
        dev_disc_costs = []
        for images, _ in dev_gen:
            images = images.reshape(BATCH_SIZE, 3, 32, 32).permute(0, 2, 3, 1)
Example #13
0
    session.run(tf.global_variables_initializer())
    gen = inf_train_gen()
    start_time = time.time()
    for iteration in range(ITERS):
        # Train generator
        if iteration > 0:
            _ = session.run(gen_train_op)

        # Train critic
        disc_iters = CRITIC_ITERS
        for i in range(disc_iters):
            _data = next(gen)
            _disc_cost, _ = session.run([disc_cost, disc_train_op],
                                        feed_dict={real_data: _data})

        plot.plot(FOLDER, 'train disc cost', _disc_cost)
        plot.plot(FOLDER, 'time', time.time() - start_time)

        if (iteration < 5) or iteration % 200 == 199:
            t = time.time()
            dev_disc_costs = []
            for (images, ) in dev_gen():
                _dev_disc_cost = session.run(disc_cost,
                                             feed_dict={real_data: images})
                dev_disc_costs.append(_dev_disc_cost)
            plot.plot(FOLDER, 'dev disc cost', np.mean(dev_disc_costs))

            generate_image(iteration)
            save(iteration)

        if (iteration < 5) or (iteration % 200 == 199):
Example #14
0
    def train(self, config):
        """Train DCGAN"""
        #define optimizer
        self.g_optim = tf.train.AdamOptimizer(
            learning_rate=config.learning_rate,
            beta1=config.beta1,
            beta2=config.beta2).minimize(
                self.gen_cost,
                var_list=params_with_name('Generator'),
                colocate_gradients_with_ops=True)
        self.d_optim = tf.train.AdamOptimizer(
            learning_rate=config.learning_rate,
            beta1=config.beta1,
            beta2=config.beta2).minimize(
                self.disc_cost,
                var_list=params_with_name('Discriminator.'),
                colocate_gradients_with_ops=True)

        #initizialize variables
        try:
            tf.global_variables_initializer().run()
        except:
            tf.initialize_all_variables().run()

        #try to load trained parameters
        self.load()

        #get real samples
        if config.dataset == 'uniform':
            firing_rates_mat = config.firing_rate + 2 * (
                np.random.random(int(self.num_neurons / config.group_size), ) -
                0.5) * config.firing_rate / 2
            correlations_mat = config.correlation + 2 * (
                np.random.random(int(self.num_neurons / config.group_size), ) -
                0.5) * config.correlation / 2
            aux = np.arange(int(self.num_neurons / config.group_size))
            activity_peaks = [
                [x] * config.group_size for x in aux
            ]  #np.random.randint(0,high=self.num_bins,size=(1,self.num_neurons)).reshape(self.num_neurons,1)
            activity_peaks = np.asarray(activity_peaks)
            activity_peaks = activity_peaks.flatten()
            activity_peaks = activity_peaks * config.group_size * self.num_bins / self.num_neurons
            activity_peaks = activity_peaks.reshape(self.num_neurons, 1)
            #activity_peaks = np.zeros((self.num_neurons,1))+self.num_bins/4
            self.real_samples = sim_pop_activity.get_samples(num_samples=config.num_samples, num_bins=self.num_bins,\
            num_neurons=self.num_neurons, correlations_mat=correlations_mat, group_size=config.group_size, refr_per=config.ref_period,firing_rates_mat=firing_rates_mat, activity_peaks=activity_peaks)
            #get dev samples
            dev_samples = sim_pop_activity.get_samples(num_samples=int(config.num_samples/4), num_bins=self.num_bins,\
            num_neurons=self.num_neurons, correlations_mat=correlations_mat, group_size=config.group_size, refr_per=config.ref_period,firing_rates_mat=firing_rates_mat, activity_peaks=activity_peaks)
            #save original statistics
            analysis.get_stats(X=self.real_samples,
                               num_neurons=self.num_neurons,
                               num_bins=self.num_bins,
                               folder=self.sample_dir,
                               name='real',
                               firing_rate_mat=firing_rates_mat,
                               correlation_mat=correlations_mat,
                               activity_peaks=activity_peaks)
        elif config.dataset == 'retina':
            self.real_samples = retinal_data.get_samples(
                num_bins=self.num_bins,
                num_neurons=self.num_neurons,
                instance=config.data_instance)
            #save original statistics
            analysis.get_stats(X=self.real_samples,
                               num_neurons=self.num_neurons,
                               num_bins=self.num_bins,
                               folder=self.sample_dir,
                               name='real',
                               instance=config.data_instance)

        #count number of variables
        total_parameters = 0
        for variable in tf.trainable_variables():
            # shape is an array of tf.Dimension
            shape = variable.get_shape()
            variable_parameters = 1
            for dim in shape:
                variable_parameters *= dim.value
            total_parameters += variable_parameters
        print('number of varaibles: ' + str(total_parameters))
        #start training
        counter_batch = 0
        epoch = 0
        #fitting errors
        f, sbplt = plt.subplots(2, 2, figsize=(8, 8), dpi=250)
        matplotlib.rcParams.update({'font.size': 8})
        plt.subplots_adjust(left=left,
                            bottom=bottom,
                            right=right,
                            top=top,
                            wspace=wspace,
                            hspace=hspace)
        for iteration in range(config.num_iter):
            start_time = time.time()
            # Train generator (only after the critic has been trained, at least once)
            if iteration > 0:
                _ = self.sess.run(self.g_optim)

            # Train critic
            disc_iters = config.critic_iters
            for i in range(disc_iters):
                #get batch and trained critic
                _data = self.real_samples[:, counter_batch *
                                          config.batch_size:(counter_batch +
                                                             1) *
                                          config.batch_size].T
                _disc_cost, _ = self.sess.run([self.disc_cost, self.d_optim],
                                              feed_dict={self.inputs: _data})
                #if we have reached the end of the real samples set, we start over and increment the number of epochs
                if counter_batch == int(
                        self.real_samples.shape[1] / self.batch_size) - 1:
                    counter_batch = 0
                    epoch += 1
                else:
                    counter_batch += 1
            aux = time.time() - start_time
            #plot the  critics loss and the iteration time
            plot.plot(self.sample_dir, 'train disc cost', -_disc_cost)
            plot.plot(self.sample_dir, 'time', aux)

            if (
                    iteration == 500
            ) or iteration % 20000 == 19999 or iteration > config.num_iter - 10:
                print('epoch ' + str(epoch))
                if config.dataset == 'uniform':
                    #this is to evaluate whether the discriminator has overfit
                    dev_disc_costs = []
                    for ind_dev in range(
                            int(dev_samples.shape[1] / self.batch_size)):
                        images = dev_samples[:, ind_dev *
                                             config.batch_size:(ind_dev + 1) *
                                             config.batch_size].T
                        _dev_disc_cost = self.sess.run(
                            self.disc_cost, feed_dict={self.inputs: images})
                        dev_disc_costs.append(_dev_disc_cost)
                    #plot the dev loss
                    plot.plot(self.sample_dir, 'dev disc cost',
                              -np.mean(dev_disc_costs))

                #save the network parameters
                self.save(iteration)

                #get simulated samples, calculate their statistics and compare them with the original ones
                fake_samples = self.get_samples(num_samples=2**13)
                fake_samples = fake_samples.eval(session=self.sess)
                fake_samples = self.binarize(samples=fake_samples)
                acf_error, mean_error, corr_error, time_course_error,_ = analysis.get_stats(X=fake_samples.T, num_neurons=config.num_neurons,\
                    num_bins=config.num_bins, folder=config.sample_dir, name='fake'+str(iteration), critic_cost=-_disc_cost,instance=config.data_instance)
                #plot the fitting errors
                sbplt[0][0].plot(iteration, mean_error, '+b')
                sbplt[0][0].set_title('spk-count mean error')
                sbplt[0][0].set_xlabel('iterations')
                sbplt[0][0].set_ylabel('L1 error')
                sbplt[0][0].set_xlim([
                    0 - config.num_iter / 4,
                    config.num_iter + config.num_iter / 4
                ])
                sbplt[0][1].plot(iteration, time_course_error, '+b')
                sbplt[0][1].set_title('time course error')
                sbplt[0][1].set_xlabel('iterations')
                sbplt[0][1].set_ylabel('L1 error')
                sbplt[0][1].set_xlim([
                    0 - config.num_iter / 4,
                    config.num_iter + config.num_iter / 4
                ])
                sbplt[1][0].plot(iteration, acf_error, '+b')
                sbplt[1][0].set_title('AC error')
                sbplt[1][0].set_xlabel('iterations')
                sbplt[1][0].set_ylabel('L1 error')
                sbplt[1][0].set_xlim([
                    0 - config.num_iter / 4,
                    config.num_iter + config.num_iter / 4
                ])
                sbplt[1][1].plot(iteration, corr_error, '+b')
                sbplt[1][1].set_title('corr error')
                sbplt[1][1].set_xlabel('iterations')
                sbplt[1][1].set_ylabel('L1 error')
                sbplt[1][1].set_xlim([
                    0 - config.num_iter / 4,
                    config.num_iter + config.num_iter / 4
                ])
                f.savefig(self.sample_dir + 'fitting_errors.svg',
                          dpi=600,
                          bbox_inches='tight')
                plt.close(f)
                plot.flush(self.sample_dir)

            plot.tick()