Пример #1
0
def main(_):
    #print parameters
    pp.pprint(tf.app.flags.FLAGS.flag_values_dict())
    #folders
    if FLAGS.dataset == 'uniform':
        if FLAGS.architecture == 'fc':
            FLAGS.sample_dir = 'samples fc/' + 'dataset_' + FLAGS.dataset + '_num_samples_' + str(FLAGS.num_samples) +\
            '_num_neurons_' + str(FLAGS.num_neurons) + '_num_bins_' + str(FLAGS.num_bins)\
            + '_ref_period_' + str(FLAGS.ref_period) + '_firing_rate_' + str(FLAGS.firing_rate) + '_correlation_' + str(FLAGS.correlation) +\
            '_group_size_' + str(FLAGS.group_size)  + '_critic_iters_' + str(FLAGS.critic_iters) + '_lambda_' + str(FLAGS.lambd) +\
             '_num_units_' + str(FLAGS.num_units) +\
            '_iteration_' + FLAGS.iteration + '/'
        elif FLAGS.architecture == 'conv':
            FLAGS.sample_dir = 'samples conv/' + 'dataset_' + FLAGS.dataset + '_num_samples_' + str(FLAGS.num_samples) +\
            '_num_neurons_' + str(FLAGS.num_neurons) + '_num_bins_' + str(FLAGS.num_bins)\
            + '_ref_period_' + str(FLAGS.ref_period) + '_firing_rate_' + str(FLAGS.firing_rate) + '_correlation_' + str(FLAGS.correlation) +\
            '_group_size_' + str(FLAGS.group_size)  + '_critic_iters_' + str(FLAGS.critic_iters) + '_lambda_' + str(FLAGS.lambd) +\
            '_num_layers_' + str(FLAGS.num_layers)  + '_num_features_' + str(FLAGS.num_features) + '_kernel_' + str(FLAGS.kernel_width) +\
            '_iteration_' + FLAGS.iteration + '/'
    elif FLAGS.dataset == 'packets' and FLAGS.number_of_modes == 1:
        if FLAGS.architecture == 'fc':
            FLAGS.sample_dir = 'samples fc/' + 'dataset_' + FLAGS.dataset + '_num_samples_' + str(FLAGS.num_samples) +\
            '_num_neurons_' + str(FLAGS.num_neurons) + '_num_bins_' + str(FLAGS.num_bins) + '_packet_prob_' + str(FLAGS.packet_prob)\
            + '_firing_rate_' + str(FLAGS.firing_rate) + '_group_size_' + str(FLAGS.group_size) + '_critic_iters_' +\
            str(FLAGS.critic_iters) + '_lambda_' + str(FLAGS.lambd) + '_num_units_' + str(FLAGS.num_units) +\
            '_iteration_' + FLAGS.iteration + '/'
        elif FLAGS.architecture == 'conv':
            FLAGS.sample_dir = 'samples conv/' + 'dataset_' + FLAGS.dataset + '_num_samples_' + str(FLAGS.num_samples) +\
            '_num_neurons_' + str(FLAGS.num_neurons) + '_num_bins_' + str(FLAGS.num_bins) + '_packet_prob_' + str(FLAGS.packet_prob)\
            + '_firing_rate_' + str(FLAGS.firing_rate) + '_group_size_' + str(FLAGS.group_size) + '_critic_iters_' +\
            str(FLAGS.critic_iters) + '_lambda_' + str(FLAGS.lambd) +\
            '_num_layers_' + str(FLAGS.num_layers)  + '_num_features_' + str(FLAGS.num_features) + '_kernel_' + str(FLAGS.kernel_width) +\
            '_iteration_' + FLAGS.iteration + '/'
    elif FLAGS.dataset == 'packets' and FLAGS.number_of_modes == 2:
        if FLAGS.architecture == 'fc':
            FLAGS.sample_dir = 'samples fc/' + 'dataset_' + FLAGS.dataset + '_num_samples_' + str(FLAGS.num_samples) +\
            '_num_neurons_' + str(FLAGS.num_neurons) + '_num_bins_' + str(FLAGS.num_bins) + '_packet_prob_' + str(FLAGS.packet_prob)\
            + '_firing_rate_' + str(FLAGS.firing_rate) + '_group_size_' + str(FLAGS.group_size) + '_noise_in_packet_' + str(FLAGS.noise_in_packet) + '_number_of_modes_' + str(FLAGS.number_of_modes)  + '_critic_iters_' +\
            str(FLAGS.critic_iters) + '_lambda_' + str(FLAGS.lambd) + '_num_units_' + str(FLAGS.num_units) +\
            '_iteration_' + FLAGS.iteration + '/'
        elif FLAGS.architecture == 'conv':
            FLAGS.sample_dir = 'samples conv/' + 'dataset_' + FLAGS.dataset + '_num_samples_' + str(FLAGS.num_samples) +\
            '_num_neurons_' + str(FLAGS.num_neurons) + '_num_bins_' + str(FLAGS.num_bins) + '_packet_prob_' + str(FLAGS.packet_prob)\
            + '_firing_rate_' + str(FLAGS.firing_rate) + '_group_size_' + str(FLAGS.group_size) + '_noise_in_packet_' + str(FLAGS.noise_in_packet) + '_number_of_modes_' + str(FLAGS.number_of_modes)  + '_critic_iters_' +\
            str(FLAGS.critic_iters) + '_lambda_' + str(FLAGS.lambd) +\
            '_num_layers_' + str(FLAGS.num_layers)  + '_num_features_' + str(FLAGS.num_features) + '_kernel_' + str(FLAGS.kernel_width) +\
            '_iteration_' + FLAGS.iteration + '/'
    elif FLAGS.dataset == 'retina':
        if FLAGS.architecture == 'fc':
            FLAGS.sample_dir = 'samples fc/' + 'dataset_' + FLAGS.dataset  +\
            '_num_neurons_' + str(FLAGS.num_neurons) + '_num_bins_' + str(FLAGS.num_bins)\
            + '_critic_iters_' + str(FLAGS.critic_iters) + '_lambda_' + str(FLAGS.lambd) +\
            '_num_units_' + str(FLAGS.num_units) +\
            '_iteration_' + FLAGS.iteration + '/'
        elif FLAGS.architecture == 'conv':
            FLAGS.sample_dir = 'samples conv/' + 'dataset_' + FLAGS.dataset + '_num_samples_' + str(FLAGS.num_samples) +\
            '_num_neurons_' + str(FLAGS.num_neurons) + '_num_bins_' + str(FLAGS.num_bins)\
            + '_critic_iters_' + str(FLAGS.critic_iters) + '_lambda_' + str(FLAGS.lambd) +\
            '_num_layers_' + str(FLAGS.num_layers)  + '_num_features_' + str(FLAGS.num_features) + '_kernel_' + str(FLAGS.kernel_width) +\
            '_iteration_' + FLAGS.iteration + '/'

    FLAGS.checkpoint_dir = FLAGS.sample_dir + 'checkpoint/'
    if not os.path.exists(FLAGS.checkpoint_dir):
        os.makedirs(FLAGS.checkpoint_dir)
    if not os.path.exists(FLAGS.sample_dir):
        os.makedirs(FLAGS.sample_dir)

    if FLAGS.recovery_dir == "" and os.path.exists(FLAGS.sample_dir +
                                                   '/stats_real.npz'):
        FLAGS.recovery_dir = FLAGS.sample_dir

    run_config = tf.ConfigProto()
    run_config.gpu_options.allow_growth = True

    with tf.Session(config=run_config) as sess:
        wgan = WGAN_conv(sess,
                         architecture=FLAGS.architecture,
                         num_neurons=FLAGS.num_neurons,
                         num_bins=FLAGS.num_bins,
                         num_layers=FLAGS.num_layers,
                         num_units=FLAGS.num_units,
                         num_features=FLAGS.num_features,
                         kernel_width=FLAGS.kernel_width,
                         lambd=FLAGS.lambd,
                         batch_size=FLAGS.batch_size,
                         checkpoint_dir=FLAGS.checkpoint_dir,
                         sample_dir=FLAGS.sample_dir)

        if FLAGS.is_train:
            training_samples, dev_samples = data_provider.generate_spike_trains(
                FLAGS, FLAGS.recovery_dir)
            wgan.training_samples = training_samples
            wgan.dev_samples = dev_samples
            print('data loaded')
            wgan.train(FLAGS)
        else:
            if not wgan.load(FLAGS.training_stage):
                raise Exception("[!] Train a model first, then run test mode")

        #LOAD TRAINING DATASET (and its statistics)
        original_dataset = np.load(FLAGS.sample_dir + '/stats_real.npz')

        #PLOT FILTERS
        if FLAGS.dataset == 'retina':
            index = np.arange(FLAGS.num_neurons)
        else:
            index = np.argsort(original_dataset['shuffled_index'])

        if FLAGS.architecture == 'conv':
            print('get filters -----------------------------------')
            filters = wgan.get_filters(num_samples=64)
            visualize_filters_and_units.plot_filters(filters, sess, FLAGS,
                                                     index)

        #GET GENERATED SAMPLES AND COMPUTE THEIR STATISTICS
        print('compute stats -----------------------------------')
        if 'samples' not in original_dataset:
            real_samples = retinal_data.get_samples(
                num_bins=FLAGS.num_bins,
                num_neurons=FLAGS.num_neurons,
                instance=FLAGS.data_instance,
                folder=os.getcwd() + '/data/retinal data/')
        else:
            real_samples = original_dataset['samples']
        sim_pop_activity.plot_samples(real_samples, FLAGS.num_neurons,
                                      FLAGS.sample_dir, 'real')
        fake_samples = wgan.get_samples(num_samples=FLAGS.num_samples)
        fake_samples = fake_samples.eval(session=sess)
        sim_pop_activity.plot_samples(fake_samples.T, FLAGS.num_neurons,
                                      FLAGS.sample_dir, 'fake')
        _, _, _, _, _ = analysis.get_stats(X=fake_samples.T,
                                           num_neurons=FLAGS.num_neurons,
                                           num_bins=FLAGS.num_bins,
                                           folder=FLAGS.sample_dir,
                                           name='fake',
                                           instance=FLAGS.data_instance)

        #EVALUATE HIGH ORDER FEATURES (only when dimensionality is low)
        if FLAGS.dataset == 'uniform' and FLAGS.num_bins * FLAGS.num_neurons < 40:
            print(
                'compute high order statistics-----------------------------------'
            )
            num_trials = int(2**8)
            num_samples_per_trial = 2**13
            fake_samples_mat = np.zeros((num_trials * num_samples_per_trial,
                                         FLAGS.num_neurons * FLAGS.num_bins))
            for ind_tr in range(num_trials):
                fake_samples = wgan.sess.run([wgan.ex_samples])[0]
                fake_samples_mat[ind_tr * num_samples_per_trial:(ind_tr + 1) *
                                 num_samples_per_trial, :] = fake_samples

            analysis.evaluate_approx_distribution(X=fake_samples_mat.T, folder=FLAGS.sample_dir, num_samples_theoretical_distr=2**21,num_bins=FLAGS.num_bins, num_neurons=FLAGS.num_neurons,\
                              group_size=FLAGS.group_size,refr_per=FLAGS.ref_period)

        #COMPARISON WITH K-PAIRWISE AND DG MODELS (only for retinal data)
        if FLAGS.dataset == 'retina':
            print(
                'nearest sample analysis -----------------------------------')
            num_samples = 100  #this is for the 'nearest sample' analysis (Fig. S5)
            print('real_samples')
            analysis.nearest_sample(X_real=real_samples,
                                    fake_samples=real_samples,
                                    num_neurons=FLAGS.num_neurons,
                                    num_bins=FLAGS.num_bins,
                                    folder=FLAGS.sample_dir,
                                    name='real',
                                    num_samples=num_samples)
            ###################
            print('fake_samples')
            analysis.nearest_sample(X_real=real_samples,
                                    fake_samples=fake_samples.T,
                                    num_neurons=FLAGS.num_neurons,
                                    num_bins=FLAGS.num_bins,
                                    folder=FLAGS.sample_dir,
                                    name='spikeGAN',
                                    num_samples=num_samples)
            ###################
            k_pairwise_samples = retinal_data.load_samples_from_k_pairwise_model(
                num_samples=FLAGS.num_samples,
                num_bins=FLAGS.num_bins,
                num_neurons=FLAGS.num_neurons,
                instance=FLAGS.data_instance,
                folder=os.getcwd() + '/data/retinal data/')
            print('k_pairwise_samples')
            _, _, _, _, _ = analysis.get_stats(X=k_pairwise_samples,
                                               num_neurons=FLAGS.num_neurons,
                                               num_bins=FLAGS.num_bins,
                                               folder=FLAGS.sample_dir,
                                               name='k_pairwise',
                                               instance=FLAGS.data_instance)
            analysis.nearest_sample(X_real=real_samples,
                                    fake_samples=k_pairwise_samples,
                                    num_neurons=FLAGS.num_neurons,
                                    num_bins=FLAGS.num_bins,
                                    folder=FLAGS.sample_dir,
                                    name='k_pairwise',
                                    num_samples=num_samples)
            ###################
            DDG_samples = retinal_data.load_samples_from_DDG_model(
                num_samples=FLAGS.num_samples,
                num_bins=FLAGS.num_bins,
                num_neurons=FLAGS.num_neurons,
                instance=FLAGS.data_instance,
                folder=os.getcwd() + '/data/retinal data/')
            print('DDG_samples')
            _, _, _, _, _ = analysis.get_stats(X=DDG_samples,
                                               num_neurons=FLAGS.num_neurons,
                                               num_bins=FLAGS.num_bins,
                                               folder=FLAGS.sample_dir,
                                               name='DDG',
                                               instance=FLAGS.data_instance)
            analysis.nearest_sample(X_real=real_samples,
                                    fake_samples=DDG_samples,
                                    num_neurons=FLAGS.num_neurons,
                                    num_bins=FLAGS.num_bins,
                                    folder=FLAGS.sample_dir,
                                    name='DDG',
                                    num_samples=num_samples)
Пример #2
0
def get_stats(X,
              num_neurons,
              num_bins,
              folder,
              name,
              firing_rate_mat=[],
              correlation_mat=[],
              activity_peaks=[],
              critic_cost=np.nan,
              instance='1',
              shuffled_index=[]):
    '''
    compute spike trains spikes: spk-count mean and std, autocorrelogram and correlation mat
    if name!='real' then it compares the above stats with the original ones 
    
    '''
    X_binnarized = (X > np.random.random(X.shape)).astype(float)
    resave_real_data = False
    if name != 'real':
        original_data = np.load(folder + '/stats_real.npz')
        if any(k not in original_data
               for k in ("mean", "acf", "cov_mat", "k_probs", "lag_cov_mat",
                         "firing_average_time_course")):
            if 'samples' not in original_data:
                samples = retinal_data.get_samples(num_bins=num_bins,
                                                   num_neurons=num_neurons,
                                                   instance=instance,
                                                   folder=os.getcwd() +
                                                   '/data/retinal data/')
            else:
                samples = original_data['samples']
            cov_mat_real, k_probs_real, mean_spike_count_real, autocorrelogram_mat_real, firing_average_time_course_real, lag_cov_mat_real =\
            get_stats_aux(samples, num_neurons, num_bins)
            assert np.all(autocorrelogram_mat_real == original_data['acf'])
            assert np.all(mean_spike_count_real == original_data['mean'])
            resave_real_data = True
        else:
            mean_spike_count_real, autocorrelogram_mat_real, firing_average_time_course_real, cov_mat_real, k_probs_real, lag_cov_mat_real = \
            [original_data["mean"], original_data["acf"], original_data["firing_average_time_course"], original_data["cov_mat"], original_data["k_probs"], original_data["lag_cov_mat"]]

    cov_mat, k_probs, mean_spike_count, autocorrelogram_mat, firing_average_time_course, lag_cov_mat = get_stats_aux(
        X_binnarized, num_neurons, num_bins)
    variances = np.diag(cov_mat)
    only_cov_mat = cov_mat.copy()
    only_cov_mat[np.diag_indices(num_neurons)] = np.nan

    #PLOT
    index = np.linspace(-10, 10, 2 * 10 + 1)
    #figure for all training error across epochs (supp. figure 2)
    f, sbplt = plt.subplots(2, 3, figsize=(8, 8), dpi=250)
    matplotlib.rcParams.update({'font.size': 8})
    plt.subplots_adjust(left=left,
                        bottom=bottom,
                        right=right,
                        top=top,
                        wspace=wspace,
                        hspace=hspace)

    #plot autocorrelogram(s)
    sbplt[1][1].plot(index, autocorrelogram_mat, 'r')
    if name != 'real':
        sbplt[1][1].plot(index, autocorrelogram_mat_real, 'b')
        acf_error = np.sum(
            np.abs(autocorrelogram_mat - autocorrelogram_mat_real))
    sbplt[1][1].set_title('Autocorrelogram')
    sbplt[1][1].set_xlabel('time (ms)')
    sbplt[1][1].set_ylabel('number of spikes')

    #plot mean firing rates
    if name != 'real':
        sbplt[0][0].plot([0, np.max(mean_spike_count_real)],
                         [0, np.max(mean_spike_count_real)], 'k')
        sbplt[0][0].plot(mean_spike_count_real, mean_spike_count, '.g')
        mean_error = np.sum(np.abs(mean_spike_count - mean_spike_count_real))
        sbplt[0][0].set_xlabel('mean firing rate expt')
        sbplt[0][0].set_ylabel('mean firing rate model')
    else:
        sbplt[0][0].plot(mean_spike_count, 'b')
        sbplt[0][0].set_xlabel('neuron')
        sbplt[0][0].set_ylabel('firing probability')

    sbplt[0][0].set_title('mean firing rates')

    #plot covariances
    if name != 'real':
        variances_real = np.diag(cov_mat_real)
        only_cov_mat_real = cov_mat_real.copy()
        only_cov_mat_real[np.diag_indices(num_neurons)] = np.nan
        sbplt[0][1].plot([np.nanmin(only_cov_mat_real.flatten()),np.nanmax(only_cov_mat_real.flatten())],\
                        [np.nanmin(only_cov_mat_real.flatten()),np.nanmax(only_cov_mat_real.flatten())],'k')
        sbplt[0][1].plot(only_cov_mat_real.flatten(), only_cov_mat.flatten(),
                         '.g')
        sbplt[0][1].set_title('pairwise covariances')
        sbplt[0][1].set_xlabel('covariances expt')
        sbplt[0][1].set_ylabel('covariances model')
        corr_error = np.nansum(
            np.abs(only_cov_mat - only_cov_mat_real).flatten())
    else:
        map_aux = sbplt[0][1].imshow(only_cov_mat, interpolation='nearest')
        f.colorbar(map_aux, ax=sbplt[0][1])
        sbplt[0][1].set_title('covariance mat')
        sbplt[0][1].set_xlabel('neuron')
        sbplt[0][1].set_ylabel('neuron')

    #plot k-statistics
    if name != 'real':
        sbplt[1][0].plot([0, np.max(k_probs_real)],
                         [0, np.max(k_probs_real)], 'k')
        sbplt[1][0].plot(k_probs_real, k_probs, '.g')
        k_probs_error = np.sum(np.abs(k_probs - k_probs_real))
        sbplt[1][0].set_xlabel('k-probs expt')
        sbplt[1][0].set_ylabel('k-probs model')
    else:
        sbplt[1][0].plot(k_probs)
        sbplt[1][0].set_xlabel('K')
        sbplt[1][0].set_ylabel('probability')

    sbplt[1][0].set_title('k statistics')

    #plot average time course
    #firing_average_time_course[firing_average_time_course>0.048] = 0.048
    map_aux = sbplt[0][2].imshow(firing_average_time_course,
                                 interpolation='nearest')
    f.colorbar(map_aux, ax=sbplt[0][2])
    sbplt[0][2].set_title('sim firing time course')
    sbplt[0][2].set_xlabel('time (ms)')
    sbplt[0][2].set_ylabel('neuron')
    if name != 'real':
        map_aux = sbplt[1][2].imshow(firing_average_time_course_real,
                                     interpolation='nearest')
        f.colorbar(map_aux, ax=sbplt[1][2])
        sbplt[1][2].set_title('real firing time course')
        sbplt[1][2].set_xlabel('time (ms)')
        sbplt[1][2].set_ylabel('neuron')
        time_course_error = np.sum(
            np.abs(firing_average_time_course -
                   firing_average_time_course_real).flatten())

    f.savefig(folder + 'stats_' + name + '_II.svg',
              dpi=600,
              bbox_inches='tight')
    plt.close(f)

    if name != 'real':
        #PLOT LAG COVARIANCES
        #figure for all training error across epochs (supp. figure 2)
        f, sbplt = plt.subplots(2, 2, figsize=(8, 8), dpi=250)
        matplotlib.rcParams.update({'font.size': 8})
        plt.subplots_adjust(left=left,
                            bottom=bottom,
                            right=right,
                            top=top,
                            wspace=wspace,
                            hspace=hspace)
        map_aux = sbplt[0][0].imshow(lag_cov_mat_real, interpolation='nearest')
        f.colorbar(map_aux, ax=sbplt[0][0])
        sbplt[0][0].set_title('lag covariance mat expt')
        sbplt[0][0].set_xlabel('neuron')
        sbplt[0][0].set_ylabel('neuron shifted')
        map_aux = sbplt[1][0].imshow(lag_cov_mat, interpolation='nearest')
        f.colorbar(map_aux, ax=sbplt[1][0])
        sbplt[1][0].set_title('lag covariance mat model')
        sbplt[1][0].set_xlabel('neuron')
        sbplt[1][0].set_ylabel('neuron shifted')
        lag_corr_error = np.nansum(
            np.abs(lag_cov_mat - lag_cov_mat_real).flatten())
        sbplt[0][1].plot([np.min(lag_cov_mat_real.flatten()),np.max(lag_cov_mat_real.flatten())],\
                        [np.min(lag_cov_mat_real.flatten()),np.max(lag_cov_mat_real.flatten())],'k')
        sbplt[0][1].plot(lag_cov_mat_real, lag_cov_mat, '.g')
        sbplt[0][1].set_xlabel('lag cov real')
        sbplt[0][1].set_ylabel('lag cov model')
        sbplt[1][1].plot([np.min(variances_real.flatten()),np.max(variances.flatten())],\
                        [np.min(variances_real.flatten()),np.max(variances_real.flatten())],'k')
        sbplt[1][1].plot(variances_real.flatten(), variances.flatten(), '.g')
        sbplt[1][1].set_title('variances')
        sbplt[1][1].set_xlabel('variances expt')
        sbplt[1][1].set_ylabel('variances model')
        variance_error = np.nansum(
            np.abs(variances_real - variances).flatten())
        f.savefig(folder + 'lag_covs_' + name + '_II.svg',
                  dpi=600,
                  bbox_inches='tight')
        plt.close(f)

    if name == 'real' and len(firing_rate_mat) > 0:
        #ground truth data but not real (retinal) data
        data = {'mean':mean_spike_count, 'acf':autocorrelogram_mat, 'cov_mat':cov_mat, 'samples':X, 'k_probs':k_probs,'lag_cov_mat':lag_cov_mat,\
        'firing_rate_mat':firing_rate_mat, 'correlation_mat':correlation_mat, 'activity_peaks':activity_peaks, 'shuffled_index':shuffled_index, 'firing_average_time_course':firing_average_time_course}
        np.savez(folder + '/stats_' + name + '.npz', **data)
    else:
        data = {'mean':mean_spike_count, 'acf':autocorrelogram_mat, 'cov_mat':cov_mat, 'k_probs':k_probs, 'firing_average_time_course':firing_average_time_course,\
                'critic_cost':critic_cost, 'lag_cov_mat':lag_cov_mat}
        np.savez(folder + '/stats_' + name + '.npz', **data)
        if resave_real_data:
            if 'firing_rate_mat' in original_data:
                data = {'mean':mean_spike_count_real, 'acf':autocorrelogram_mat_real, 'cov_mat':cov_mat_real, 'samples':samples, 'k_probs':k_probs_real,'lag_cov_mat':lag_cov_mat_real,\
                'firing_rate_mat':original_data['firing_rate_mat'], 'correlation_mat':original_data['correlation_mat'], 'activity_peaks':original_data['activity_peaks'],\
                 'shuffled_index':original_data['shuffled_index'], 'firing_average_time_course':firing_average_time_course_real}
            else:
                data = {'mean':mean_spike_count_real, 'acf':autocorrelogram_mat_real, 'cov_mat':cov_mat_real, 'samples':samples, 'k_probs':k_probs_real,'lag_cov_mat':lag_cov_mat_real,\
                    'firing_average_time_course':firing_average_time_course_real}
            np.savez(folder + '/stats_real.npz', **data)
        if name != 'real':
            errors_mat = {'acf_error':acf_error, 'mean_error':mean_error, 'corr_error':corr_error, 'time_course_error':time_course_error, 'k_probs_error':k_probs_error,\
                          'variance_error':variance_error, 'lag_corr_error':lag_corr_error}
            np.savez(folder + '/errors_' + name + '.npz', **errors_mat)
            samples_fake = {'samples': X}
            np.savez(folder + '/samples_' + name[0:4] + '.npz', **samples_fake)
            return acf_error, mean_error, corr_error, time_course_error, k_probs_error
Пример #3
0
def main(_):
    #print parameters
    pp.pprint(flags.FLAGS.__flags)
    #folders
    if FLAGS.dataset == 'uniform':
        if FLAGS.architecture == 'fc':
            FLAGS.sample_dir = 'samples fc/' + 'dataset_' + FLAGS.dataset + '_num_samples_' + str(FLAGS.num_samples) +\
            '_num_neurons_' + str(FLAGS.num_neurons) + '_num_bins_' + str(FLAGS.num_bins)\
            + '_ref_period_' + str(FLAGS.ref_period) + '_firing_rate_' + str(FLAGS.firing_rate) + '_correlation_' + str(FLAGS.correlation) +\
            '_group_size_' + str(FLAGS.group_size)  + '_critic_iters_' + str(FLAGS.critic_iters) + '_lambda_' + str(FLAGS.lambd) +\
             '_num_units_' + str(FLAGS.num_units) +\
            '_iteration_' + FLAGS.iteration + '/'
        elif FLAGS.architecture == 'conv':
            FLAGS.sample_dir = 'samples conv/' + 'dataset_' + FLAGS.dataset + '_num_samples_' + str(FLAGS.num_samples) +\
            '_num_neurons_' + str(FLAGS.num_neurons) + '_num_bins_' + str(FLAGS.num_bins)\
            + '_ref_period_' + str(FLAGS.ref_period) + '_firing_rate_' + str(FLAGS.firing_rate) + '_correlation_' + str(FLAGS.correlation) +\
            '_group_size_' + str(FLAGS.group_size)  + '_critic_iters_' + str(FLAGS.critic_iters) + '_lambda_' + str(FLAGS.lambd) +\
            '_num_layers_' + str(FLAGS.num_layers)  + '_num_features_' + str(FLAGS.num_features) + '_kernel_' + str(FLAGS.kernel_width) +\
            '_iteration_' + FLAGS.iteration + '/'
    elif FLAGS.dataset == 'packets':
        if FLAGS.architecture == 'fc':
            FLAGS.sample_dir = 'samples fc/' + 'dataset_' + FLAGS.dataset + '_num_samples_' + str(FLAGS.num_samples) +\
            '_num_neurons_' + str(FLAGS.num_neurons) + '_num_bins_' + str(FLAGS.num_bins) + '_packet_prob_' + str(FLAGS.packet_prob)\
            + '_firing_rate_' + str(FLAGS.firing_rate) + '_group_size_' + str(FLAGS.group_size) + '_noise_in_packet_' + str(FLAGS.noise_in_packet) + '_number_of_modes_' + str(FLAGS.number_of_modes)  + '_critic_iters_' +\
            str(FLAGS.critic_iters) + '_lambda_' + str(FLAGS.lambd) + '_num_units_' + str(FLAGS.num_units) +\
            '_iteration_' + FLAGS.iteration + '/'
        elif FLAGS.architecture == 'conv':
            FLAGS.sample_dir = 'samples conv/' + 'dataset_' + FLAGS.dataset + '_num_samples_' + str(FLAGS.num_samples) +\
            '_num_neurons_' + str(FLAGS.num_neurons) + '_num_bins_' + str(FLAGS.num_bins) + '_packet_prob_' + str(FLAGS.packet_prob)\
            + '_firing_rate_' + str(FLAGS.firing_rate) + '_group_size_' + str(FLAGS.group_size) + '_noise_in_packet_' + str(FLAGS.noise_in_packet) + '_number_of_modes_' + str(FLAGS.number_of_modes)  + '_critic_iters_' +\
            str(FLAGS.critic_iters) + '_lambda_' + str(FLAGS.lambd) +\
            '_num_layers_' + str(FLAGS.num_layers)  + '_num_features_' + str(FLAGS.num_features) + '_kernel_' + str(FLAGS.kernel_width) +\
            '_iteration_' + FLAGS.iteration + '/'
    elif FLAGS.dataset == 'retina':
        if FLAGS.architecture == 'fc':
            FLAGS.sample_dir = 'samples fc/' + 'dataset_' + FLAGS.dataset  +\
            '_num_neurons_' + str(FLAGS.num_neurons) + '_num_bins_' + str(FLAGS.num_bins)\
            + '_critic_iters_' + str(FLAGS.critic_iters) + '_lambda_' + str(FLAGS.lambd) +\
            '_num_units_' + str(FLAGS.num_units) +\
            '_iteration_' + FLAGS.iteration + '/'
        elif FLAGS.architecture == 'conv':
            FLAGS.sample_dir = 'samples conv/' + 'dataset_' + FLAGS.dataset + '_num_samples_' + str(FLAGS.num_samples) +\
            '_num_neurons_' + str(FLAGS.num_neurons) + '_num_bins_' + str(FLAGS.num_bins)\
            + '_critic_iters_' + str(FLAGS.critic_iters) + '_lambda_' + str(FLAGS.lambd) +\
            '_num_layers_' + str(FLAGS.num_layers)  + '_num_features_' + str(FLAGS.num_features) + '_kernel_' + str(FLAGS.kernel_width) +\
            '_iteration_' + FLAGS.iteration + '/'

    FLAGS.checkpoint_dir = FLAGS.sample_dir + 'checkpoint/'

    with tf.Session() as sess:
        wgan = WGAN_conv(sess,
                         num_neurons=FLAGS.num_neurons,
                         num_bins=FLAGS.num_bins,
                         num_layers=FLAGS.num_layers,
                         num_features=FLAGS.num_features,
                         kernel_width=FLAGS.kernel_width,
                         lambd=FLAGS.lambd,
                         batch_size=FLAGS.batch_size,
                         checkpoint_dir=FLAGS.checkpoint_dir,
                         sample_dir=FLAGS.sample_dir)
        if not wgan.load(FLAGS.training_stage):
            raise Exception("[!] Train a model first, then run test mode")

        num_samples = 8000
        if FLAGS.dataset == 'retina':
            samples = retinal_data.get_samples(num_bins=FLAGS.num_bins,
                                               num_neurons=FLAGS.num_neurons,
                                               instance=FLAGS.data_instance,
                                               folder=os.getcwd() +
                                               '/data/retinal data/').T
        else:
            original_dataset = np.load(FLAGS.sample_dir + '/stats_real.npz')
            if FLAGS.number_of_modes == 1:
                _ = sim_pop_activity.spike_train_transient_packets(num_samples=num_samples, num_bins=FLAGS.num_bins, num_neurons=FLAGS.num_neurons, group_size=FLAGS.group_size,\
                                                                     prob_packets=FLAGS.packet_prob,firing_rates_mat=original_dataset['firing_rate_mat'], refr_per=FLAGS.ref_period,\
                                                                     shuffled_index=original_dataset['shuffled_index'], limits=[0,64], groups=[0,1,2,3], folder=FLAGS.sample_dir, save_packet=True).T

                samples = sim_pop_activity.spike_train_transient_packets(num_samples=num_samples, num_bins=FLAGS.num_bins, num_neurons=FLAGS.num_neurons, group_size=FLAGS.group_size,\
                                                                     prob_packets=0.2,firing_rates_mat=original_dataset['firing_rate_mat'], refr_per=FLAGS.ref_period,\
                                                                     shuffled_index=original_dataset['shuffled_index'], limits=[16,32], groups=[0], folder=FLAGS.sample_dir, save_packet=False).T
            elif FLAGS.number_of_modes == 2:
                samples = original_dataset['samples'].T
                num_samples = np.min([num_samples, samples.shape[0]])
                stim1_samples = np.zeros(
                    (int(num_samples / 2), FLAGS.num_neurons, FLAGS.num_bins))
                stim2_samples = np.zeros(
                    (int(num_samples / 2), FLAGS.num_neurons, FLAGS.num_bins))
                for ind_s in range(int(num_samples / 2)):
                    stim1_samples[ind_s,:,:] ,_ = sim_pop_activity.spike_train_packets(num_bins=FLAGS.num_bins, num_neurons=FLAGS.num_neurons, group_size=FLAGS.group_size, firing_rates_mat=original_dataset['firing_rate_mat'], \
                                                         refr_per=FLAGS.ref_period, noise=FLAGS.noise_in_packet, number_of_modes=FLAGS.number_of_modes, save_sample=True, folder=FLAGS.sample_dir, mode=0)
                    stim2_samples[ind_s,:,:],_ = sim_pop_activity.spike_train_packets(num_bins=FLAGS.num_bins, num_neurons=FLAGS.num_neurons, group_size=FLAGS.group_size, firing_rates_mat=original_dataset['firing_rate_mat'], \
                                                         refr_per=FLAGS.ref_period, noise=FLAGS.noise_in_packet, number_of_modes=FLAGS.number_of_modes, save_sample=True, folder=FLAGS.sample_dir, mode=1)

                packets = {
                    'packets_stim1': stim1_samples,
                    'packets_stim2': stim2_samples
                }
                np.savez(FLAGS.sample_dir + 'packets.npz', **packets)
                sim_pop_activity.spike_train_packets(num_bins=FLAGS.num_bins, num_neurons=FLAGS.num_neurons, group_size=FLAGS.group_size, firing_rates_mat=original_dataset['firing_rate_mat'], \
                                                     refr_per=FLAGS.ref_period, noise=0, number_of_modes=FLAGS.number_of_modes, save_sample=True, folder=FLAGS.sample_dir, mode=0)

                sim_pop_activity.spike_train_packets(num_bins=FLAGS.num_bins, num_neurons=FLAGS.num_neurons, group_size=FLAGS.group_size, firing_rates_mat=original_dataset['firing_rate_mat'], \
                                                     refr_per=FLAGS.ref_period, noise=0, number_of_modes=FLAGS.number_of_modes, save_sample=True, folder=FLAGS.sample_dir, mode=1)

        inputs = tf.placeholder(
            tf.float32,
            name='inputs_to_discriminator',
            shape=[None, FLAGS.num_neurons * FLAGS.num_bins])
        score = wgan.get_critics_output(inputs)

        step = FLAGS.step
        pattern_size = FLAGS.pattern_size
        times = step * np.arange(int(FLAGS.num_bins / step))
        times = np.delete(times,
                          np.nonzero(times > FLAGS.num_bins - pattern_size))
        importance_time_vector = np.zeros((num_samples, FLAGS.num_bins))
        importance_neuron_vector = np.zeros((num_samples, FLAGS.num_neurons))
        grad_maps = np.zeros((num_samples, FLAGS.num_neurons, FLAGS.num_bins))
        activity_map = np.zeros((FLAGS.num_neurons, FLAGS.num_bins))
        importance_time_vector_surr = np.zeros((num_samples, FLAGS.num_bins))
        importance_neuron_vector_surr = np.zeros(
            (num_samples, FLAGS.num_neurons))
        sample_diff = np.zeros((FLAGS.num_neurons, FLAGS.num_bins))
        samples = samples[0:num_samples, :]
        for i in range(num_samples):
            sample = samples[i, :]
            time0 = time.time()
            #get importance map for sample and compute the averages across time and space
            grad_maps[i, :, :], _, sample_diff_aux = patterns_relevance(
                sample, FLAGS.num_neurons, score, inputs, sess, pattern_size,
                times)
            time1 = time.time()
            importance_time_vector[i, :] = np.mean(grad_maps[i, :, :], axis=0)
            importance_neuron_vector[i, :] = np.mean(grad_maps[i, :, :],
                                                     axis=1)
            sample_diff += sample_diff_aux

            #compute surrogate data (not used in the IClR paper)
            aux, _, _ = patterns_relevance(sample,
                                           FLAGS.num_neurons,
                                           score,
                                           inputs,
                                           sess,
                                           pattern_size,
                                           times,
                                           shuffle=True)
            time1 = time.time()
            importance_time_vector_surr[i, :] = np.mean(aux, axis=0)
            importance_neuron_vector_surr[i, :] = np.mean(aux, axis=1)

            sample = sample.reshape(FLAGS.num_neurons, -1)
            activity_map += sample
            print(str(i) + ' time ' + str(time1 - time0))

        stimulus_id = np.load(FLAGS.sample_dir + '/stim.npz')['stimulus']

        stimulus_id = stimulus_id[0:num_samples]
        predicted_packets = analysis.find_packets(grad_maps, samples,
                                                  FLAGS.num_neurons,
                                                  FLAGS.num_bins,
                                                  FLAGS.sample_dir,
                                                  num_samples)
        ground_truth_packets_stim1 = np.mean(analysis.find_packets(
            stim1_samples,
            stim1_samples - 1,
            FLAGS.num_neurons,
            FLAGS.num_bins,
            FLAGS.sample_dir,
            int(num_samples / 2),
            plot_fig=False),
                                             axis=0)
        ground_truth_packets_stim2 = np.mean(analysis.find_packets(
            stim2_samples,
            stim2_samples - 1,
            FLAGS.num_neurons,
            FLAGS.num_bins,
            FLAGS.sample_dir,
            int(num_samples / 2),
            plot_fig=False),
                                             axis=0)

        importance_vectors = {'time':importance_time_vector, 'neurons':importance_neuron_vector, 'grad_maps':grad_maps, 'samples':samples, 'activity_map':activity_map,\
                              'time_surr':importance_time_vector_surr, 'neurons_surr':importance_neuron_vector_surr, 'sample_diff':sample_diff, 'predicted_packets':predicted_packets,\
                              'ground_truth_packets_stim1':ground_truth_packets_stim1, 'ground_truth_packets_stim2':ground_truth_packets_stim2}
        np.savez(
            FLAGS.sample_dir + 'importance_vectors_' + str(step) + '_' +
            str(pattern_size) + '_' + str(num_samples) + '.npz',
            **importance_vectors)
Пример #4
0
def generate_spike_trains(config, recovery_dir):
    if config.dataset == 'uniform':
        if recovery_dir != "":
            aux = np.load(recovery_dir + '/stats_real.npz')
            real_samples = aux['samples']
            firing_rates_mat = aux['firing_rate_mat']
            correlations_mat = aux['correlation_mat']
            activity_peaks = aux['activity_peaks']
            shuffled_index = aux['shuffled_index']
        else:
            #we shuffle neurons to test if the network still learns the packets
            shuffled_index = np.arange(config.num_neurons)
            np.random.shuffle(shuffled_index)
            firing_rates_mat = config.firing_rate + 2 * (
                np.random.random(int(config.num_neurons / config.group_size), )
                - 0.5) * config.firing_rate / 2
            correlations_mat = config.correlation + 2 * (
                np.random.random(int(config.num_neurons / config.group_size), )
                - 0.5) * config.correlation / 2
            #peaks of activity
            #sequence response
            aux = np.arange(int(config.num_neurons / config.group_size))
            activity_peaks = [
                [x] * config.group_size for x in aux
            ]  #np.random.randint(0,high=config.num_bins,size=(1,config.num_neurons)).reshape(config.num_neurons,1)
            activity_peaks = np.asarray(activity_peaks)
            activity_peaks = activity_peaks.flatten()
            activity_peaks = activity_peaks * config.group_size * config.num_bins / config.num_neurons
            activity_peaks = activity_peaks.reshape(config.num_neurons, 1)
            #peak of activity equal for all neurons
            #activity_peaks = np.zeros((config.num_neurons,1))+config.num_bins/4
            real_samples = sim_pop_activity.get_samples(num_samples=config.num_samples, num_bins=config.num_bins,\
                                num_neurons=config.num_neurons, correlations_mat=correlations_mat, group_size=config.group_size, shuffled_index=shuffled_index,\
                                refr_per=config.ref_period,firing_rates_mat=firing_rates_mat, activity_peaks=activity_peaks, folder=config.sample_dir)

        #save original statistics
        analysis.get_stats(X=real_samples, num_neurons=config.num_neurons, num_bins=config.num_bins, folder=config.sample_dir, shuffled_index=shuffled_index,\
                           name='real',firing_rate_mat=firing_rates_mat, correlation_mat=correlations_mat, activity_peaks=activity_peaks)

        #get dev samples
        dev_samples = sim_pop_activity.get_samples(num_samples=int(config.num_samples/4), num_bins=config.num_bins,\
                       num_neurons=config.num_neurons, correlations_mat=correlations_mat, group_size=config.group_size, shuffled_index=shuffled_index,\
                       refr_per=config.ref_period,firing_rates_mat=firing_rates_mat, activity_peaks=activity_peaks)

    elif config.dataset == 'packets':
        if recovery_dir != "":
            aux = np.load(recovery_dir + '/stats_real.npz')
            real_samples = aux['samples']
            firing_rates_mat = aux['firing_rate_mat']
            shuffled_index = aux['shuffled_index']
        else:
            #we shuffle neurons to test if the network still learns the packets
            shuffled_index = np.arange(config.num_neurons)
            np.random.shuffle(shuffled_index)
            firing_rates_mat = config.firing_rate + 2 * (np.random.random(
                size=(config.num_neurons, 1)) - 0.5) * config.firing_rate / 2
            real_samples = sim_pop_activity.get_samples(num_samples=config.num_samples, num_bins=config.num_bins, refr_per=config.ref_period,\
                                 num_neurons=config.num_neurons, group_size=config.group_size, firing_rates_mat=firing_rates_mat, packets_on=True,\
                                 prob_packets=config.packet_prob, shuffled_index=shuffled_index, folder=config.sample_dir)
        #save original statistics
        analysis.get_stats(X=real_samples, num_neurons=config.num_neurons, num_bins=config.num_bins, folder=config.sample_dir, name='real',\
                       firing_rate_mat=firing_rates_mat, shuffled_index=shuffled_index)
        #get dev samples
        dev_samples = sim_pop_activity.get_samples(num_samples=int(config.num_samples/4), num_bins=config.num_bins, refr_per=config.ref_period,\
                       num_neurons=config.num_neurons, group_size=config.group_size, firing_rates_mat=firing_rates_mat, packets_on=True,\
                       prob_packets=config.packet_prob,shuffled_index=shuffled_index)

    elif config.dataset == 'retina':
        real_samples = retinal_data.get_samples(num_bins=config.num_bins,
                                                num_neurons=config.num_neurons,
                                                instance=config.data_instance)
        #save original statistics
        analysis.get_stats(X=real_samples,
                           num_neurons=config.num_neurons,
                           num_bins=config.num_bins,
                           folder=config.sample_dir,
                           name='real',
                           instance=config.data_instance)
        dev_samples = []
    return real_samples, dev_samples
Пример #5
0
def generate_spike_trains(config, recovery_dir):
    '''
    this function returns the training and dev sets, corresponding to the parameters provided in config
    '''
    if config.dataset == 'uniform':
        if recovery_dir != "":
            aux = np.load(recovery_dir + '/stats_real.npz')
            real_samples = aux['samples']
            firing_rates_mat = aux['firing_rate_mat']
            correlations_mat = aux['correlation_mat']
            shuffled_index = aux['shuffled_index']
        else:
            #shuffle neurons
            shuffled_index = np.arange(config.num_neurons)
            np.random.shuffle(shuffled_index)
            firing_rates_mat = config.firing_rate + 2 * (
                np.random.random(int(config.num_neurons / config.group_size), )
                - 0.5) * config.firing_rate / 2
            correlations_mat = config.correlation + 2 * (
                np.random.random(int(config.num_neurons / config.group_size), )
                - 0.5) * config.correlation / 2
            #peaks of activity
            aux = np.arange(int(config.num_neurons / config.group_size))
            #peak of activity equal for all neurons
            real_samples = sim_pop_activity.get_samples(num_samples=config.num_samples, num_bins=config.num_bins,\
                                num_neurons=config.num_neurons, correlations_mat=correlations_mat, group_size=config.group_size, shuffled_index=shuffled_index,\
                                refr_per=config.ref_period,firing_rates_mat=firing_rates_mat, folder=config.sample_dir)

        #save original statistics
        analysis.get_stats(X=real_samples, num_neurons=config.num_neurons, num_bins=config.num_bins, folder=config.sample_dir, shuffled_index=shuffled_index,\
                           name='real',firing_rate_mat=firing_rates_mat, correlation_mat=correlations_mat)

        #get dev samples
        dev_samples = sim_pop_activity.get_samples(num_samples=int(config.num_samples/4), num_bins=config.num_bins,\
                       num_neurons=config.num_neurons, correlations_mat=correlations_mat, group_size=config.group_size, shuffled_index=shuffled_index,\
                       refr_per=config.ref_period,firing_rates_mat=firing_rates_mat)

    elif config.dataset == 'packets':
        if recovery_dir != "":
            aux = np.load(recovery_dir + '/stats_real.npz')
            real_samples = aux['samples']
            firing_rates_mat = aux['firing_rate_mat']
            shuffled_index = aux['shuffled_index']
        else:
            #shuffle the neurons
            shuffled_index = np.arange(config.num_neurons)
            np.random.shuffle(shuffled_index)
            firing_rates_mat = config.firing_rate + 2 * (np.random.random(
                size=(config.num_neurons, 1)) - 0.5) * config.firing_rate / 2
            real_samples = sim_pop_activity.get_samples(num_samples=config.num_samples, num_bins=config.num_bins, refr_per=config.ref_period,\
                                 num_neurons=config.num_neurons, group_size=config.group_size, firing_rates_mat=firing_rates_mat, packets_on=True,\
                                 prob_packets=config.packet_prob, shuffled_index=shuffled_index, folder=config.sample_dir, noise_in_packet=config.noise_in_packet, number_of_modes=config.number_of_modes)
        #save original statistics
        analysis.get_stats(X=real_samples, num_neurons=config.num_neurons, num_bins=config.num_bins, folder=config.sample_dir, name='real',\
                       firing_rate_mat=firing_rates_mat, shuffled_index=shuffled_index)
        #get dev samples
        dev_samples = sim_pop_activity.get_samples(num_samples=int(config.num_samples/4), num_bins=config.num_bins, refr_per=config.ref_period,\
                       num_neurons=config.num_neurons, group_size=config.group_size, firing_rates_mat=firing_rates_mat, packets_on=True,\
                       prob_packets=config.packet_prob, shuffled_index=shuffled_index, noise_in_packet=config.noise_in_packet, number_of_modes=config.number_of_modes)

    elif config.dataset == 'retina':
        dirpath = os.getcwd()
        real_samples = retinal_data.get_samples(num_bins=config.num_bins,
                                                num_neurons=config.num_neurons,
                                                instance=config.data_instance,
                                                folder=dirpath +
                                                '/data/retinal data/')
        #save original statistics
        analysis.get_stats(X=real_samples,
                           num_neurons=config.num_neurons,
                           num_bins=config.num_bins,
                           folder=config.sample_dir,
                           name='real',
                           instance=config.data_instance)
        dev_samples = []
    return real_samples, dev_samples
Пример #6
0
    def train(self, config):
        """Train DCGAN"""
        #define optimizer
        self.g_optim = tf.train.AdamOptimizer(
            learning_rate=config.learning_rate,
            beta1=config.beta1,
            beta2=config.beta2).minimize(
                self.gen_cost,
                var_list=params_with_name('Generator'),
                colocate_gradients_with_ops=True)
        self.d_optim = tf.train.AdamOptimizer(
            learning_rate=config.learning_rate,
            beta1=config.beta1,
            beta2=config.beta2).minimize(
                self.disc_cost,
                var_list=params_with_name('Discriminator.'),
                colocate_gradients_with_ops=True)

        #initizialize variables
        try:
            tf.global_variables_initializer().run()
        except:
            tf.initialize_all_variables().run()

        #try to load trained parameters
        self.load()

        #get real samples
        if config.dataset == 'uniform':
            firing_rates_mat = config.firing_rate + 2 * (
                np.random.random(int(self.num_neurons / config.group_size), ) -
                0.5) * config.firing_rate / 2
            correlations_mat = config.correlation + 2 * (
                np.random.random(int(self.num_neurons / config.group_size), ) -
                0.5) * config.correlation / 2
            aux = np.arange(int(self.num_neurons / config.group_size))
            activity_peaks = [
                [x] * config.group_size for x in aux
            ]  #np.random.randint(0,high=self.num_bins,size=(1,self.num_neurons)).reshape(self.num_neurons,1)
            activity_peaks = np.asarray(activity_peaks)
            activity_peaks = activity_peaks.flatten()
            activity_peaks = activity_peaks * config.group_size * self.num_bins / self.num_neurons
            activity_peaks = activity_peaks.reshape(self.num_neurons, 1)
            #activity_peaks = np.zeros((self.num_neurons,1))+self.num_bins/4
            self.real_samples = sim_pop_activity.get_samples(num_samples=config.num_samples, num_bins=self.num_bins,\
            num_neurons=self.num_neurons, correlations_mat=correlations_mat, group_size=config.group_size, refr_per=config.ref_period,firing_rates_mat=firing_rates_mat, activity_peaks=activity_peaks)
            #get dev samples
            dev_samples = sim_pop_activity.get_samples(num_samples=int(config.num_samples/4), num_bins=self.num_bins,\
            num_neurons=self.num_neurons, correlations_mat=correlations_mat, group_size=config.group_size, refr_per=config.ref_period,firing_rates_mat=firing_rates_mat, activity_peaks=activity_peaks)
            #save original statistics
            analysis.get_stats(X=self.real_samples,
                               num_neurons=self.num_neurons,
                               num_bins=self.num_bins,
                               folder=self.sample_dir,
                               name='real',
                               firing_rate_mat=firing_rates_mat,
                               correlation_mat=correlations_mat,
                               activity_peaks=activity_peaks)
        elif config.dataset == 'retina':
            self.real_samples = retinal_data.get_samples(
                num_bins=self.num_bins,
                num_neurons=self.num_neurons,
                instance=config.data_instance)
            #save original statistics
            analysis.get_stats(X=self.real_samples,
                               num_neurons=self.num_neurons,
                               num_bins=self.num_bins,
                               folder=self.sample_dir,
                               name='real',
                               instance=config.data_instance)

        #count number of variables
        total_parameters = 0
        for variable in tf.trainable_variables():
            # shape is an array of tf.Dimension
            shape = variable.get_shape()
            variable_parameters = 1
            for dim in shape:
                variable_parameters *= dim.value
            total_parameters += variable_parameters
        print('number of varaibles: ' + str(total_parameters))
        #start training
        counter_batch = 0
        epoch = 0
        #fitting errors
        f, sbplt = plt.subplots(2, 2, figsize=(8, 8), dpi=250)
        matplotlib.rcParams.update({'font.size': 8})
        plt.subplots_adjust(left=left,
                            bottom=bottom,
                            right=right,
                            top=top,
                            wspace=wspace,
                            hspace=hspace)
        for iteration in range(config.num_iter):
            start_time = time.time()
            # Train generator (only after the critic has been trained, at least once)
            if iteration > 0:
                _ = self.sess.run(self.g_optim)

            # Train critic
            disc_iters = config.critic_iters
            for i in range(disc_iters):
                #get batch and trained critic
                _data = self.real_samples[:, counter_batch *
                                          config.batch_size:(counter_batch +
                                                             1) *
                                          config.batch_size].T
                _disc_cost, _ = self.sess.run([self.disc_cost, self.d_optim],
                                              feed_dict={self.inputs: _data})
                #if we have reached the end of the real samples set, we start over and increment the number of epochs
                if counter_batch == int(
                        self.real_samples.shape[1] / self.batch_size) - 1:
                    counter_batch = 0
                    epoch += 1
                else:
                    counter_batch += 1
            aux = time.time() - start_time
            #plot the  critics loss and the iteration time
            plot.plot(self.sample_dir, 'train disc cost', -_disc_cost)
            plot.plot(self.sample_dir, 'time', aux)

            if (
                    iteration == 500
            ) or iteration % 20000 == 19999 or iteration > config.num_iter - 10:
                print('epoch ' + str(epoch))
                if config.dataset == 'uniform':
                    #this is to evaluate whether the discriminator has overfit
                    dev_disc_costs = []
                    for ind_dev in range(
                            int(dev_samples.shape[1] / self.batch_size)):
                        images = dev_samples[:, ind_dev *
                                             config.batch_size:(ind_dev + 1) *
                                             config.batch_size].T
                        _dev_disc_cost = self.sess.run(
                            self.disc_cost, feed_dict={self.inputs: images})
                        dev_disc_costs.append(_dev_disc_cost)
                    #plot the dev loss
                    plot.plot(self.sample_dir, 'dev disc cost',
                              -np.mean(dev_disc_costs))

                #save the network parameters
                self.save(iteration)

                #get simulated samples, calculate their statistics and compare them with the original ones
                fake_samples = self.get_samples(num_samples=2**13)
                fake_samples = fake_samples.eval(session=self.sess)
                fake_samples = self.binarize(samples=fake_samples)
                acf_error, mean_error, corr_error, time_course_error,_ = analysis.get_stats(X=fake_samples.T, num_neurons=config.num_neurons,\
                    num_bins=config.num_bins, folder=config.sample_dir, name='fake'+str(iteration), critic_cost=-_disc_cost,instance=config.data_instance)
                #plot the fitting errors
                sbplt[0][0].plot(iteration, mean_error, '+b')
                sbplt[0][0].set_title('spk-count mean error')
                sbplt[0][0].set_xlabel('iterations')
                sbplt[0][0].set_ylabel('L1 error')
                sbplt[0][0].set_xlim([
                    0 - config.num_iter / 4,
                    config.num_iter + config.num_iter / 4
                ])
                sbplt[0][1].plot(iteration, time_course_error, '+b')
                sbplt[0][1].set_title('time course error')
                sbplt[0][1].set_xlabel('iterations')
                sbplt[0][1].set_ylabel('L1 error')
                sbplt[0][1].set_xlim([
                    0 - config.num_iter / 4,
                    config.num_iter + config.num_iter / 4
                ])
                sbplt[1][0].plot(iteration, acf_error, '+b')
                sbplt[1][0].set_title('AC error')
                sbplt[1][0].set_xlabel('iterations')
                sbplt[1][0].set_ylabel('L1 error')
                sbplt[1][0].set_xlim([
                    0 - config.num_iter / 4,
                    config.num_iter + config.num_iter / 4
                ])
                sbplt[1][1].plot(iteration, corr_error, '+b')
                sbplt[1][1].set_title('corr error')
                sbplt[1][1].set_xlabel('iterations')
                sbplt[1][1].set_ylabel('L1 error')
                sbplt[1][1].set_xlim([
                    0 - config.num_iter / 4,
                    config.num_iter + config.num_iter / 4
                ])
                f.savefig(self.sample_dir + 'fitting_errors.svg',
                          dpi=600,
                          bbox_inches='tight')
                plt.close(f)
                plot.flush(self.sample_dir)

            plot.tick()