def main(_): #print parameters pp.pprint(tf.app.flags.FLAGS.flag_values_dict()) #folders if FLAGS.dataset == 'uniform': if FLAGS.architecture == 'fc': FLAGS.sample_dir = 'samples fc/' + 'dataset_' + FLAGS.dataset + '_num_samples_' + str(FLAGS.num_samples) +\ '_num_neurons_' + str(FLAGS.num_neurons) + '_num_bins_' + str(FLAGS.num_bins)\ + '_ref_period_' + str(FLAGS.ref_period) + '_firing_rate_' + str(FLAGS.firing_rate) + '_correlation_' + str(FLAGS.correlation) +\ '_group_size_' + str(FLAGS.group_size) + '_critic_iters_' + str(FLAGS.critic_iters) + '_lambda_' + str(FLAGS.lambd) +\ '_num_units_' + str(FLAGS.num_units) +\ '_iteration_' + FLAGS.iteration + '/' elif FLAGS.architecture == 'conv': FLAGS.sample_dir = 'samples conv/' + 'dataset_' + FLAGS.dataset + '_num_samples_' + str(FLAGS.num_samples) +\ '_num_neurons_' + str(FLAGS.num_neurons) + '_num_bins_' + str(FLAGS.num_bins)\ + '_ref_period_' + str(FLAGS.ref_period) + '_firing_rate_' + str(FLAGS.firing_rate) + '_correlation_' + str(FLAGS.correlation) +\ '_group_size_' + str(FLAGS.group_size) + '_critic_iters_' + str(FLAGS.critic_iters) + '_lambda_' + str(FLAGS.lambd) +\ '_num_layers_' + str(FLAGS.num_layers) + '_num_features_' + str(FLAGS.num_features) + '_kernel_' + str(FLAGS.kernel_width) +\ '_iteration_' + FLAGS.iteration + '/' elif FLAGS.dataset == 'packets' and FLAGS.number_of_modes == 1: if FLAGS.architecture == 'fc': FLAGS.sample_dir = 'samples fc/' + 'dataset_' + FLAGS.dataset + '_num_samples_' + str(FLAGS.num_samples) +\ '_num_neurons_' + str(FLAGS.num_neurons) + '_num_bins_' + str(FLAGS.num_bins) + '_packet_prob_' + str(FLAGS.packet_prob)\ + '_firing_rate_' + str(FLAGS.firing_rate) + '_group_size_' + str(FLAGS.group_size) + '_critic_iters_' +\ str(FLAGS.critic_iters) + '_lambda_' + str(FLAGS.lambd) + '_num_units_' + str(FLAGS.num_units) +\ '_iteration_' + FLAGS.iteration + '/' elif FLAGS.architecture == 'conv': FLAGS.sample_dir = 'samples conv/' + 'dataset_' + FLAGS.dataset + '_num_samples_' + str(FLAGS.num_samples) +\ '_num_neurons_' + str(FLAGS.num_neurons) + '_num_bins_' + str(FLAGS.num_bins) + '_packet_prob_' + str(FLAGS.packet_prob)\ + '_firing_rate_' + str(FLAGS.firing_rate) + '_group_size_' + str(FLAGS.group_size) + '_critic_iters_' +\ str(FLAGS.critic_iters) + '_lambda_' + str(FLAGS.lambd) +\ '_num_layers_' + str(FLAGS.num_layers) + '_num_features_' + str(FLAGS.num_features) + '_kernel_' + str(FLAGS.kernel_width) +\ '_iteration_' + FLAGS.iteration + '/' elif FLAGS.dataset == 'packets' and FLAGS.number_of_modes == 2: if FLAGS.architecture == 'fc': FLAGS.sample_dir = 'samples fc/' + 'dataset_' + FLAGS.dataset + '_num_samples_' + str(FLAGS.num_samples) +\ '_num_neurons_' + str(FLAGS.num_neurons) + '_num_bins_' + str(FLAGS.num_bins) + '_packet_prob_' + str(FLAGS.packet_prob)\ + '_firing_rate_' + str(FLAGS.firing_rate) + '_group_size_' + str(FLAGS.group_size) + '_noise_in_packet_' + str(FLAGS.noise_in_packet) + '_number_of_modes_' + str(FLAGS.number_of_modes) + '_critic_iters_' +\ str(FLAGS.critic_iters) + '_lambda_' + str(FLAGS.lambd) + '_num_units_' + str(FLAGS.num_units) +\ '_iteration_' + FLAGS.iteration + '/' elif FLAGS.architecture == 'conv': FLAGS.sample_dir = 'samples conv/' + 'dataset_' + FLAGS.dataset + '_num_samples_' + str(FLAGS.num_samples) +\ '_num_neurons_' + str(FLAGS.num_neurons) + '_num_bins_' + str(FLAGS.num_bins) + '_packet_prob_' + str(FLAGS.packet_prob)\ + '_firing_rate_' + str(FLAGS.firing_rate) + '_group_size_' + str(FLAGS.group_size) + '_noise_in_packet_' + str(FLAGS.noise_in_packet) + '_number_of_modes_' + str(FLAGS.number_of_modes) + '_critic_iters_' +\ str(FLAGS.critic_iters) + '_lambda_' + str(FLAGS.lambd) +\ '_num_layers_' + str(FLAGS.num_layers) + '_num_features_' + str(FLAGS.num_features) + '_kernel_' + str(FLAGS.kernel_width) +\ '_iteration_' + FLAGS.iteration + '/' elif FLAGS.dataset == 'retina': if FLAGS.architecture == 'fc': FLAGS.sample_dir = 'samples fc/' + 'dataset_' + FLAGS.dataset +\ '_num_neurons_' + str(FLAGS.num_neurons) + '_num_bins_' + str(FLAGS.num_bins)\ + '_critic_iters_' + str(FLAGS.critic_iters) + '_lambda_' + str(FLAGS.lambd) +\ '_num_units_' + str(FLAGS.num_units) +\ '_iteration_' + FLAGS.iteration + '/' elif FLAGS.architecture == 'conv': FLAGS.sample_dir = 'samples conv/' + 'dataset_' + FLAGS.dataset + '_num_samples_' + str(FLAGS.num_samples) +\ '_num_neurons_' + str(FLAGS.num_neurons) + '_num_bins_' + str(FLAGS.num_bins)\ + '_critic_iters_' + str(FLAGS.critic_iters) + '_lambda_' + str(FLAGS.lambd) +\ '_num_layers_' + str(FLAGS.num_layers) + '_num_features_' + str(FLAGS.num_features) + '_kernel_' + str(FLAGS.kernel_width) +\ '_iteration_' + FLAGS.iteration + '/' FLAGS.checkpoint_dir = FLAGS.sample_dir + 'checkpoint/' if not os.path.exists(FLAGS.checkpoint_dir): os.makedirs(FLAGS.checkpoint_dir) if not os.path.exists(FLAGS.sample_dir): os.makedirs(FLAGS.sample_dir) if FLAGS.recovery_dir == "" and os.path.exists(FLAGS.sample_dir + '/stats_real.npz'): FLAGS.recovery_dir = FLAGS.sample_dir run_config = tf.ConfigProto() run_config.gpu_options.allow_growth = True with tf.Session(config=run_config) as sess: wgan = WGAN_conv(sess, architecture=FLAGS.architecture, num_neurons=FLAGS.num_neurons, num_bins=FLAGS.num_bins, num_layers=FLAGS.num_layers, num_units=FLAGS.num_units, num_features=FLAGS.num_features, kernel_width=FLAGS.kernel_width, lambd=FLAGS.lambd, batch_size=FLAGS.batch_size, checkpoint_dir=FLAGS.checkpoint_dir, sample_dir=FLAGS.sample_dir) if FLAGS.is_train: training_samples, dev_samples = data_provider.generate_spike_trains( FLAGS, FLAGS.recovery_dir) wgan.training_samples = training_samples wgan.dev_samples = dev_samples print('data loaded') wgan.train(FLAGS) else: if not wgan.load(FLAGS.training_stage): raise Exception("[!] Train a model first, then run test mode") #LOAD TRAINING DATASET (and its statistics) original_dataset = np.load(FLAGS.sample_dir + '/stats_real.npz') #PLOT FILTERS if FLAGS.dataset == 'retina': index = np.arange(FLAGS.num_neurons) else: index = np.argsort(original_dataset['shuffled_index']) if FLAGS.architecture == 'conv': print('get filters -----------------------------------') filters = wgan.get_filters(num_samples=64) visualize_filters_and_units.plot_filters(filters, sess, FLAGS, index) #GET GENERATED SAMPLES AND COMPUTE THEIR STATISTICS print('compute stats -----------------------------------') if 'samples' not in original_dataset: real_samples = retinal_data.get_samples( num_bins=FLAGS.num_bins, num_neurons=FLAGS.num_neurons, instance=FLAGS.data_instance, folder=os.getcwd() + '/data/retinal data/') else: real_samples = original_dataset['samples'] sim_pop_activity.plot_samples(real_samples, FLAGS.num_neurons, FLAGS.sample_dir, 'real') fake_samples = wgan.get_samples(num_samples=FLAGS.num_samples) fake_samples = fake_samples.eval(session=sess) sim_pop_activity.plot_samples(fake_samples.T, FLAGS.num_neurons, FLAGS.sample_dir, 'fake') _, _, _, _, _ = analysis.get_stats(X=fake_samples.T, num_neurons=FLAGS.num_neurons, num_bins=FLAGS.num_bins, folder=FLAGS.sample_dir, name='fake', instance=FLAGS.data_instance) #EVALUATE HIGH ORDER FEATURES (only when dimensionality is low) if FLAGS.dataset == 'uniform' and FLAGS.num_bins * FLAGS.num_neurons < 40: print( 'compute high order statistics-----------------------------------' ) num_trials = int(2**8) num_samples_per_trial = 2**13 fake_samples_mat = np.zeros((num_trials * num_samples_per_trial, FLAGS.num_neurons * FLAGS.num_bins)) for ind_tr in range(num_trials): fake_samples = wgan.sess.run([wgan.ex_samples])[0] fake_samples_mat[ind_tr * num_samples_per_trial:(ind_tr + 1) * num_samples_per_trial, :] = fake_samples analysis.evaluate_approx_distribution(X=fake_samples_mat.T, folder=FLAGS.sample_dir, num_samples_theoretical_distr=2**21,num_bins=FLAGS.num_bins, num_neurons=FLAGS.num_neurons,\ group_size=FLAGS.group_size,refr_per=FLAGS.ref_period) #COMPARISON WITH K-PAIRWISE AND DG MODELS (only for retinal data) if FLAGS.dataset == 'retina': print( 'nearest sample analysis -----------------------------------') num_samples = 100 #this is for the 'nearest sample' analysis (Fig. S5) print('real_samples') analysis.nearest_sample(X_real=real_samples, fake_samples=real_samples, num_neurons=FLAGS.num_neurons, num_bins=FLAGS.num_bins, folder=FLAGS.sample_dir, name='real', num_samples=num_samples) ################### print('fake_samples') analysis.nearest_sample(X_real=real_samples, fake_samples=fake_samples.T, num_neurons=FLAGS.num_neurons, num_bins=FLAGS.num_bins, folder=FLAGS.sample_dir, name='spikeGAN', num_samples=num_samples) ################### k_pairwise_samples = retinal_data.load_samples_from_k_pairwise_model( num_samples=FLAGS.num_samples, num_bins=FLAGS.num_bins, num_neurons=FLAGS.num_neurons, instance=FLAGS.data_instance, folder=os.getcwd() + '/data/retinal data/') print('k_pairwise_samples') _, _, _, _, _ = analysis.get_stats(X=k_pairwise_samples, num_neurons=FLAGS.num_neurons, num_bins=FLAGS.num_bins, folder=FLAGS.sample_dir, name='k_pairwise', instance=FLAGS.data_instance) analysis.nearest_sample(X_real=real_samples, fake_samples=k_pairwise_samples, num_neurons=FLAGS.num_neurons, num_bins=FLAGS.num_bins, folder=FLAGS.sample_dir, name='k_pairwise', num_samples=num_samples) ################### DDG_samples = retinal_data.load_samples_from_DDG_model( num_samples=FLAGS.num_samples, num_bins=FLAGS.num_bins, num_neurons=FLAGS.num_neurons, instance=FLAGS.data_instance, folder=os.getcwd() + '/data/retinal data/') print('DDG_samples') _, _, _, _, _ = analysis.get_stats(X=DDG_samples, num_neurons=FLAGS.num_neurons, num_bins=FLAGS.num_bins, folder=FLAGS.sample_dir, name='DDG', instance=FLAGS.data_instance) analysis.nearest_sample(X_real=real_samples, fake_samples=DDG_samples, num_neurons=FLAGS.num_neurons, num_bins=FLAGS.num_bins, folder=FLAGS.sample_dir, name='DDG', num_samples=num_samples)
def get_stats(X, num_neurons, num_bins, folder, name, firing_rate_mat=[], correlation_mat=[], activity_peaks=[], critic_cost=np.nan, instance='1', shuffled_index=[]): ''' compute spike trains spikes: spk-count mean and std, autocorrelogram and correlation mat if name!='real' then it compares the above stats with the original ones ''' X_binnarized = (X > np.random.random(X.shape)).astype(float) resave_real_data = False if name != 'real': original_data = np.load(folder + '/stats_real.npz') if any(k not in original_data for k in ("mean", "acf", "cov_mat", "k_probs", "lag_cov_mat", "firing_average_time_course")): if 'samples' not in original_data: samples = retinal_data.get_samples(num_bins=num_bins, num_neurons=num_neurons, instance=instance, folder=os.getcwd() + '/data/retinal data/') else: samples = original_data['samples'] cov_mat_real, k_probs_real, mean_spike_count_real, autocorrelogram_mat_real, firing_average_time_course_real, lag_cov_mat_real =\ get_stats_aux(samples, num_neurons, num_bins) assert np.all(autocorrelogram_mat_real == original_data['acf']) assert np.all(mean_spike_count_real == original_data['mean']) resave_real_data = True else: mean_spike_count_real, autocorrelogram_mat_real, firing_average_time_course_real, cov_mat_real, k_probs_real, lag_cov_mat_real = \ [original_data["mean"], original_data["acf"], original_data["firing_average_time_course"], original_data["cov_mat"], original_data["k_probs"], original_data["lag_cov_mat"]] cov_mat, k_probs, mean_spike_count, autocorrelogram_mat, firing_average_time_course, lag_cov_mat = get_stats_aux( X_binnarized, num_neurons, num_bins) variances = np.diag(cov_mat) only_cov_mat = cov_mat.copy() only_cov_mat[np.diag_indices(num_neurons)] = np.nan #PLOT index = np.linspace(-10, 10, 2 * 10 + 1) #figure for all training error across epochs (supp. figure 2) f, sbplt = plt.subplots(2, 3, figsize=(8, 8), dpi=250) matplotlib.rcParams.update({'font.size': 8}) plt.subplots_adjust(left=left, bottom=bottom, right=right, top=top, wspace=wspace, hspace=hspace) #plot autocorrelogram(s) sbplt[1][1].plot(index, autocorrelogram_mat, 'r') if name != 'real': sbplt[1][1].plot(index, autocorrelogram_mat_real, 'b') acf_error = np.sum( np.abs(autocorrelogram_mat - autocorrelogram_mat_real)) sbplt[1][1].set_title('Autocorrelogram') sbplt[1][1].set_xlabel('time (ms)') sbplt[1][1].set_ylabel('number of spikes') #plot mean firing rates if name != 'real': sbplt[0][0].plot([0, np.max(mean_spike_count_real)], [0, np.max(mean_spike_count_real)], 'k') sbplt[0][0].plot(mean_spike_count_real, mean_spike_count, '.g') mean_error = np.sum(np.abs(mean_spike_count - mean_spike_count_real)) sbplt[0][0].set_xlabel('mean firing rate expt') sbplt[0][0].set_ylabel('mean firing rate model') else: sbplt[0][0].plot(mean_spike_count, 'b') sbplt[0][0].set_xlabel('neuron') sbplt[0][0].set_ylabel('firing probability') sbplt[0][0].set_title('mean firing rates') #plot covariances if name != 'real': variances_real = np.diag(cov_mat_real) only_cov_mat_real = cov_mat_real.copy() only_cov_mat_real[np.diag_indices(num_neurons)] = np.nan sbplt[0][1].plot([np.nanmin(only_cov_mat_real.flatten()),np.nanmax(only_cov_mat_real.flatten())],\ [np.nanmin(only_cov_mat_real.flatten()),np.nanmax(only_cov_mat_real.flatten())],'k') sbplt[0][1].plot(only_cov_mat_real.flatten(), only_cov_mat.flatten(), '.g') sbplt[0][1].set_title('pairwise covariances') sbplt[0][1].set_xlabel('covariances expt') sbplt[0][1].set_ylabel('covariances model') corr_error = np.nansum( np.abs(only_cov_mat - only_cov_mat_real).flatten()) else: map_aux = sbplt[0][1].imshow(only_cov_mat, interpolation='nearest') f.colorbar(map_aux, ax=sbplt[0][1]) sbplt[0][1].set_title('covariance mat') sbplt[0][1].set_xlabel('neuron') sbplt[0][1].set_ylabel('neuron') #plot k-statistics if name != 'real': sbplt[1][0].plot([0, np.max(k_probs_real)], [0, np.max(k_probs_real)], 'k') sbplt[1][0].plot(k_probs_real, k_probs, '.g') k_probs_error = np.sum(np.abs(k_probs - k_probs_real)) sbplt[1][0].set_xlabel('k-probs expt') sbplt[1][0].set_ylabel('k-probs model') else: sbplt[1][0].plot(k_probs) sbplt[1][0].set_xlabel('K') sbplt[1][0].set_ylabel('probability') sbplt[1][0].set_title('k statistics') #plot average time course #firing_average_time_course[firing_average_time_course>0.048] = 0.048 map_aux = sbplt[0][2].imshow(firing_average_time_course, interpolation='nearest') f.colorbar(map_aux, ax=sbplt[0][2]) sbplt[0][2].set_title('sim firing time course') sbplt[0][2].set_xlabel('time (ms)') sbplt[0][2].set_ylabel('neuron') if name != 'real': map_aux = sbplt[1][2].imshow(firing_average_time_course_real, interpolation='nearest') f.colorbar(map_aux, ax=sbplt[1][2]) sbplt[1][2].set_title('real firing time course') sbplt[1][2].set_xlabel('time (ms)') sbplt[1][2].set_ylabel('neuron') time_course_error = np.sum( np.abs(firing_average_time_course - firing_average_time_course_real).flatten()) f.savefig(folder + 'stats_' + name + '_II.svg', dpi=600, bbox_inches='tight') plt.close(f) if name != 'real': #PLOT LAG COVARIANCES #figure for all training error across epochs (supp. figure 2) f, sbplt = plt.subplots(2, 2, figsize=(8, 8), dpi=250) matplotlib.rcParams.update({'font.size': 8}) plt.subplots_adjust(left=left, bottom=bottom, right=right, top=top, wspace=wspace, hspace=hspace) map_aux = sbplt[0][0].imshow(lag_cov_mat_real, interpolation='nearest') f.colorbar(map_aux, ax=sbplt[0][0]) sbplt[0][0].set_title('lag covariance mat expt') sbplt[0][0].set_xlabel('neuron') sbplt[0][0].set_ylabel('neuron shifted') map_aux = sbplt[1][0].imshow(lag_cov_mat, interpolation='nearest') f.colorbar(map_aux, ax=sbplt[1][0]) sbplt[1][0].set_title('lag covariance mat model') sbplt[1][0].set_xlabel('neuron') sbplt[1][0].set_ylabel('neuron shifted') lag_corr_error = np.nansum( np.abs(lag_cov_mat - lag_cov_mat_real).flatten()) sbplt[0][1].plot([np.min(lag_cov_mat_real.flatten()),np.max(lag_cov_mat_real.flatten())],\ [np.min(lag_cov_mat_real.flatten()),np.max(lag_cov_mat_real.flatten())],'k') sbplt[0][1].plot(lag_cov_mat_real, lag_cov_mat, '.g') sbplt[0][1].set_xlabel('lag cov real') sbplt[0][1].set_ylabel('lag cov model') sbplt[1][1].plot([np.min(variances_real.flatten()),np.max(variances.flatten())],\ [np.min(variances_real.flatten()),np.max(variances_real.flatten())],'k') sbplt[1][1].plot(variances_real.flatten(), variances.flatten(), '.g') sbplt[1][1].set_title('variances') sbplt[1][1].set_xlabel('variances expt') sbplt[1][1].set_ylabel('variances model') variance_error = np.nansum( np.abs(variances_real - variances).flatten()) f.savefig(folder + 'lag_covs_' + name + '_II.svg', dpi=600, bbox_inches='tight') plt.close(f) if name == 'real' and len(firing_rate_mat) > 0: #ground truth data but not real (retinal) data data = {'mean':mean_spike_count, 'acf':autocorrelogram_mat, 'cov_mat':cov_mat, 'samples':X, 'k_probs':k_probs,'lag_cov_mat':lag_cov_mat,\ 'firing_rate_mat':firing_rate_mat, 'correlation_mat':correlation_mat, 'activity_peaks':activity_peaks, 'shuffled_index':shuffled_index, 'firing_average_time_course':firing_average_time_course} np.savez(folder + '/stats_' + name + '.npz', **data) else: data = {'mean':mean_spike_count, 'acf':autocorrelogram_mat, 'cov_mat':cov_mat, 'k_probs':k_probs, 'firing_average_time_course':firing_average_time_course,\ 'critic_cost':critic_cost, 'lag_cov_mat':lag_cov_mat} np.savez(folder + '/stats_' + name + '.npz', **data) if resave_real_data: if 'firing_rate_mat' in original_data: data = {'mean':mean_spike_count_real, 'acf':autocorrelogram_mat_real, 'cov_mat':cov_mat_real, 'samples':samples, 'k_probs':k_probs_real,'lag_cov_mat':lag_cov_mat_real,\ 'firing_rate_mat':original_data['firing_rate_mat'], 'correlation_mat':original_data['correlation_mat'], 'activity_peaks':original_data['activity_peaks'],\ 'shuffled_index':original_data['shuffled_index'], 'firing_average_time_course':firing_average_time_course_real} else: data = {'mean':mean_spike_count_real, 'acf':autocorrelogram_mat_real, 'cov_mat':cov_mat_real, 'samples':samples, 'k_probs':k_probs_real,'lag_cov_mat':lag_cov_mat_real,\ 'firing_average_time_course':firing_average_time_course_real} np.savez(folder + '/stats_real.npz', **data) if name != 'real': errors_mat = {'acf_error':acf_error, 'mean_error':mean_error, 'corr_error':corr_error, 'time_course_error':time_course_error, 'k_probs_error':k_probs_error,\ 'variance_error':variance_error, 'lag_corr_error':lag_corr_error} np.savez(folder + '/errors_' + name + '.npz', **errors_mat) samples_fake = {'samples': X} np.savez(folder + '/samples_' + name[0:4] + '.npz', **samples_fake) return acf_error, mean_error, corr_error, time_course_error, k_probs_error
def main(_): #print parameters pp.pprint(flags.FLAGS.__flags) #folders if FLAGS.dataset == 'uniform': if FLAGS.architecture == 'fc': FLAGS.sample_dir = 'samples fc/' + 'dataset_' + FLAGS.dataset + '_num_samples_' + str(FLAGS.num_samples) +\ '_num_neurons_' + str(FLAGS.num_neurons) + '_num_bins_' + str(FLAGS.num_bins)\ + '_ref_period_' + str(FLAGS.ref_period) + '_firing_rate_' + str(FLAGS.firing_rate) + '_correlation_' + str(FLAGS.correlation) +\ '_group_size_' + str(FLAGS.group_size) + '_critic_iters_' + str(FLAGS.critic_iters) + '_lambda_' + str(FLAGS.lambd) +\ '_num_units_' + str(FLAGS.num_units) +\ '_iteration_' + FLAGS.iteration + '/' elif FLAGS.architecture == 'conv': FLAGS.sample_dir = 'samples conv/' + 'dataset_' + FLAGS.dataset + '_num_samples_' + str(FLAGS.num_samples) +\ '_num_neurons_' + str(FLAGS.num_neurons) + '_num_bins_' + str(FLAGS.num_bins)\ + '_ref_period_' + str(FLAGS.ref_period) + '_firing_rate_' + str(FLAGS.firing_rate) + '_correlation_' + str(FLAGS.correlation) +\ '_group_size_' + str(FLAGS.group_size) + '_critic_iters_' + str(FLAGS.critic_iters) + '_lambda_' + str(FLAGS.lambd) +\ '_num_layers_' + str(FLAGS.num_layers) + '_num_features_' + str(FLAGS.num_features) + '_kernel_' + str(FLAGS.kernel_width) +\ '_iteration_' + FLAGS.iteration + '/' elif FLAGS.dataset == 'packets': if FLAGS.architecture == 'fc': FLAGS.sample_dir = 'samples fc/' + 'dataset_' + FLAGS.dataset + '_num_samples_' + str(FLAGS.num_samples) +\ '_num_neurons_' + str(FLAGS.num_neurons) + '_num_bins_' + str(FLAGS.num_bins) + '_packet_prob_' + str(FLAGS.packet_prob)\ + '_firing_rate_' + str(FLAGS.firing_rate) + '_group_size_' + str(FLAGS.group_size) + '_noise_in_packet_' + str(FLAGS.noise_in_packet) + '_number_of_modes_' + str(FLAGS.number_of_modes) + '_critic_iters_' +\ str(FLAGS.critic_iters) + '_lambda_' + str(FLAGS.lambd) + '_num_units_' + str(FLAGS.num_units) +\ '_iteration_' + FLAGS.iteration + '/' elif FLAGS.architecture == 'conv': FLAGS.sample_dir = 'samples conv/' + 'dataset_' + FLAGS.dataset + '_num_samples_' + str(FLAGS.num_samples) +\ '_num_neurons_' + str(FLAGS.num_neurons) + '_num_bins_' + str(FLAGS.num_bins) + '_packet_prob_' + str(FLAGS.packet_prob)\ + '_firing_rate_' + str(FLAGS.firing_rate) + '_group_size_' + str(FLAGS.group_size) + '_noise_in_packet_' + str(FLAGS.noise_in_packet) + '_number_of_modes_' + str(FLAGS.number_of_modes) + '_critic_iters_' +\ str(FLAGS.critic_iters) + '_lambda_' + str(FLAGS.lambd) +\ '_num_layers_' + str(FLAGS.num_layers) + '_num_features_' + str(FLAGS.num_features) + '_kernel_' + str(FLAGS.kernel_width) +\ '_iteration_' + FLAGS.iteration + '/' elif FLAGS.dataset == 'retina': if FLAGS.architecture == 'fc': FLAGS.sample_dir = 'samples fc/' + 'dataset_' + FLAGS.dataset +\ '_num_neurons_' + str(FLAGS.num_neurons) + '_num_bins_' + str(FLAGS.num_bins)\ + '_critic_iters_' + str(FLAGS.critic_iters) + '_lambda_' + str(FLAGS.lambd) +\ '_num_units_' + str(FLAGS.num_units) +\ '_iteration_' + FLAGS.iteration + '/' elif FLAGS.architecture == 'conv': FLAGS.sample_dir = 'samples conv/' + 'dataset_' + FLAGS.dataset + '_num_samples_' + str(FLAGS.num_samples) +\ '_num_neurons_' + str(FLAGS.num_neurons) + '_num_bins_' + str(FLAGS.num_bins)\ + '_critic_iters_' + str(FLAGS.critic_iters) + '_lambda_' + str(FLAGS.lambd) +\ '_num_layers_' + str(FLAGS.num_layers) + '_num_features_' + str(FLAGS.num_features) + '_kernel_' + str(FLAGS.kernel_width) +\ '_iteration_' + FLAGS.iteration + '/' FLAGS.checkpoint_dir = FLAGS.sample_dir + 'checkpoint/' with tf.Session() as sess: wgan = WGAN_conv(sess, num_neurons=FLAGS.num_neurons, num_bins=FLAGS.num_bins, num_layers=FLAGS.num_layers, num_features=FLAGS.num_features, kernel_width=FLAGS.kernel_width, lambd=FLAGS.lambd, batch_size=FLAGS.batch_size, checkpoint_dir=FLAGS.checkpoint_dir, sample_dir=FLAGS.sample_dir) if not wgan.load(FLAGS.training_stage): raise Exception("[!] Train a model first, then run test mode") num_samples = 8000 if FLAGS.dataset == 'retina': samples = retinal_data.get_samples(num_bins=FLAGS.num_bins, num_neurons=FLAGS.num_neurons, instance=FLAGS.data_instance, folder=os.getcwd() + '/data/retinal data/').T else: original_dataset = np.load(FLAGS.sample_dir + '/stats_real.npz') if FLAGS.number_of_modes == 1: _ = sim_pop_activity.spike_train_transient_packets(num_samples=num_samples, num_bins=FLAGS.num_bins, num_neurons=FLAGS.num_neurons, group_size=FLAGS.group_size,\ prob_packets=FLAGS.packet_prob,firing_rates_mat=original_dataset['firing_rate_mat'], refr_per=FLAGS.ref_period,\ shuffled_index=original_dataset['shuffled_index'], limits=[0,64], groups=[0,1,2,3], folder=FLAGS.sample_dir, save_packet=True).T samples = sim_pop_activity.spike_train_transient_packets(num_samples=num_samples, num_bins=FLAGS.num_bins, num_neurons=FLAGS.num_neurons, group_size=FLAGS.group_size,\ prob_packets=0.2,firing_rates_mat=original_dataset['firing_rate_mat'], refr_per=FLAGS.ref_period,\ shuffled_index=original_dataset['shuffled_index'], limits=[16,32], groups=[0], folder=FLAGS.sample_dir, save_packet=False).T elif FLAGS.number_of_modes == 2: samples = original_dataset['samples'].T num_samples = np.min([num_samples, samples.shape[0]]) stim1_samples = np.zeros( (int(num_samples / 2), FLAGS.num_neurons, FLAGS.num_bins)) stim2_samples = np.zeros( (int(num_samples / 2), FLAGS.num_neurons, FLAGS.num_bins)) for ind_s in range(int(num_samples / 2)): stim1_samples[ind_s,:,:] ,_ = sim_pop_activity.spike_train_packets(num_bins=FLAGS.num_bins, num_neurons=FLAGS.num_neurons, group_size=FLAGS.group_size, firing_rates_mat=original_dataset['firing_rate_mat'], \ refr_per=FLAGS.ref_period, noise=FLAGS.noise_in_packet, number_of_modes=FLAGS.number_of_modes, save_sample=True, folder=FLAGS.sample_dir, mode=0) stim2_samples[ind_s,:,:],_ = sim_pop_activity.spike_train_packets(num_bins=FLAGS.num_bins, num_neurons=FLAGS.num_neurons, group_size=FLAGS.group_size, firing_rates_mat=original_dataset['firing_rate_mat'], \ refr_per=FLAGS.ref_period, noise=FLAGS.noise_in_packet, number_of_modes=FLAGS.number_of_modes, save_sample=True, folder=FLAGS.sample_dir, mode=1) packets = { 'packets_stim1': stim1_samples, 'packets_stim2': stim2_samples } np.savez(FLAGS.sample_dir + 'packets.npz', **packets) sim_pop_activity.spike_train_packets(num_bins=FLAGS.num_bins, num_neurons=FLAGS.num_neurons, group_size=FLAGS.group_size, firing_rates_mat=original_dataset['firing_rate_mat'], \ refr_per=FLAGS.ref_period, noise=0, number_of_modes=FLAGS.number_of_modes, save_sample=True, folder=FLAGS.sample_dir, mode=0) sim_pop_activity.spike_train_packets(num_bins=FLAGS.num_bins, num_neurons=FLAGS.num_neurons, group_size=FLAGS.group_size, firing_rates_mat=original_dataset['firing_rate_mat'], \ refr_per=FLAGS.ref_period, noise=0, number_of_modes=FLAGS.number_of_modes, save_sample=True, folder=FLAGS.sample_dir, mode=1) inputs = tf.placeholder( tf.float32, name='inputs_to_discriminator', shape=[None, FLAGS.num_neurons * FLAGS.num_bins]) score = wgan.get_critics_output(inputs) step = FLAGS.step pattern_size = FLAGS.pattern_size times = step * np.arange(int(FLAGS.num_bins / step)) times = np.delete(times, np.nonzero(times > FLAGS.num_bins - pattern_size)) importance_time_vector = np.zeros((num_samples, FLAGS.num_bins)) importance_neuron_vector = np.zeros((num_samples, FLAGS.num_neurons)) grad_maps = np.zeros((num_samples, FLAGS.num_neurons, FLAGS.num_bins)) activity_map = np.zeros((FLAGS.num_neurons, FLAGS.num_bins)) importance_time_vector_surr = np.zeros((num_samples, FLAGS.num_bins)) importance_neuron_vector_surr = np.zeros( (num_samples, FLAGS.num_neurons)) sample_diff = np.zeros((FLAGS.num_neurons, FLAGS.num_bins)) samples = samples[0:num_samples, :] for i in range(num_samples): sample = samples[i, :] time0 = time.time() #get importance map for sample and compute the averages across time and space grad_maps[i, :, :], _, sample_diff_aux = patterns_relevance( sample, FLAGS.num_neurons, score, inputs, sess, pattern_size, times) time1 = time.time() importance_time_vector[i, :] = np.mean(grad_maps[i, :, :], axis=0) importance_neuron_vector[i, :] = np.mean(grad_maps[i, :, :], axis=1) sample_diff += sample_diff_aux #compute surrogate data (not used in the IClR paper) aux, _, _ = patterns_relevance(sample, FLAGS.num_neurons, score, inputs, sess, pattern_size, times, shuffle=True) time1 = time.time() importance_time_vector_surr[i, :] = np.mean(aux, axis=0) importance_neuron_vector_surr[i, :] = np.mean(aux, axis=1) sample = sample.reshape(FLAGS.num_neurons, -1) activity_map += sample print(str(i) + ' time ' + str(time1 - time0)) stimulus_id = np.load(FLAGS.sample_dir + '/stim.npz')['stimulus'] stimulus_id = stimulus_id[0:num_samples] predicted_packets = analysis.find_packets(grad_maps, samples, FLAGS.num_neurons, FLAGS.num_bins, FLAGS.sample_dir, num_samples) ground_truth_packets_stim1 = np.mean(analysis.find_packets( stim1_samples, stim1_samples - 1, FLAGS.num_neurons, FLAGS.num_bins, FLAGS.sample_dir, int(num_samples / 2), plot_fig=False), axis=0) ground_truth_packets_stim2 = np.mean(analysis.find_packets( stim2_samples, stim2_samples - 1, FLAGS.num_neurons, FLAGS.num_bins, FLAGS.sample_dir, int(num_samples / 2), plot_fig=False), axis=0) importance_vectors = {'time':importance_time_vector, 'neurons':importance_neuron_vector, 'grad_maps':grad_maps, 'samples':samples, 'activity_map':activity_map,\ 'time_surr':importance_time_vector_surr, 'neurons_surr':importance_neuron_vector_surr, 'sample_diff':sample_diff, 'predicted_packets':predicted_packets,\ 'ground_truth_packets_stim1':ground_truth_packets_stim1, 'ground_truth_packets_stim2':ground_truth_packets_stim2} np.savez( FLAGS.sample_dir + 'importance_vectors_' + str(step) + '_' + str(pattern_size) + '_' + str(num_samples) + '.npz', **importance_vectors)
def generate_spike_trains(config, recovery_dir): if config.dataset == 'uniform': if recovery_dir != "": aux = np.load(recovery_dir + '/stats_real.npz') real_samples = aux['samples'] firing_rates_mat = aux['firing_rate_mat'] correlations_mat = aux['correlation_mat'] activity_peaks = aux['activity_peaks'] shuffled_index = aux['shuffled_index'] else: #we shuffle neurons to test if the network still learns the packets shuffled_index = np.arange(config.num_neurons) np.random.shuffle(shuffled_index) firing_rates_mat = config.firing_rate + 2 * ( np.random.random(int(config.num_neurons / config.group_size), ) - 0.5) * config.firing_rate / 2 correlations_mat = config.correlation + 2 * ( np.random.random(int(config.num_neurons / config.group_size), ) - 0.5) * config.correlation / 2 #peaks of activity #sequence response aux = np.arange(int(config.num_neurons / config.group_size)) activity_peaks = [ [x] * config.group_size for x in aux ] #np.random.randint(0,high=config.num_bins,size=(1,config.num_neurons)).reshape(config.num_neurons,1) activity_peaks = np.asarray(activity_peaks) activity_peaks = activity_peaks.flatten() activity_peaks = activity_peaks * config.group_size * config.num_bins / config.num_neurons activity_peaks = activity_peaks.reshape(config.num_neurons, 1) #peak of activity equal for all neurons #activity_peaks = np.zeros((config.num_neurons,1))+config.num_bins/4 real_samples = sim_pop_activity.get_samples(num_samples=config.num_samples, num_bins=config.num_bins,\ num_neurons=config.num_neurons, correlations_mat=correlations_mat, group_size=config.group_size, shuffled_index=shuffled_index,\ refr_per=config.ref_period,firing_rates_mat=firing_rates_mat, activity_peaks=activity_peaks, folder=config.sample_dir) #save original statistics analysis.get_stats(X=real_samples, num_neurons=config.num_neurons, num_bins=config.num_bins, folder=config.sample_dir, shuffled_index=shuffled_index,\ name='real',firing_rate_mat=firing_rates_mat, correlation_mat=correlations_mat, activity_peaks=activity_peaks) #get dev samples dev_samples = sim_pop_activity.get_samples(num_samples=int(config.num_samples/4), num_bins=config.num_bins,\ num_neurons=config.num_neurons, correlations_mat=correlations_mat, group_size=config.group_size, shuffled_index=shuffled_index,\ refr_per=config.ref_period,firing_rates_mat=firing_rates_mat, activity_peaks=activity_peaks) elif config.dataset == 'packets': if recovery_dir != "": aux = np.load(recovery_dir + '/stats_real.npz') real_samples = aux['samples'] firing_rates_mat = aux['firing_rate_mat'] shuffled_index = aux['shuffled_index'] else: #we shuffle neurons to test if the network still learns the packets shuffled_index = np.arange(config.num_neurons) np.random.shuffle(shuffled_index) firing_rates_mat = config.firing_rate + 2 * (np.random.random( size=(config.num_neurons, 1)) - 0.5) * config.firing_rate / 2 real_samples = sim_pop_activity.get_samples(num_samples=config.num_samples, num_bins=config.num_bins, refr_per=config.ref_period,\ num_neurons=config.num_neurons, group_size=config.group_size, firing_rates_mat=firing_rates_mat, packets_on=True,\ prob_packets=config.packet_prob, shuffled_index=shuffled_index, folder=config.sample_dir) #save original statistics analysis.get_stats(X=real_samples, num_neurons=config.num_neurons, num_bins=config.num_bins, folder=config.sample_dir, name='real',\ firing_rate_mat=firing_rates_mat, shuffled_index=shuffled_index) #get dev samples dev_samples = sim_pop_activity.get_samples(num_samples=int(config.num_samples/4), num_bins=config.num_bins, refr_per=config.ref_period,\ num_neurons=config.num_neurons, group_size=config.group_size, firing_rates_mat=firing_rates_mat, packets_on=True,\ prob_packets=config.packet_prob,shuffled_index=shuffled_index) elif config.dataset == 'retina': real_samples = retinal_data.get_samples(num_bins=config.num_bins, num_neurons=config.num_neurons, instance=config.data_instance) #save original statistics analysis.get_stats(X=real_samples, num_neurons=config.num_neurons, num_bins=config.num_bins, folder=config.sample_dir, name='real', instance=config.data_instance) dev_samples = [] return real_samples, dev_samples
def generate_spike_trains(config, recovery_dir): ''' this function returns the training and dev sets, corresponding to the parameters provided in config ''' if config.dataset == 'uniform': if recovery_dir != "": aux = np.load(recovery_dir + '/stats_real.npz') real_samples = aux['samples'] firing_rates_mat = aux['firing_rate_mat'] correlations_mat = aux['correlation_mat'] shuffled_index = aux['shuffled_index'] else: #shuffle neurons shuffled_index = np.arange(config.num_neurons) np.random.shuffle(shuffled_index) firing_rates_mat = config.firing_rate + 2 * ( np.random.random(int(config.num_neurons / config.group_size), ) - 0.5) * config.firing_rate / 2 correlations_mat = config.correlation + 2 * ( np.random.random(int(config.num_neurons / config.group_size), ) - 0.5) * config.correlation / 2 #peaks of activity aux = np.arange(int(config.num_neurons / config.group_size)) #peak of activity equal for all neurons real_samples = sim_pop_activity.get_samples(num_samples=config.num_samples, num_bins=config.num_bins,\ num_neurons=config.num_neurons, correlations_mat=correlations_mat, group_size=config.group_size, shuffled_index=shuffled_index,\ refr_per=config.ref_period,firing_rates_mat=firing_rates_mat, folder=config.sample_dir) #save original statistics analysis.get_stats(X=real_samples, num_neurons=config.num_neurons, num_bins=config.num_bins, folder=config.sample_dir, shuffled_index=shuffled_index,\ name='real',firing_rate_mat=firing_rates_mat, correlation_mat=correlations_mat) #get dev samples dev_samples = sim_pop_activity.get_samples(num_samples=int(config.num_samples/4), num_bins=config.num_bins,\ num_neurons=config.num_neurons, correlations_mat=correlations_mat, group_size=config.group_size, shuffled_index=shuffled_index,\ refr_per=config.ref_period,firing_rates_mat=firing_rates_mat) elif config.dataset == 'packets': if recovery_dir != "": aux = np.load(recovery_dir + '/stats_real.npz') real_samples = aux['samples'] firing_rates_mat = aux['firing_rate_mat'] shuffled_index = aux['shuffled_index'] else: #shuffle the neurons shuffled_index = np.arange(config.num_neurons) np.random.shuffle(shuffled_index) firing_rates_mat = config.firing_rate + 2 * (np.random.random( size=(config.num_neurons, 1)) - 0.5) * config.firing_rate / 2 real_samples = sim_pop_activity.get_samples(num_samples=config.num_samples, num_bins=config.num_bins, refr_per=config.ref_period,\ num_neurons=config.num_neurons, group_size=config.group_size, firing_rates_mat=firing_rates_mat, packets_on=True,\ prob_packets=config.packet_prob, shuffled_index=shuffled_index, folder=config.sample_dir, noise_in_packet=config.noise_in_packet, number_of_modes=config.number_of_modes) #save original statistics analysis.get_stats(X=real_samples, num_neurons=config.num_neurons, num_bins=config.num_bins, folder=config.sample_dir, name='real',\ firing_rate_mat=firing_rates_mat, shuffled_index=shuffled_index) #get dev samples dev_samples = sim_pop_activity.get_samples(num_samples=int(config.num_samples/4), num_bins=config.num_bins, refr_per=config.ref_period,\ num_neurons=config.num_neurons, group_size=config.group_size, firing_rates_mat=firing_rates_mat, packets_on=True,\ prob_packets=config.packet_prob, shuffled_index=shuffled_index, noise_in_packet=config.noise_in_packet, number_of_modes=config.number_of_modes) elif config.dataset == 'retina': dirpath = os.getcwd() real_samples = retinal_data.get_samples(num_bins=config.num_bins, num_neurons=config.num_neurons, instance=config.data_instance, folder=dirpath + '/data/retinal data/') #save original statistics analysis.get_stats(X=real_samples, num_neurons=config.num_neurons, num_bins=config.num_bins, folder=config.sample_dir, name='real', instance=config.data_instance) dev_samples = [] return real_samples, dev_samples
def train(self, config): """Train DCGAN""" #define optimizer self.g_optim = tf.train.AdamOptimizer( learning_rate=config.learning_rate, beta1=config.beta1, beta2=config.beta2).minimize( self.gen_cost, var_list=params_with_name('Generator'), colocate_gradients_with_ops=True) self.d_optim = tf.train.AdamOptimizer( learning_rate=config.learning_rate, beta1=config.beta1, beta2=config.beta2).minimize( self.disc_cost, var_list=params_with_name('Discriminator.'), colocate_gradients_with_ops=True) #initizialize variables try: tf.global_variables_initializer().run() except: tf.initialize_all_variables().run() #try to load trained parameters self.load() #get real samples if config.dataset == 'uniform': firing_rates_mat = config.firing_rate + 2 * ( np.random.random(int(self.num_neurons / config.group_size), ) - 0.5) * config.firing_rate / 2 correlations_mat = config.correlation + 2 * ( np.random.random(int(self.num_neurons / config.group_size), ) - 0.5) * config.correlation / 2 aux = np.arange(int(self.num_neurons / config.group_size)) activity_peaks = [ [x] * config.group_size for x in aux ] #np.random.randint(0,high=self.num_bins,size=(1,self.num_neurons)).reshape(self.num_neurons,1) activity_peaks = np.asarray(activity_peaks) activity_peaks = activity_peaks.flatten() activity_peaks = activity_peaks * config.group_size * self.num_bins / self.num_neurons activity_peaks = activity_peaks.reshape(self.num_neurons, 1) #activity_peaks = np.zeros((self.num_neurons,1))+self.num_bins/4 self.real_samples = sim_pop_activity.get_samples(num_samples=config.num_samples, num_bins=self.num_bins,\ num_neurons=self.num_neurons, correlations_mat=correlations_mat, group_size=config.group_size, refr_per=config.ref_period,firing_rates_mat=firing_rates_mat, activity_peaks=activity_peaks) #get dev samples dev_samples = sim_pop_activity.get_samples(num_samples=int(config.num_samples/4), num_bins=self.num_bins,\ num_neurons=self.num_neurons, correlations_mat=correlations_mat, group_size=config.group_size, refr_per=config.ref_period,firing_rates_mat=firing_rates_mat, activity_peaks=activity_peaks) #save original statistics analysis.get_stats(X=self.real_samples, num_neurons=self.num_neurons, num_bins=self.num_bins, folder=self.sample_dir, name='real', firing_rate_mat=firing_rates_mat, correlation_mat=correlations_mat, activity_peaks=activity_peaks) elif config.dataset == 'retina': self.real_samples = retinal_data.get_samples( num_bins=self.num_bins, num_neurons=self.num_neurons, instance=config.data_instance) #save original statistics analysis.get_stats(X=self.real_samples, num_neurons=self.num_neurons, num_bins=self.num_bins, folder=self.sample_dir, name='real', instance=config.data_instance) #count number of variables total_parameters = 0 for variable in tf.trainable_variables(): # shape is an array of tf.Dimension shape = variable.get_shape() variable_parameters = 1 for dim in shape: variable_parameters *= dim.value total_parameters += variable_parameters print('number of varaibles: ' + str(total_parameters)) #start training counter_batch = 0 epoch = 0 #fitting errors f, sbplt = plt.subplots(2, 2, figsize=(8, 8), dpi=250) matplotlib.rcParams.update({'font.size': 8}) plt.subplots_adjust(left=left, bottom=bottom, right=right, top=top, wspace=wspace, hspace=hspace) for iteration in range(config.num_iter): start_time = time.time() # Train generator (only after the critic has been trained, at least once) if iteration > 0: _ = self.sess.run(self.g_optim) # Train critic disc_iters = config.critic_iters for i in range(disc_iters): #get batch and trained critic _data = self.real_samples[:, counter_batch * config.batch_size:(counter_batch + 1) * config.batch_size].T _disc_cost, _ = self.sess.run([self.disc_cost, self.d_optim], feed_dict={self.inputs: _data}) #if we have reached the end of the real samples set, we start over and increment the number of epochs if counter_batch == int( self.real_samples.shape[1] / self.batch_size) - 1: counter_batch = 0 epoch += 1 else: counter_batch += 1 aux = time.time() - start_time #plot the critics loss and the iteration time plot.plot(self.sample_dir, 'train disc cost', -_disc_cost) plot.plot(self.sample_dir, 'time', aux) if ( iteration == 500 ) or iteration % 20000 == 19999 or iteration > config.num_iter - 10: print('epoch ' + str(epoch)) if config.dataset == 'uniform': #this is to evaluate whether the discriminator has overfit dev_disc_costs = [] for ind_dev in range( int(dev_samples.shape[1] / self.batch_size)): images = dev_samples[:, ind_dev * config.batch_size:(ind_dev + 1) * config.batch_size].T _dev_disc_cost = self.sess.run( self.disc_cost, feed_dict={self.inputs: images}) dev_disc_costs.append(_dev_disc_cost) #plot the dev loss plot.plot(self.sample_dir, 'dev disc cost', -np.mean(dev_disc_costs)) #save the network parameters self.save(iteration) #get simulated samples, calculate their statistics and compare them with the original ones fake_samples = self.get_samples(num_samples=2**13) fake_samples = fake_samples.eval(session=self.sess) fake_samples = self.binarize(samples=fake_samples) acf_error, mean_error, corr_error, time_course_error,_ = analysis.get_stats(X=fake_samples.T, num_neurons=config.num_neurons,\ num_bins=config.num_bins, folder=config.sample_dir, name='fake'+str(iteration), critic_cost=-_disc_cost,instance=config.data_instance) #plot the fitting errors sbplt[0][0].plot(iteration, mean_error, '+b') sbplt[0][0].set_title('spk-count mean error') sbplt[0][0].set_xlabel('iterations') sbplt[0][0].set_ylabel('L1 error') sbplt[0][0].set_xlim([ 0 - config.num_iter / 4, config.num_iter + config.num_iter / 4 ]) sbplt[0][1].plot(iteration, time_course_error, '+b') sbplt[0][1].set_title('time course error') sbplt[0][1].set_xlabel('iterations') sbplt[0][1].set_ylabel('L1 error') sbplt[0][1].set_xlim([ 0 - config.num_iter / 4, config.num_iter + config.num_iter / 4 ]) sbplt[1][0].plot(iteration, acf_error, '+b') sbplt[1][0].set_title('AC error') sbplt[1][0].set_xlabel('iterations') sbplt[1][0].set_ylabel('L1 error') sbplt[1][0].set_xlim([ 0 - config.num_iter / 4, config.num_iter + config.num_iter / 4 ]) sbplt[1][1].plot(iteration, corr_error, '+b') sbplt[1][1].set_title('corr error') sbplt[1][1].set_xlabel('iterations') sbplt[1][1].set_ylabel('L1 error') sbplt[1][1].set_xlim([ 0 - config.num_iter / 4, config.num_iter + config.num_iter / 4 ]) f.savefig(self.sample_dir + 'fitting_errors.svg', dpi=600, bbox_inches='tight') plt.close(f) plot.flush(self.sample_dir) plot.tick()