Beispiel #1
0
    def train(self, config):
        """Train DCGAN"""
        #define optimizer
        self.g_optim = tf.train.AdamOptimizer(
            learning_rate=config.learning_rate,
            beta1=config.beta1,
            beta2=config.beta2).minimize(
                self.gen_cost,
                var_list=params_with_name('Generator'),
                colocate_gradients_with_ops=True)
        self.d_optim = tf.train.AdamOptimizer(
            learning_rate=config.learning_rate,
            beta1=config.beta1,
            beta2=config.beta2).minimize(
                self.disc_cost,
                var_list=params_with_name('Discriminator.'),
                colocate_gradients_with_ops=True)

        tf.global_variables_initializer().run()

        #try to load trained parameters
        print('-------------')
        existing_gan, ckpt_name = self.load()

        #count number of variables
        total_parameters = 0
        for variable in tf.trainable_variables():
            # shape is an array of tf.Dimension
            shape = variable.get_shape()
            variable_parameters = 1
            for dim in shape:
                variable_parameters *= dim.value
            total_parameters += variable_parameters
        print('-------------')
        print('number of variables: ' + str(total_parameters))
        print('-------------')
        #start training
        counter_batch = 0
        epoch = 0
        #fitting errors
        f, sbplt = plt.subplots(2, 2, figsize=(8, 8), dpi=250)
        matplotlib.rcParams.update({'font.size': 8})
        plt.subplots_adjust(left=left,
                            bottom=bottom,
                            right=right,
                            top=top,
                            wspace=wspace,
                            hspace=hspace)
        for iteration in range(config.num_iter):
            start_time = time.time()
            # Train generator (only after the critic has been trained, at least once)
            if iteration + ckpt_name > 0:
                _ = self.sess.run(self.g_optim)

            # Train critic
            disc_iters = config.critic_iters
            for i in range(disc_iters):
                #get batch and update critic
                _data = self.training_samples[:, counter_batch *
                                              config.batch_size:
                                              (counter_batch + 1) *
                                              config.batch_size].T
                _disc_cost, _ = self.sess.run([self.disc_cost, self.d_optim],
                                              feed_dict={self.inputs: _data})
                #if we have reached the end of the real samples set, we start over and increment the number of epochs
                if counter_batch == int(
                        self.training_samples.shape[1] / self.batch_size) - 1:
                    counter_batch = 0
                    epoch += 1
                else:
                    counter_batch += 1
            aux = time.time() - start_time
            #plot the  critics loss and the iteration time
            plot.plot(self.sample_dir, 'train disc cost', -_disc_cost)
            plot.plot(self.sample_dir, 'time', aux)

            if (iteration + ckpt_name
                    == 500) or iteration % 20000 == 19999 or (
                        iteration + ckpt_name >= config.num_iter - 10):
                print('epoch ' + str(epoch))
                if config.dataset == 'uniform' or config.dataset == 'packets':
                    #this is to evaluate whether the discriminator has overfit
                    dev_disc_costs = []
                    for ind_dev in range(
                            int(self.dev_samples.shape[1] / self.batch_size)):
                        images = self.dev_samples[:, ind_dev *
                                                  config.batch_size:(ind_dev +
                                                                     1) *
                                                  config.batch_size].T
                        _dev_disc_cost = self.sess.run(
                            self.disc_cost, feed_dict={self.inputs: images})
                        dev_disc_costs.append(_dev_disc_cost)
                    #plot the dev loss
                    plot.plot(self.sample_dir, 'dev disc cost',
                              -np.mean(dev_disc_costs))

                #save the network parameters
                self.save(iteration + ckpt_name)

                #get simulated samples, calculate their statistics and compare them with the original ones
                fake_samples = self.sess.run([self.ex_samples])[0]
                acf_error, mean_error, corr_error, time_course_error,_ = analysis.get_stats(X=fake_samples.T, num_neurons=config.num_neurons,\
                    num_bins=config.num_bins, folder=config.sample_dir, name='fake'+str(iteration+ckpt_name), critic_cost=-_disc_cost,instance=config.data_instance)
                #plot the fitting errors
                sbplt[0][0].plot(iteration + ckpt_name, mean_error, '+b')
                sbplt[0][0].set_title('spk-count mean error')
                sbplt[0][0].set_xlabel('iterations')
                sbplt[0][0].set_ylabel('L1 error')
                sbplt[0][0].set_xlim([
                    0 - config.num_iter / 4,
                    config.num_iter + config.num_iter / 4
                ])
                sbplt[0][1].plot(iteration + ckpt_name, time_course_error,
                                 '+b')
                sbplt[0][1].set_title('time course error')
                sbplt[0][1].set_xlabel('iterations')
                sbplt[0][1].set_ylabel('L1 error')
                sbplt[0][1].set_xlim([
                    0 - config.num_iter / 4,
                    config.num_iter + config.num_iter / 4
                ])
                sbplt[1][0].plot(iteration + ckpt_name, acf_error, '+b')
                sbplt[1][0].set_title('AC error')
                sbplt[1][0].set_xlabel('iterations')
                sbplt[1][0].set_ylabel('L1 error')
                sbplt[1][0].set_xlim([
                    0 - config.num_iter / 4,
                    config.num_iter + config.num_iter / 4
                ])
                sbplt[1][1].plot(iteration + ckpt_name, corr_error, '+b')
                sbplt[1][1].set_title('corr error')
                sbplt[1][1].set_xlabel('iterations')
                sbplt[1][1].set_ylabel('L1 error')
                sbplt[1][1].set_xlim([
                    0 - config.num_iter / 4,
                    config.num_iter + config.num_iter / 4
                ])
                f.savefig(self.sample_dir + 'fitting_errors.svg',
                          dpi=600,
                          bbox_inches='tight')
                plt.close(f)
                plot.flush(self.sample_dir)

            plot.tick()
Beispiel #2
0
def main(_):
    #print parameters
    pp.pprint(tf.app.flags.FLAGS.flag_values_dict())
    #folders
    if FLAGS.dataset == 'uniform':
        if FLAGS.architecture == 'fc':
            FLAGS.sample_dir = 'samples fc/' + 'dataset_' + FLAGS.dataset + '_num_samples_' + str(FLAGS.num_samples) +\
            '_num_neurons_' + str(FLAGS.num_neurons) + '_num_bins_' + str(FLAGS.num_bins)\
            + '_ref_period_' + str(FLAGS.ref_period) + '_firing_rate_' + str(FLAGS.firing_rate) + '_correlation_' + str(FLAGS.correlation) +\
            '_group_size_' + str(FLAGS.group_size)  + '_critic_iters_' + str(FLAGS.critic_iters) + '_lambda_' + str(FLAGS.lambd) +\
             '_num_units_' + str(FLAGS.num_units) +\
            '_iteration_' + FLAGS.iteration + '/'
        elif FLAGS.architecture == 'conv':
            FLAGS.sample_dir = 'samples conv/' + 'dataset_' + FLAGS.dataset + '_num_samples_' + str(FLAGS.num_samples) +\
            '_num_neurons_' + str(FLAGS.num_neurons) + '_num_bins_' + str(FLAGS.num_bins)\
            + '_ref_period_' + str(FLAGS.ref_period) + '_firing_rate_' + str(FLAGS.firing_rate) + '_correlation_' + str(FLAGS.correlation) +\
            '_group_size_' + str(FLAGS.group_size)  + '_critic_iters_' + str(FLAGS.critic_iters) + '_lambda_' + str(FLAGS.lambd) +\
            '_num_layers_' + str(FLAGS.num_layers)  + '_num_features_' + str(FLAGS.num_features) + '_kernel_' + str(FLAGS.kernel_width) +\
            '_iteration_' + FLAGS.iteration + '/'
    elif FLAGS.dataset == 'packets' and FLAGS.number_of_modes == 1:
        if FLAGS.architecture == 'fc':
            FLAGS.sample_dir = 'samples fc/' + 'dataset_' + FLAGS.dataset + '_num_samples_' + str(FLAGS.num_samples) +\
            '_num_neurons_' + str(FLAGS.num_neurons) + '_num_bins_' + str(FLAGS.num_bins) + '_packet_prob_' + str(FLAGS.packet_prob)\
            + '_firing_rate_' + str(FLAGS.firing_rate) + '_group_size_' + str(FLAGS.group_size) + '_critic_iters_' +\
            str(FLAGS.critic_iters) + '_lambda_' + str(FLAGS.lambd) + '_num_units_' + str(FLAGS.num_units) +\
            '_iteration_' + FLAGS.iteration + '/'
        elif FLAGS.architecture == 'conv':
            FLAGS.sample_dir = 'samples conv/' + 'dataset_' + FLAGS.dataset + '_num_samples_' + str(FLAGS.num_samples) +\
            '_num_neurons_' + str(FLAGS.num_neurons) + '_num_bins_' + str(FLAGS.num_bins) + '_packet_prob_' + str(FLAGS.packet_prob)\
            + '_firing_rate_' + str(FLAGS.firing_rate) + '_group_size_' + str(FLAGS.group_size) + '_critic_iters_' +\
            str(FLAGS.critic_iters) + '_lambda_' + str(FLAGS.lambd) +\
            '_num_layers_' + str(FLAGS.num_layers)  + '_num_features_' + str(FLAGS.num_features) + '_kernel_' + str(FLAGS.kernel_width) +\
            '_iteration_' + FLAGS.iteration + '/'
    elif FLAGS.dataset == 'packets' and FLAGS.number_of_modes == 2:
        if FLAGS.architecture == 'fc':
            FLAGS.sample_dir = 'samples fc/' + 'dataset_' + FLAGS.dataset + '_num_samples_' + str(FLAGS.num_samples) +\
            '_num_neurons_' + str(FLAGS.num_neurons) + '_num_bins_' + str(FLAGS.num_bins) + '_packet_prob_' + str(FLAGS.packet_prob)\
            + '_firing_rate_' + str(FLAGS.firing_rate) + '_group_size_' + str(FLAGS.group_size) + '_noise_in_packet_' + str(FLAGS.noise_in_packet) + '_number_of_modes_' + str(FLAGS.number_of_modes)  + '_critic_iters_' +\
            str(FLAGS.critic_iters) + '_lambda_' + str(FLAGS.lambd) + '_num_units_' + str(FLAGS.num_units) +\
            '_iteration_' + FLAGS.iteration + '/'
        elif FLAGS.architecture == 'conv':
            FLAGS.sample_dir = 'samples conv/' + 'dataset_' + FLAGS.dataset + '_num_samples_' + str(FLAGS.num_samples) +\
            '_num_neurons_' + str(FLAGS.num_neurons) + '_num_bins_' + str(FLAGS.num_bins) + '_packet_prob_' + str(FLAGS.packet_prob)\
            + '_firing_rate_' + str(FLAGS.firing_rate) + '_group_size_' + str(FLAGS.group_size) + '_noise_in_packet_' + str(FLAGS.noise_in_packet) + '_number_of_modes_' + str(FLAGS.number_of_modes)  + '_critic_iters_' +\
            str(FLAGS.critic_iters) + '_lambda_' + str(FLAGS.lambd) +\
            '_num_layers_' + str(FLAGS.num_layers)  + '_num_features_' + str(FLAGS.num_features) + '_kernel_' + str(FLAGS.kernel_width) +\
            '_iteration_' + FLAGS.iteration + '/'
    elif FLAGS.dataset == 'retina':
        if FLAGS.architecture == 'fc':
            FLAGS.sample_dir = 'samples fc/' + 'dataset_' + FLAGS.dataset  +\
            '_num_neurons_' + str(FLAGS.num_neurons) + '_num_bins_' + str(FLAGS.num_bins)\
            + '_critic_iters_' + str(FLAGS.critic_iters) + '_lambda_' + str(FLAGS.lambd) +\
            '_num_units_' + str(FLAGS.num_units) +\
            '_iteration_' + FLAGS.iteration + '/'
        elif FLAGS.architecture == 'conv':
            FLAGS.sample_dir = 'samples conv/' + 'dataset_' + FLAGS.dataset + '_num_samples_' + str(FLAGS.num_samples) +\
            '_num_neurons_' + str(FLAGS.num_neurons) + '_num_bins_' + str(FLAGS.num_bins)\
            + '_critic_iters_' + str(FLAGS.critic_iters) + '_lambda_' + str(FLAGS.lambd) +\
            '_num_layers_' + str(FLAGS.num_layers)  + '_num_features_' + str(FLAGS.num_features) + '_kernel_' + str(FLAGS.kernel_width) +\
            '_iteration_' + FLAGS.iteration + '/'

    FLAGS.checkpoint_dir = FLAGS.sample_dir + 'checkpoint/'
    if not os.path.exists(FLAGS.checkpoint_dir):
        os.makedirs(FLAGS.checkpoint_dir)
    if not os.path.exists(FLAGS.sample_dir):
        os.makedirs(FLAGS.sample_dir)

    if FLAGS.recovery_dir == "" and os.path.exists(FLAGS.sample_dir +
                                                   '/stats_real.npz'):
        FLAGS.recovery_dir = FLAGS.sample_dir

    run_config = tf.ConfigProto()
    run_config.gpu_options.allow_growth = True

    with tf.Session(config=run_config) as sess:
        wgan = WGAN_conv(sess,
                         architecture=FLAGS.architecture,
                         num_neurons=FLAGS.num_neurons,
                         num_bins=FLAGS.num_bins,
                         num_layers=FLAGS.num_layers,
                         num_units=FLAGS.num_units,
                         num_features=FLAGS.num_features,
                         kernel_width=FLAGS.kernel_width,
                         lambd=FLAGS.lambd,
                         batch_size=FLAGS.batch_size,
                         checkpoint_dir=FLAGS.checkpoint_dir,
                         sample_dir=FLAGS.sample_dir)

        if FLAGS.is_train:
            training_samples, dev_samples = data_provider.generate_spike_trains(
                FLAGS, FLAGS.recovery_dir)
            wgan.training_samples = training_samples
            wgan.dev_samples = dev_samples
            print('data loaded')
            wgan.train(FLAGS)
        else:
            if not wgan.load(FLAGS.training_stage):
                raise Exception("[!] Train a model first, then run test mode")

        #LOAD TRAINING DATASET (and its statistics)
        original_dataset = np.load(FLAGS.sample_dir + '/stats_real.npz')

        #PLOT FILTERS
        if FLAGS.dataset == 'retina':
            index = np.arange(FLAGS.num_neurons)
        else:
            index = np.argsort(original_dataset['shuffled_index'])

        if FLAGS.architecture == 'conv':
            print('get filters -----------------------------------')
            filters = wgan.get_filters(num_samples=64)
            visualize_filters_and_units.plot_filters(filters, sess, FLAGS,
                                                     index)

        #GET GENERATED SAMPLES AND COMPUTE THEIR STATISTICS
        print('compute stats -----------------------------------')
        if 'samples' not in original_dataset:
            real_samples = retinal_data.get_samples(
                num_bins=FLAGS.num_bins,
                num_neurons=FLAGS.num_neurons,
                instance=FLAGS.data_instance,
                folder=os.getcwd() + '/data/retinal data/')
        else:
            real_samples = original_dataset['samples']
        sim_pop_activity.plot_samples(real_samples, FLAGS.num_neurons,
                                      FLAGS.sample_dir, 'real')
        fake_samples = wgan.get_samples(num_samples=FLAGS.num_samples)
        fake_samples = fake_samples.eval(session=sess)
        sim_pop_activity.plot_samples(fake_samples.T, FLAGS.num_neurons,
                                      FLAGS.sample_dir, 'fake')
        _, _, _, _, _ = analysis.get_stats(X=fake_samples.T,
                                           num_neurons=FLAGS.num_neurons,
                                           num_bins=FLAGS.num_bins,
                                           folder=FLAGS.sample_dir,
                                           name='fake',
                                           instance=FLAGS.data_instance)

        #EVALUATE HIGH ORDER FEATURES (only when dimensionality is low)
        if FLAGS.dataset == 'uniform' and FLAGS.num_bins * FLAGS.num_neurons < 40:
            print(
                'compute high order statistics-----------------------------------'
            )
            num_trials = int(2**8)
            num_samples_per_trial = 2**13
            fake_samples_mat = np.zeros((num_trials * num_samples_per_trial,
                                         FLAGS.num_neurons * FLAGS.num_bins))
            for ind_tr in range(num_trials):
                fake_samples = wgan.sess.run([wgan.ex_samples])[0]
                fake_samples_mat[ind_tr * num_samples_per_trial:(ind_tr + 1) *
                                 num_samples_per_trial, :] = fake_samples

            analysis.evaluate_approx_distribution(X=fake_samples_mat.T, folder=FLAGS.sample_dir, num_samples_theoretical_distr=2**21,num_bins=FLAGS.num_bins, num_neurons=FLAGS.num_neurons,\
                              group_size=FLAGS.group_size,refr_per=FLAGS.ref_period)

        #COMPARISON WITH K-PAIRWISE AND DG MODELS (only for retinal data)
        if FLAGS.dataset == 'retina':
            print(
                'nearest sample analysis -----------------------------------')
            num_samples = 100  #this is for the 'nearest sample' analysis (Fig. S5)
            print('real_samples')
            analysis.nearest_sample(X_real=real_samples,
                                    fake_samples=real_samples,
                                    num_neurons=FLAGS.num_neurons,
                                    num_bins=FLAGS.num_bins,
                                    folder=FLAGS.sample_dir,
                                    name='real',
                                    num_samples=num_samples)
            ###################
            print('fake_samples')
            analysis.nearest_sample(X_real=real_samples,
                                    fake_samples=fake_samples.T,
                                    num_neurons=FLAGS.num_neurons,
                                    num_bins=FLAGS.num_bins,
                                    folder=FLAGS.sample_dir,
                                    name='spikeGAN',
                                    num_samples=num_samples)
            ###################
            k_pairwise_samples = retinal_data.load_samples_from_k_pairwise_model(
                num_samples=FLAGS.num_samples,
                num_bins=FLAGS.num_bins,
                num_neurons=FLAGS.num_neurons,
                instance=FLAGS.data_instance,
                folder=os.getcwd() + '/data/retinal data/')
            print('k_pairwise_samples')
            _, _, _, _, _ = analysis.get_stats(X=k_pairwise_samples,
                                               num_neurons=FLAGS.num_neurons,
                                               num_bins=FLAGS.num_bins,
                                               folder=FLAGS.sample_dir,
                                               name='k_pairwise',
                                               instance=FLAGS.data_instance)
            analysis.nearest_sample(X_real=real_samples,
                                    fake_samples=k_pairwise_samples,
                                    num_neurons=FLAGS.num_neurons,
                                    num_bins=FLAGS.num_bins,
                                    folder=FLAGS.sample_dir,
                                    name='k_pairwise',
                                    num_samples=num_samples)
            ###################
            DDG_samples = retinal_data.load_samples_from_DDG_model(
                num_samples=FLAGS.num_samples,
                num_bins=FLAGS.num_bins,
                num_neurons=FLAGS.num_neurons,
                instance=FLAGS.data_instance,
                folder=os.getcwd() + '/data/retinal data/')
            print('DDG_samples')
            _, _, _, _, _ = analysis.get_stats(X=DDG_samples,
                                               num_neurons=FLAGS.num_neurons,
                                               num_bins=FLAGS.num_bins,
                                               folder=FLAGS.sample_dir,
                                               name='DDG',
                                               instance=FLAGS.data_instance)
            analysis.nearest_sample(X_real=real_samples,
                                    fake_samples=DDG_samples,
                                    num_neurons=FLAGS.num_neurons,
                                    num_bins=FLAGS.num_bins,
                                    folder=FLAGS.sample_dir,
                                    name='DDG',
                                    num_samples=num_samples)
Beispiel #3
0
def generate_spike_trains(config, recovery_dir):
    if config.dataset == 'uniform':
        if recovery_dir != "":
            aux = np.load(recovery_dir + '/stats_real.npz')
            real_samples = aux['samples']
            firing_rates_mat = aux['firing_rate_mat']
            correlations_mat = aux['correlation_mat']
            activity_peaks = aux['activity_peaks']
            shuffled_index = aux['shuffled_index']
        else:
            #we shuffle neurons to test if the network still learns the packets
            shuffled_index = np.arange(config.num_neurons)
            np.random.shuffle(shuffled_index)
            firing_rates_mat = config.firing_rate + 2 * (
                np.random.random(int(config.num_neurons / config.group_size), )
                - 0.5) * config.firing_rate / 2
            correlations_mat = config.correlation + 2 * (
                np.random.random(int(config.num_neurons / config.group_size), )
                - 0.5) * config.correlation / 2
            #peaks of activity
            #sequence response
            aux = np.arange(int(config.num_neurons / config.group_size))
            activity_peaks = [
                [x] * config.group_size for x in aux
            ]  #np.random.randint(0,high=config.num_bins,size=(1,config.num_neurons)).reshape(config.num_neurons,1)
            activity_peaks = np.asarray(activity_peaks)
            activity_peaks = activity_peaks.flatten()
            activity_peaks = activity_peaks * config.group_size * config.num_bins / config.num_neurons
            activity_peaks = activity_peaks.reshape(config.num_neurons, 1)
            #peak of activity equal for all neurons
            #activity_peaks = np.zeros((config.num_neurons,1))+config.num_bins/4
            real_samples = sim_pop_activity.get_samples(num_samples=config.num_samples, num_bins=config.num_bins,\
                                num_neurons=config.num_neurons, correlations_mat=correlations_mat, group_size=config.group_size, shuffled_index=shuffled_index,\
                                refr_per=config.ref_period,firing_rates_mat=firing_rates_mat, activity_peaks=activity_peaks, folder=config.sample_dir)

        #save original statistics
        analysis.get_stats(X=real_samples, num_neurons=config.num_neurons, num_bins=config.num_bins, folder=config.sample_dir, shuffled_index=shuffled_index,\
                           name='real',firing_rate_mat=firing_rates_mat, correlation_mat=correlations_mat, activity_peaks=activity_peaks)

        #get dev samples
        dev_samples = sim_pop_activity.get_samples(num_samples=int(config.num_samples/4), num_bins=config.num_bins,\
                       num_neurons=config.num_neurons, correlations_mat=correlations_mat, group_size=config.group_size, shuffled_index=shuffled_index,\
                       refr_per=config.ref_period,firing_rates_mat=firing_rates_mat, activity_peaks=activity_peaks)

    elif config.dataset == 'packets':
        if recovery_dir != "":
            aux = np.load(recovery_dir + '/stats_real.npz')
            real_samples = aux['samples']
            firing_rates_mat = aux['firing_rate_mat']
            shuffled_index = aux['shuffled_index']
        else:
            #we shuffle neurons to test if the network still learns the packets
            shuffled_index = np.arange(config.num_neurons)
            np.random.shuffle(shuffled_index)
            firing_rates_mat = config.firing_rate + 2 * (np.random.random(
                size=(config.num_neurons, 1)) - 0.5) * config.firing_rate / 2
            real_samples = sim_pop_activity.get_samples(num_samples=config.num_samples, num_bins=config.num_bins, refr_per=config.ref_period,\
                                 num_neurons=config.num_neurons, group_size=config.group_size, firing_rates_mat=firing_rates_mat, packets_on=True,\
                                 prob_packets=config.packet_prob, shuffled_index=shuffled_index, folder=config.sample_dir)
        #save original statistics
        analysis.get_stats(X=real_samples, num_neurons=config.num_neurons, num_bins=config.num_bins, folder=config.sample_dir, name='real',\
                       firing_rate_mat=firing_rates_mat, shuffled_index=shuffled_index)
        #get dev samples
        dev_samples = sim_pop_activity.get_samples(num_samples=int(config.num_samples/4), num_bins=config.num_bins, refr_per=config.ref_period,\
                       num_neurons=config.num_neurons, group_size=config.group_size, firing_rates_mat=firing_rates_mat, packets_on=True,\
                       prob_packets=config.packet_prob,shuffled_index=shuffled_index)

    elif config.dataset == 'retina':
        real_samples = retinal_data.get_samples(num_bins=config.num_bins,
                                                num_neurons=config.num_neurons,
                                                instance=config.data_instance)
        #save original statistics
        analysis.get_stats(X=real_samples,
                           num_neurons=config.num_neurons,
                           num_bins=config.num_bins,
                           folder=config.sample_dir,
                           name='real',
                           instance=config.data_instance)
        dev_samples = []
    return real_samples, dev_samples
Beispiel #4
0
def generate_spike_trains(config, recovery_dir):
    '''
    this function returns the training and dev sets, corresponding to the parameters provided in config
    '''
    if config.dataset == 'uniform':
        if recovery_dir != "":
            aux = np.load(recovery_dir + '/stats_real.npz')
            real_samples = aux['samples']
            firing_rates_mat = aux['firing_rate_mat']
            correlations_mat = aux['correlation_mat']
            shuffled_index = aux['shuffled_index']
        else:
            #shuffle neurons
            shuffled_index = np.arange(config.num_neurons)
            np.random.shuffle(shuffled_index)
            firing_rates_mat = config.firing_rate + 2 * (
                np.random.random(int(config.num_neurons / config.group_size), )
                - 0.5) * config.firing_rate / 2
            correlations_mat = config.correlation + 2 * (
                np.random.random(int(config.num_neurons / config.group_size), )
                - 0.5) * config.correlation / 2
            #peaks of activity
            aux = np.arange(int(config.num_neurons / config.group_size))
            #peak of activity equal for all neurons
            real_samples = sim_pop_activity.get_samples(num_samples=config.num_samples, num_bins=config.num_bins,\
                                num_neurons=config.num_neurons, correlations_mat=correlations_mat, group_size=config.group_size, shuffled_index=shuffled_index,\
                                refr_per=config.ref_period,firing_rates_mat=firing_rates_mat, folder=config.sample_dir)

        #save original statistics
        analysis.get_stats(X=real_samples, num_neurons=config.num_neurons, num_bins=config.num_bins, folder=config.sample_dir, shuffled_index=shuffled_index,\
                           name='real',firing_rate_mat=firing_rates_mat, correlation_mat=correlations_mat)

        #get dev samples
        dev_samples = sim_pop_activity.get_samples(num_samples=int(config.num_samples/4), num_bins=config.num_bins,\
                       num_neurons=config.num_neurons, correlations_mat=correlations_mat, group_size=config.group_size, shuffled_index=shuffled_index,\
                       refr_per=config.ref_period,firing_rates_mat=firing_rates_mat)

    elif config.dataset == 'packets':
        if recovery_dir != "":
            aux = np.load(recovery_dir + '/stats_real.npz')
            real_samples = aux['samples']
            firing_rates_mat = aux['firing_rate_mat']
            shuffled_index = aux['shuffled_index']
        else:
            #shuffle the neurons
            shuffled_index = np.arange(config.num_neurons)
            np.random.shuffle(shuffled_index)
            firing_rates_mat = config.firing_rate + 2 * (np.random.random(
                size=(config.num_neurons, 1)) - 0.5) * config.firing_rate / 2
            real_samples = sim_pop_activity.get_samples(num_samples=config.num_samples, num_bins=config.num_bins, refr_per=config.ref_period,\
                                 num_neurons=config.num_neurons, group_size=config.group_size, firing_rates_mat=firing_rates_mat, packets_on=True,\
                                 prob_packets=config.packet_prob, shuffled_index=shuffled_index, folder=config.sample_dir, noise_in_packet=config.noise_in_packet, number_of_modes=config.number_of_modes)
        #save original statistics
        analysis.get_stats(X=real_samples, num_neurons=config.num_neurons, num_bins=config.num_bins, folder=config.sample_dir, name='real',\
                       firing_rate_mat=firing_rates_mat, shuffled_index=shuffled_index)
        #get dev samples
        dev_samples = sim_pop_activity.get_samples(num_samples=int(config.num_samples/4), num_bins=config.num_bins, refr_per=config.ref_period,\
                       num_neurons=config.num_neurons, group_size=config.group_size, firing_rates_mat=firing_rates_mat, packets_on=True,\
                       prob_packets=config.packet_prob, shuffled_index=shuffled_index, noise_in_packet=config.noise_in_packet, number_of_modes=config.number_of_modes)

    elif config.dataset == 'retina':
        dirpath = os.getcwd()
        real_samples = retinal_data.get_samples(num_bins=config.num_bins,
                                                num_neurons=config.num_neurons,
                                                instance=config.data_instance,
                                                folder=dirpath +
                                                '/data/retinal data/')
        #save original statistics
        analysis.get_stats(X=real_samples,
                           num_neurons=config.num_neurons,
                           num_bins=config.num_bins,
                           folder=config.sample_dir,
                           name='real',
                           instance=config.data_instance)
        dev_samples = []
    return real_samples, dev_samples
Beispiel #5
0
def main(_):
  #print parameters
  pp.pprint(flags.FLAGS.__flags)
  #folders
  if FLAGS.dataset=='uniform':
      if FLAGS.architecture=='fc':
          FLAGS.sample_dir = 'samples fc/' + 'dataset_' + FLAGS.dataset + '_num_samples_' + str(FLAGS.num_samples) +\
          '_num_neurons_' + str(FLAGS.num_neurons) + '_num_bins_' + str(FLAGS.num_bins)\
          + '_ref_period_' + str(FLAGS.ref_period) + '_firing_rate_' + str(FLAGS.firing_rate) + '_correlation_' + str(FLAGS.correlation) +\
          '_group_size_' + str(FLAGS.group_size)  + '_critic_iters_' + str(FLAGS.critic_iters) + '_lambda_' + str(FLAGS.lambd) +\
           '_num_units_' + str(FLAGS.num_units) +\
          '_iteration_' + FLAGS.iteration + '/'
      elif FLAGS.architecture=='conv':
          FLAGS.sample_dir = 'samples conv/' + 'dataset_' + FLAGS.dataset + '_num_samples_' + str(FLAGS.num_samples) +\
          '_num_neurons_' + str(FLAGS.num_neurons) + '_num_bins_' + str(FLAGS.num_bins)\
          + '_ref_period_' + str(FLAGS.ref_period) + '_firing_rate_' + str(FLAGS.firing_rate) + '_correlation_' + str(FLAGS.correlation) +\
          '_group_size_' + str(FLAGS.group_size)  + '_critic_iters_' + str(FLAGS.critic_iters) + '_lambda_' + str(FLAGS.lambd) +\
          '_num_layers_' + str(FLAGS.num_layers)  + '_num_features_' + str(FLAGS.num_features) + '_kernel_' + str(FLAGS.kernel_width) +\
          '_iteration_' + FLAGS.iteration + '/'
  elif FLAGS.dataset=='packets':
      if FLAGS.architecture=='fc':
          FLAGS.sample_dir = 'samples fc/' + 'dataset_' + FLAGS.dataset + '_num_samples_' + str(FLAGS.num_samples) +\
          '_num_neurons_' + str(FLAGS.num_neurons) + '_num_bins_' + str(FLAGS.num_bins) + '_packet_prob_' + str(FLAGS.packet_prob)\
          + '_firing_rate_' + str(FLAGS.firing_rate) + '_group_size_' + str(FLAGS.group_size)  + '_critic_iters_' +\
          str(FLAGS.critic_iters) + '_lambda_' + str(FLAGS.lambd) + '_num_units_' + str(FLAGS.num_units) +\
          '_iteration_' + FLAGS.iteration + '/'
      elif FLAGS.architecture=='conv':
          FLAGS.sample_dir = 'samples conv/' + 'dataset_' + FLAGS.dataset + '_num_samples_' + str(FLAGS.num_samples) +\
          '_num_neurons_' + str(FLAGS.num_neurons) + '_num_bins_' + str(FLAGS.num_bins) + '_packet_prob_' + str(FLAGS.packet_prob)\
          + '_firing_rate_' + str(FLAGS.firing_rate) + '_group_size_' + str(FLAGS.group_size)  + '_critic_iters_' +\
          str(FLAGS.critic_iters) + '_lambda_' + str(FLAGS.lambd) +\
          '_num_layers_' + str(FLAGS.num_layers)  + '_num_features_' + str(FLAGS.num_features) + '_kernel_' + str(FLAGS.kernel_width) +\
          '_iteration_' + FLAGS.iteration + '/'
  elif FLAGS.dataset=='retina':
     if FLAGS.architecture=='fc':
          FLAGS.sample_dir = 'samples fc/' + 'dataset_' + FLAGS.dataset  +\
          '_num_neurons_' + str(FLAGS.num_neurons) + '_num_bins_' + str(FLAGS.num_bins)\
          + '_critic_iters_' + str(FLAGS.critic_iters) + '_lambda_' + str(FLAGS.lambd) +\
          '_num_units_' + str(FLAGS.num_units) +\
          '_iteration_' + FLAGS.iteration + '/'
     elif FLAGS.architecture=='conv':
          FLAGS.sample_dir = 'samples conv/' + 'dataset_' + FLAGS.dataset + '_num_samples_' + str(FLAGS.num_samples) +\
          '_num_neurons_' + str(FLAGS.num_neurons) + '_num_bins_' + str(FLAGS.num_bins)\
          + '_critic_iters_' + str(FLAGS.critic_iters) + '_lambda_' + str(FLAGS.lambd) +\
          '_num_layers_' + str(FLAGS.num_layers)  + '_num_features_' + str(FLAGS.num_features) + '_kernel_' + str(FLAGS.kernel_width) +\
          '_iteration_' + FLAGS.iteration + '/'
      
  FLAGS.checkpoint_dir = FLAGS.sample_dir + 'checkpoint/'
  if not os.path.exists(FLAGS.checkpoint_dir):
    os.makedirs(FLAGS.checkpoint_dir)
  if not os.path.exists(FLAGS.sample_dir):
    os.makedirs(FLAGS.sample_dir)
  
  if FLAGS.recovery_dir=="" and os.path.exists(FLAGS.sample_dir+'/stats_real.npz'):
      FLAGS.recovery_dir = FLAGS.sample_dir
      
  #gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.333)
  run_config = tf.ConfigProto()
  run_config.gpu_options.allow_growth=True
  
  with tf.Session(config=run_config) as sess:
    wgan = WGAN_conv(sess, architecture=FLAGS.architecture,
        num_neurons=FLAGS.num_neurons,
        num_bins=FLAGS.num_bins,
        num_layers=FLAGS.num_layers, num_units=FLAGS.num_units,
        num_features=FLAGS.num_features,
        kernel_width=FLAGS.kernel_width,
        lambd=FLAGS.lambd,
        batch_size=FLAGS.batch_size,
        checkpoint_dir=FLAGS.checkpoint_dir,
        sample_dir=FLAGS.sample_dir)
    
    
    if FLAGS.is_train:  
        training_samples, dev_samples = data_provider.generate_spike_trains(FLAGS, FLAGS.recovery_dir)
        wgan.training_samples = training_samples
        wgan.dev_samples = dev_samples
        wgan.train(FLAGS)
    else:
        if not wgan.load(FLAGS.training_stage):
            raise Exception("[!] Train a model first, then run test mode")      

    
    original_dataset = np.load(FLAGS.sample_dir+ '/stats_real.npz')
    if FLAGS.dataset=='retina':
        index = np.arange(FLAGS.num_neurons)
    else:
        index = np.argsort(original_dataset['shuffled_index'])
    print('get filters -----------------------------------')
    filters = wgan.get_filters(num_samples=64)
    visualize_filters_and_units.plot_filters(filters, sess, FLAGS, index)
    asdasdasd
    #get generated samples and their statistics
    fake_samples = wgan.get_samples(num_samples=FLAGS.num_samples)
    fake_samples = fake_samples.eval(session=sess)
     
    _,_,_,_,_ = analysis.get_stats(X=fake_samples.T, num_neurons=FLAGS.num_neurons, num_bins= FLAGS.num_bins, folder=FLAGS.sample_dir, name='fake', instance=FLAGS.data_instance)

    #evaluate high order features approximation
    if FLAGS.dataset=='uniform' and FLAGS.num_bins*FLAGS.num_neurons<40:
        analysis.evaluate_approx_distribution(X=fake_samples.T, folder=FLAGS.sample_dir, num_samples_theoretical_distr=2**21,num_bins=FLAGS.num_bins, num_neurons=FLAGS.num_neurons,\
                            group_size=FLAGS.group_size,refr_per=FLAGS.ref_period)

    
    #get filters
    print('get activations -----------------------------------')
    output,units,inputs = wgan.get_units(num_samples=2**13)  
    if FLAGS.architecture=='conv':
        visualize_filters_and_units.plot_untis_rf_conv(units,output, inputs, sess, FLAGS, index)
    elif FLAGS.architecture=='fc':
        visualize_filters_and_units.plot_untis_rf(units, output, inputs, sess, FLAGS, index)
        
    
    
    
    real_samples = original_dataset['samples']
    #get critic's output distribution
    noise = ((np.zeros((FLAGS.num_samples, FLAGS.num_neurons*FLAGS.num_bins)) + 0.5) > np.random.random((FLAGS.num_samples, FLAGS.num_neurons*FLAGS.num_bins))).astype('float32')
    output_real = wgan.get_critics_output(real_samples.T)
    output_fake = wgan.get_critics_output(fake_samples)
    output_noise = wgan.get_critics_output(noise)
    output_real = output_real.eval(session=sess)
    output_fake = output_fake.eval(session=sess)
    output_noise = output_noise.eval(session=sess)
    visualize_filters_and_units.plot_histogram(output_real, FLAGS.sample_dir, 'real')
    visualize_filters_and_units.plot_histogram(output_fake, FLAGS.sample_dir, 'fake')
    visualize_filters_and_units.plot_histogram(output_noise, FLAGS.sample_dir, 'noise')
    #plot samples
    analysis.plot_samples(fake_samples.T, FLAGS.num_neurons, FLAGS.sample_dir, 'fake')
    
    
    analysis.plot_samples(real_samples, FLAGS.num_neurons, FLAGS.sample_dir, 'real')
    
    #compare with k-pairwise model samples
    if FLAGS.dataset=='retina':
        k_pairwise_samples = retinal_data.load_samples_from_k_pairwise_model(num_samples=FLAGS.num_samples, num_bins=FLAGS.num_bins, num_neurons=FLAGS.num_neurons, instance=FLAGS.data_instance)    
        print(k_pairwise_samples.shape)
        _,_,_,_ ,_ = analysis.get_stats(X=k_pairwise_samples, num_neurons=FLAGS.num_neurons, num_bins= FLAGS.num_bins, folder=FLAGS.sample_dir, name='k_pairwise', instance=FLAGS.data_instance)
Beispiel #6
0
    def train(self, config):
        """Train DCGAN"""
        #define optimizer
        self.g_optim = tf.train.AdamOptimizer(
            learning_rate=config.learning_rate,
            beta1=config.beta1,
            beta2=config.beta2).minimize(
                self.gen_cost,
                var_list=params_with_name('Generator'),
                colocate_gradients_with_ops=True)
        self.d_optim = tf.train.AdamOptimizer(
            learning_rate=config.learning_rate,
            beta1=config.beta1,
            beta2=config.beta2).minimize(
                self.disc_cost,
                var_list=params_with_name('Discriminator.'),
                colocate_gradients_with_ops=True)

        #initizialize variables
        try:
            tf.global_variables_initializer().run()
        except:
            tf.initialize_all_variables().run()

        #try to load trained parameters
        self.load()

        #get real samples
        if config.dataset == 'uniform':
            firing_rates_mat = config.firing_rate + 2 * (
                np.random.random(int(self.num_neurons / config.group_size), ) -
                0.5) * config.firing_rate / 2
            correlations_mat = config.correlation + 2 * (
                np.random.random(int(self.num_neurons / config.group_size), ) -
                0.5) * config.correlation / 2
            aux = np.arange(int(self.num_neurons / config.group_size))
            activity_peaks = [
                [x] * config.group_size for x in aux
            ]  #np.random.randint(0,high=self.num_bins,size=(1,self.num_neurons)).reshape(self.num_neurons,1)
            activity_peaks = np.asarray(activity_peaks)
            activity_peaks = activity_peaks.flatten()
            activity_peaks = activity_peaks * config.group_size * self.num_bins / self.num_neurons
            activity_peaks = activity_peaks.reshape(self.num_neurons, 1)
            #activity_peaks = np.zeros((self.num_neurons,1))+self.num_bins/4
            self.real_samples = sim_pop_activity.get_samples(num_samples=config.num_samples, num_bins=self.num_bins,\
            num_neurons=self.num_neurons, correlations_mat=correlations_mat, group_size=config.group_size, refr_per=config.ref_period,firing_rates_mat=firing_rates_mat, activity_peaks=activity_peaks)
            #get dev samples
            dev_samples = sim_pop_activity.get_samples(num_samples=int(config.num_samples/4), num_bins=self.num_bins,\
            num_neurons=self.num_neurons, correlations_mat=correlations_mat, group_size=config.group_size, refr_per=config.ref_period,firing_rates_mat=firing_rates_mat, activity_peaks=activity_peaks)
            #save original statistics
            analysis.get_stats(X=self.real_samples,
                               num_neurons=self.num_neurons,
                               num_bins=self.num_bins,
                               folder=self.sample_dir,
                               name='real',
                               firing_rate_mat=firing_rates_mat,
                               correlation_mat=correlations_mat,
                               activity_peaks=activity_peaks)
        elif config.dataset == 'retina':
            self.real_samples = retinal_data.get_samples(
                num_bins=self.num_bins,
                num_neurons=self.num_neurons,
                instance=config.data_instance)
            #save original statistics
            analysis.get_stats(X=self.real_samples,
                               num_neurons=self.num_neurons,
                               num_bins=self.num_bins,
                               folder=self.sample_dir,
                               name='real',
                               instance=config.data_instance)

        #count number of variables
        total_parameters = 0
        for variable in tf.trainable_variables():
            # shape is an array of tf.Dimension
            shape = variable.get_shape()
            variable_parameters = 1
            for dim in shape:
                variable_parameters *= dim.value
            total_parameters += variable_parameters
        print('number of varaibles: ' + str(total_parameters))
        #start training
        counter_batch = 0
        epoch = 0
        #fitting errors
        f, sbplt = plt.subplots(2, 2, figsize=(8, 8), dpi=250)
        matplotlib.rcParams.update({'font.size': 8})
        plt.subplots_adjust(left=left,
                            bottom=bottom,
                            right=right,
                            top=top,
                            wspace=wspace,
                            hspace=hspace)
        for iteration in range(config.num_iter):
            start_time = time.time()
            # Train generator (only after the critic has been trained, at least once)
            if iteration > 0:
                _ = self.sess.run(self.g_optim)

            # Train critic
            disc_iters = config.critic_iters
            for i in range(disc_iters):
                #get batch and trained critic
                _data = self.real_samples[:, counter_batch *
                                          config.batch_size:(counter_batch +
                                                             1) *
                                          config.batch_size].T
                _disc_cost, _ = self.sess.run([self.disc_cost, self.d_optim],
                                              feed_dict={self.inputs: _data})
                #if we have reached the end of the real samples set, we start over and increment the number of epochs
                if counter_batch == int(
                        self.real_samples.shape[1] / self.batch_size) - 1:
                    counter_batch = 0
                    epoch += 1
                else:
                    counter_batch += 1
            aux = time.time() - start_time
            #plot the  critics loss and the iteration time
            plot.plot(self.sample_dir, 'train disc cost', -_disc_cost)
            plot.plot(self.sample_dir, 'time', aux)

            if (
                    iteration == 500
            ) or iteration % 20000 == 19999 or iteration > config.num_iter - 10:
                print('epoch ' + str(epoch))
                if config.dataset == 'uniform':
                    #this is to evaluate whether the discriminator has overfit
                    dev_disc_costs = []
                    for ind_dev in range(
                            int(dev_samples.shape[1] / self.batch_size)):
                        images = dev_samples[:, ind_dev *
                                             config.batch_size:(ind_dev + 1) *
                                             config.batch_size].T
                        _dev_disc_cost = self.sess.run(
                            self.disc_cost, feed_dict={self.inputs: images})
                        dev_disc_costs.append(_dev_disc_cost)
                    #plot the dev loss
                    plot.plot(self.sample_dir, 'dev disc cost',
                              -np.mean(dev_disc_costs))

                #save the network parameters
                self.save(iteration)

                #get simulated samples, calculate their statistics and compare them with the original ones
                fake_samples = self.get_samples(num_samples=2**13)
                fake_samples = fake_samples.eval(session=self.sess)
                fake_samples = self.binarize(samples=fake_samples)
                acf_error, mean_error, corr_error, time_course_error,_ = analysis.get_stats(X=fake_samples.T, num_neurons=config.num_neurons,\
                    num_bins=config.num_bins, folder=config.sample_dir, name='fake'+str(iteration), critic_cost=-_disc_cost,instance=config.data_instance)
                #plot the fitting errors
                sbplt[0][0].plot(iteration, mean_error, '+b')
                sbplt[0][0].set_title('spk-count mean error')
                sbplt[0][0].set_xlabel('iterations')
                sbplt[0][0].set_ylabel('L1 error')
                sbplt[0][0].set_xlim([
                    0 - config.num_iter / 4,
                    config.num_iter + config.num_iter / 4
                ])
                sbplt[0][1].plot(iteration, time_course_error, '+b')
                sbplt[0][1].set_title('time course error')
                sbplt[0][1].set_xlabel('iterations')
                sbplt[0][1].set_ylabel('L1 error')
                sbplt[0][1].set_xlim([
                    0 - config.num_iter / 4,
                    config.num_iter + config.num_iter / 4
                ])
                sbplt[1][0].plot(iteration, acf_error, '+b')
                sbplt[1][0].set_title('AC error')
                sbplt[1][0].set_xlabel('iterations')
                sbplt[1][0].set_ylabel('L1 error')
                sbplt[1][0].set_xlim([
                    0 - config.num_iter / 4,
                    config.num_iter + config.num_iter / 4
                ])
                sbplt[1][1].plot(iteration, corr_error, '+b')
                sbplt[1][1].set_title('corr error')
                sbplt[1][1].set_xlabel('iterations')
                sbplt[1][1].set_ylabel('L1 error')
                sbplt[1][1].set_xlim([
                    0 - config.num_iter / 4,
                    config.num_iter + config.num_iter / 4
                ])
                f.savefig(self.sample_dir + 'fitting_errors.svg',
                          dpi=600,
                          bbox_inches='tight')
                plt.close(f)
                plot.flush(self.sample_dir)

            plot.tick()
Beispiel #7
0
def main(_):
    #print parameters
    pp.pprint(flags.FLAGS.__flags)

    #folders
    if FLAGS.dataset == 'uniform':
        FLAGS.sample_dir = 'samples/' + 'dataset_' + FLAGS.dataset + '_num_samples_' + str(FLAGS.num_samples) +\
        '_num_neurons_' + str(FLAGS.num_neurons) + '_num_bins_' + str(FLAGS.num_bins)\
        + '_ref_period_' + str(FLAGS.ref_period) + '_firing_rate_' + str(FLAGS.firing_rate) + '_correlation_' + str(FLAGS.correlation) +\
        '_group_size_' + str(FLAGS.group_size)  + '_critic_iters_' + str(FLAGS.critic_iters) + '_lambda_' + str(FLAGS.lambd) +\
        '_num_features_' + str(FLAGS.num_features) + '_kernel_' + str(FLAGS.kernel_width) +\
        '_iteration_' + FLAGS.iteration + '/'
    elif FLAGS.dataset == 'retina':
        FLAGS.sample_dir = 'samples/' + 'dataset_' + FLAGS.dataset + '_instance_' + FLAGS.data_instance +\
         '_num_neurons_' + str(FLAGS.num_neurons) + '_num_bins_' + str(FLAGS.num_bins) +\
          '_critic_iters_' + str(FLAGS.critic_iters) + '_lambda_' + str(FLAGS.lambd) +\
          '_num_features_' + str(FLAGS.num_features) + '_kernel_' + str(FLAGS.kernel_width) +\
         '_iteration_' + FLAGS.iteration + '/'

    FLAGS.checkpoint_dir = FLAGS.sample_dir + 'checkpoint/'

    if not os.path.exists(FLAGS.checkpoint_dir):
        os.makedirs(FLAGS.checkpoint_dir)
    if not os.path.exists(FLAGS.sample_dir):
        os.makedirs(FLAGS.sample_dir)

    #gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.333)
    run_config = tf.ConfigProto()
    run_config.gpu_options.allow_growth = True

    with tf.Session(config=run_config) as sess:
        wgan = WGAN(sess,
                    num_neurons=FLAGS.num_neurons,
                    num_bins=FLAGS.num_bins,
                    lambd=FLAGS.lambd,
                    num_features=FLAGS.num_features,
                    kernel_width=FLAGS.kernel_width,
                    batch_size=FLAGS.batch_size,
                    checkpoint_dir=FLAGS.checkpoint_dir,
                    sample_dir=FLAGS.sample_dir)

        if FLAGS.is_train:
            wgan.train(FLAGS)
        else:
            if not wgan.load(FLAGS.training_stage):
                raise Exception("[!] Train a model first, then run test mode")

        print('get filters -----------------------------------')
        filters = wgan.get_filters()
        visualize_filters_and_units.plot_filters(filters, sess, FLAGS)
        #get units activity
        print('get units activity -----------------------------------')
        output, units, noise = wgan.get_units(num_samples=2**13)
        visualize_filters_and_units.plot_untis_rf(units, output, noise, sess,
                                                  FLAGS)
        #get samples and their statistics
        fake_samples = wgan.get_samples(num_samples=FLAGS.num_samples)
        fake_samples = fake_samples.eval(session=sess)
        fake_samples_binnarized = wgan.binarize(samples=fake_samples)
        #plot samples
        analysis.plot_samples(fake_samples.T, fake_samples_binnarized.T,
                              FLAGS.num_neurons, FLAGS.num_bins,
                              FLAGS.sample_dir)
        #plot stats
        _, _, _, _, _ = analysis.get_stats(X=fake_samples_binnarized.T,
                                           num_neurons=FLAGS.num_neurons,
                                           num_bins=FLAGS.num_bins,
                                           folder=FLAGS.sample_dir,
                                           name='fake',
                                           instance=FLAGS.data_instance)

        if FLAGS.dataset == 'retina':
            k_pairwise_samples = retinal_data.load_samples_from_k_pairwise_model(
                num_samples=FLAGS.num_samples,
                num_bins=FLAGS.num_bins,
                num_neurons=FLAGS.num_neurons,
                instance=FLAGS.data_instance)
            print(k_pairwise_samples.shape)
            _, _, _, _, _ = analysis.get_stats(X=k_pairwise_samples,
                                               num_neurons=FLAGS.num_neurons,
                                               num_bins=FLAGS.num_bins,
                                               folder=FLAGS.sample_dir,
                                               name='k_pairwise',
                                               instance=FLAGS.data_instance)
        if FLAGS.dataset == 'uniform' and False:
            analysis.evaluate_approx_distribution(X=fake_samples_binnarized.T, folder=FLAGS.sample_dir, num_samples_theoretical_distr=2**21,num_bins=FLAGS.num_bins, num_neurons=FLAGS.num_neurons,\
                                group_size=FLAGS.group_size,refr_per=FLAGS.ref_period)