def __init__(self, activation, checkpoint_dir, trainingset_dir, w_reconstruction=[1., 1.], w_decision=[1., 1.], w_reg=1e-10, w_rate=0, beta0=None, beta_steps=200000, sampling_off=False, dataset='attention_search_both2', input_distribution_dim=-1, batch_size=64, dropout_prob=0.5, regenerate_steps=1000, dataset_size=5000): if "attention" not in dataset: raise Exception( "{} is not an acceptable training task.".format(dataset)) self.name = "popout" activation_dict = { "relu": tf.nn.relu, "sigmoid": tf.nn.sigmoid, "tanh": tf.nn.tanh, "elu": tf.nn.elu } self.dataset = dataset self.dataset_size = int(dataset_size) self.iterator = None self.next_element = None self.batch_size = batch_size self.regenerate_steps = regenerate_steps self.decision_dim = 2 # Activation function self.activation = activation_dict[activation] self.image_width = 32 self.RGB = True self.image_channels = 3 self.latent_size = 500 # Size of hidden layer(s) self.w_rate = w_rate # Initialize to this value, but possibly change self.beta1 = w_rate # Final value for rate_loss_weight self.beta0 = beta0 # Initial value for rate_loss_weight if self.beta0 is not None and self.beta0 > self.beta1: raise Exception("beta1 must be greater than beta0") self.beta_steps = beta_steps self.w_reconstruction_enc, self.w_reconstruction_dec = w_reconstruction self.w_decision_enc, self.w_decision_dec = w_decision self.w_reg = w_reg self.dropout_prob = dropout_prob self.sampling_off = sampling_off self.checkpoint_top_dir = Path(checkpoint_dir) self.trainingset_dir = Path(trainingset_dir) self.checkpoint_dir = make_memnet_checkpoint_dir(checkpoint_dir, self) print("Using {} activation".format(activation)) print("Loss weights (rate, reconstruction, decision): ", (self.w_rate, w_reconstruction, w_decision)) self.sess = None
def visualize_network_samples_grid_1D(net, start=5, stop=100, step=10, f_ext="pdf"): """ [PLANTS STIMULI ONLY] Like the 2D version above, but just a single line of images (just one stimulus dimension). """ save_dir = make_memnet_checkpoint_dir( Path("plots/vis_memnet_samples/grid/"), net) if not save_dir.exists(): save_dir.makedirs() relevant_vals = range(start, stop, step) grid_width = len(relevant_vals) print("Loading plants.") inputs0, stim_vals0 = load_plants(net.image_width, net.trainingset_dir, layer_type=net.layer_type, normalize=True) inputs0 = inputs0 / 255. # Normalize print("Finished.") irrelevant_dim = int(net.input_distribution_dim == 0) if irrelevant_dim == 0: inds = list(product([50], relevant_vals)) else: inds = list(product(relevant_vals, [50])) inputs_valid = inputs0[stim_vals0[:, irrelevant_dim] == 50] stim_vals = stim_vals0[stim_vals0[:, irrelevant_dim] == 50] fig, axes = plt.subplots(1, grid_width, figsize=(50, 50 / len(relevant_vals))) preds = [] inputs = [] print("Making plots...") for i in range(len(stim_vals)): if tuple(stim_vals[i]) in inds: print("Making ", stim_vals[i]) inputs.append(inputs_valid[i]) preds = net.predict(inputs, keep_session=True) for i in range(len(preds)): ax = axes[i] ax.imshow(preds[i:i + 1].reshape(net.image_width, net.image_width), cmap="gray_r") ax.set_xticklabels("") ax.set_xticks([]) ax.set_xticklabels("") ax.set_yticks([]) ax.set_yticklabels("") plt.savefig(save_dir.joinpath('grid.' + f_ext)) print("Saved to ", save_dir.joinpath('grid.' + f_ext)) return
def plants_average_reconstruction(net, plant_coords=[[0, 0], [0, 99], [99, 99], [99, 0]], n_samples=100, f_ext="pdf"): """For the specified plants (given by `plant_coords`), plot the target image alongside the mean reconstructed image, averaged over `n_samples` samples. """ images, stim_vals = load_plants(net.image_width, net.trainingset_dir, layer_type=net.layer_type, normalize=True) save_dir = make_memnet_checkpoint_dir( Path("plots/plants_average_reconstruction/"), net) if not save_dir.exists(): save_dir.mkdir() for coord in plant_coords: # Iterate over all stimulus values for target image i_target = np.logical_and(coord[0] == stim_vals[:, 0], coord[1] == stim_vals[:, 1]) x = images[i_target] X = np.repeat(x, n_samples, axis=0) # Take n samples Xhat = net.predict(X, keep_session=True) mean_reconstruction = Xhat.mean(axis=0) # Take mean over samples fig, ax = plt.subplots(nrows=1, ncols=2, figsize=(5.6, 3.1)) ax[0].imshow(x.reshape(net.image_width, net.image_width), cmap="gray_r") ax[1].imshow(mean_reconstruction.reshape(net.image_width, net.image_width), cmap="gray_r") ax[0].set_xlabel("Target", fontsize=14) ax[1].set_xlabel("Mean reconstruction", fontsize=14) ax[0].set_xticks([]) ax[1].set_xticks([]) ax[0].set_xticklabels("") ax[1].set_xticklabels("") ax[0].set_yticks([]) ax[1].set_yticks([]) ax[0].set_yticklabels("") ax[1].set_yticklabels("") fig.suptitle("Leaf width {}, leaf angle {}".format(coord[0], coord[1])) plt.tight_layout() # plt.subplots_adjust(top=0.9) save_pth = save_dir.joinpath("width{}_angle{}.{}".format( coord[0], coord[1], f_ext)) plt.savefig(save_pth, dpi=300) print("Saved to " + save_dir) return
def visualize_network_samples_grid(net, start=5, stop=100, step=10, f_ext="pdf"): """ [PLANTS STIMULI ONLY] Plot grid of images, where (x, y) position corresponds to the stimulus values (width, droop) of the input to the network. Images are the outputs of the network given the input. """ save_dir = make_memnet_checkpoint_dir( Path("plots/vis_memnet_samples/grid/"), net) if not save_dir.exists(): save_dir.makedirs() inds = list(product(range(start, stop, step), range(start, stop, step))) grid_width = len(range(start, stop, step)) print("Loading plants.") inputs, stim_vals = load_plants(net.image_width, net.trainingset_dir, layer_type=net.layer_type, normalize=True) print("Finished.") fig, axes = plt.subplots(grid_width, grid_width, figsize=(50, 50)) preds = [] print("Making plots...") for i in range(len(stim_vals)): if tuple(stim_vals[i]) in inds: print("Making ", stim_vals[i]) preds.append( net.predict(inputs[i:i + 1], keep_session=True).reshape(net.image_width, net.image_width)) grid_coord = tuple(stim_vals[i] / step) ax = axes[int(grid_coord[1]), int(grid_coord[0])] # ax = axes[tuple(stim_vals[i] / step)] ax.imshow(preds[-1], cmap="gray_r") ax.set_xticklabels("") ax.set_xticks([]) ax.set_xticklabels("") ax.set_yticks([]) ax.set_yticklabels("") plt.savefig(save_dir.joinpath('grid.' + f_ext)) print("Saved to ", save_dir.joinpath('grid.' + f_ext)) return
def correlation(net, start=0, stop=100, step=10, n_samples=2, f_ext="pdf"): """An analysis for networks trained with non-uniform input distributions. (Designed for 'plants' experiment) Correlate reconstructed images with all original stimulus images. Since one plant dimension will be uniform and one will be non-uniform, marginalize over the non-uniform dimension n_samples: Number of network samples to average over """ from scipy.stats import pearsonr from sklearn.metrics import auc # Save directory save_dir = make_memnet_checkpoint_dir(Path("plots/correlation/"), net) if not save_dir.exists(): save_dir.mkdir() dim = net.input_distribution_dim if dim == -1: raise Exception other_dim = 1 if dim == 0 else 0 images, stim_vals = load_plants(net.image_width, net.trainingset_dir, layer_type=net.layer_type, normalize=True) target_selection = range(start, stop, step) # Select only subset of values probe_selection = range(0, 100, 1) unique_stim_vals = 100 corrs = np.zeros((len(target_selection), len(probe_selection))) xhats = np.zeros((len(target_selection), len(probe_selection), unique_stim_vals, images.shape[1])) A_left = [] A_right = [] for i, dim_t in enumerate(target_selection): # Iterate over all stimulus values for target image inds_x = [ np.logical_and(dim_t == stim_vals[:, dim], d == stim_vals[:, other_dim]) for d in range(unique_stim_vals) ] x = np.vstack([images[ix] for ix in inds_x]) for j, dim_p in enumerate(probe_selection): # Iterate over all stimulus values for probe image inds_y = [ np.logical_and(dim_p == stim_vals[:, dim], d == stim_vals[:, other_dim]) for d in range(unique_stim_vals) ] y = np.vstack([images[iy] for iy in inds_y]) xhats[i, j] = np.mean( [net.predict(x, keep_session=True) for _ in range(n_samples)], axis=0) corrs[i, j] = np.mean([ pearsonr(xhats[i, j][k], y[k])[0] for k in range(unique_stim_vals) ]) print(i, j) i_true_target_val = probe_selection.index( dim_t # Note, probe_selection must include all the target values for this # to work ) rAleft = corrs[i, max(0, i_true_target_val - 10):i_true_target_val + 1] rAright = corrs[i, i_true_target_val:i_true_target_val + 11] if len(rAleft) > 1: A_left.append(auc(range(len(rAleft)), rAleft)) else: A_left.append(0) if len(rAright) > 1: A_right.append(auc(range(len(rAright)), rAright)) else: A_right.append(0) cutoff = 100 fig, ax = plt.subplots() for i in range(len(target_selection)): ind_t = probe_selection.index(target_selection[i]) line = ax.plot(probe_selection[max(0, ind_t - cutoff):ind_t + cutoff], corrs[i][max(0, ind_t - cutoff):ind_t + cutoff]) color = line[0].get_color() ax.axvline(x=target_selection[i], color=color, linestyle="--") ax.set_ylim(0.6, 1) ax.set_xlim(0, 100) ax.set_xlabel("stimulus value", fontsize=16) ax.set_ylabel("correlation coefficient", fontsize=16) save_pth = save_dir.joinpath("correlation_dim{}_mean{}_std{}.{}".format( net.input_distribution_dim, net.input_mean, net.input_std, f_ext)) plt.savefig(save_pth) print("Saved to ", save_pth) return
def visualize_network_samples(net, outputs_per_example=5, color_map="gray", setsize=[1, 6], f_ext="pdf"): """Make grid of sample images from decoder """ n_examples = net.batch_size if net.name == "cifar": inputs, targets = generate_training_data( None, None, dataset=net.dataset, train_test="test", conv=True, data_dir=net.trainingset_dir)[0:2] inputs, targets = net.get_batch(inputs, targets, None)[0:2] sig = True # For predicting decisions elif net.name == "fruits": inputs, recall_targets = net.generate_batch("test") targets = inputs sig = True elif net.name == "popout": inputs, recall_targets = net.generate_training_data(n_examples) targets = inputs sig = False else: inputs = generate_training_data(n_examples, net.image_width, task_weights=net.task_weights, dataset=net.dataset, conv=net.layer_type == "conv", data_dir=net.trainingset_dir, setsize=setsize, mean=net.input_mean, std=net.input_std, dim=net.input_distribution_dim)[0] targets = inputs sig = False if net.image_channels > 1: shp = [net.image_width, net.image_width, net.image_channels] else: shp = [net.image_width, net.image_width] preds0, dec0 = zip(*[ net.predict_both( inputs, np.zeros_like(inputs), sigmoid=sig, keep_session=True) for i in range(outputs_per_example) ]) preds1 = np.array(preds0) preds2 = preds1.swapaxes(0, 1) preds = preds2.reshape([n_examples, outputs_per_example] + shp) preds = preds.clip(0, 1) # dec = np.array(dec0).swapaxes(0, 1) # print("decisions: ", dec) # print(dec) save_dir = make_memnet_checkpoint_dir(Path("plots/vis_memnet_samples/"), net) if not save_dir.exists(): save_dir.makedirs() print("Saving images to: ", save_dir) for i in range(n_examples): fig, axes = plt.subplots(nrows=1, ncols=outputs_per_example + 1, figsize=(45, 10)) if net.RGB: axes[0].imshow(targets[i], cmap=color_map) else: axes[0].imshow(targets[i].reshape(net.image_width, net.image_width), cmap=color_map) [ axes[j + 1].imshow(preds[i][j], cmap=color_map) for j in range(outputs_per_example) ] [ax.set_aspect('equal') for ax in axes] xlabs = ["Target"] + [ "Sample {}".format(j) for j in range(1, outputs_per_example + 1) ] [ax.set_xlabel(l, size=50) for ax, l in zip(axes, xlabs)] for ax in axes: ax.set_xticks([]) ax.set_xticklabels("") ax.set_yticks([]) ax.set_yticklabels("") plt.tight_layout() plt.savefig( save_dir.joinpath('setsize_{}_sample{}.{}'.format( "_".join([str(x) for x in setsize]), str(i).zfill(2), f_ext)))
def __init__(self, hidden_size, latent_size, activation, checkpoint_dir, trainingset_dir, decision_size=100, encoder_layers=1, decoder_layers=1, decision_layers=1, kernel_size=3, load_decision_weights=False, task_weights=None, w_reconstruction=[1., 1.], w_decision=[1., 1.], w_reg=1e-10, w_rate=0, beta0=None, beta_steps=200000, sampling_off=False, encode_probe=False, dataset='rectangle', image_width=100, RGB=False, layer_type="MLP", decision_target="same_different", loss_func_dec="squared_error", loss_func_recon=None, dataset_size=int(1e4), regenerate_steps=None, input_std=10, input_mean=50, input_distribution_dim=-1, sample_distribution="gaussian", decision_dim=1, image_channels=3, batch_size=20, dropout_prob=1.0): if "float" in dataset and layer_type == "conv": raise NotImplementedError if dataset == "letter" and len(task_weights) != 4: raise Exception if layer_type == "conv" and hidden_size % 2: raise Exception("Kernel size must be odd number") self.name = "VAE" activation_dict = { "relu": tf.nn.relu, "sigmoid": tf.nn.sigmoid, "tanh": tf.nn.tanh, "elu": tf.nn.elu } self.dataset = dataset self.dataset_size = int(dataset_size) self.regenerate_steps = regenerate_steps self.batch_size = batch_size self.input_mean = input_mean self.input_std = input_std self.input_distribution_dim = input_distribution_dim if task_weights is None: if decision_dim is None: raise Exception( "Need to specify either task_weights or decision_dim") else: task_weights = [1] * decision_dim else: if hasattr(task_weights, '__len__'): if decision_dim is None: decision_dim = len(task_weights) elif not decision_dim == len(task_weights): raise Exception( "decision_dim and task_weights mismatched for length") else: raise Exception("task_weights must be list or tuple") self.task_weights = np.array(task_weights) self.decision_dim = decision_dim self.layer_type = layer_type self.loss_func_dec = loss_func_dec self.decision_target = decision_target self.logistic_decision = (False if self.decision_target == "tp_dist" else True) if dataset in ["gabor_array", "plants_setsize"]: # Hardcode for now to avoid errors self.decision_dim = 6 self.decision_target = "recall" for i in range(1, 7): # Hardcode for now to avoid errors if dataset in [ "gabor_array{}".format(i), "plants_setsize{}".format(i) ]: self.decision_dim = i self.decision_target = "recall" if loss_func_recon is None: self.loss_func_recon = "squared_error" else: self.loss_func_recon = loss_func_recon if dataset == "episodic_shape_race": self.decision_dim = 10 elif dataset == "episodic_setsize": self.decision_dim = 12 # Activation function self.activation = activation_dict[activation] self.image_width = image_width self.RGB = RGB if self.RGB: self.image_channels = 3 else: self.image_channels = 1 # Size of hidden layer(s) self.hidden_size = hidden_size self.kernel_size = kernel_size # Used if layers are convolutional self.encoder_layers = encoder_layers self.decoder_layers = decoder_layers self.decision_layers = decision_layers self.latent_size = latent_size self.decision_size = decision_size self.w_rate = float( w_rate) # Initialize to this value, but possibly change self.beta1 = w_rate # Final value for rate_loss_weight self.beta0 = beta0 # Initial value for rate_loss_weight if self.beta0 is not None and self.beta0 > self.beta1: raise Exception("beta1 must be greater than beta0") self.beta_steps = beta_steps self.w_reconstruction_enc, self.w_reconstruction_dec = w_reconstruction self.w_reconstruction_enc = float(self.w_reconstruction_enc) self.w_reconstruction_dec = float(self.w_reconstruction_dec) self.w_decision_enc, self.w_decision_dec = w_decision self.w_decision_enc = float(self.w_decision_enc) self.w_decision_dec = float(self.w_decision_dec) self.w_reg = float(w_reg) self.dropout_prob = dropout_prob self.sampling_off = sampling_off self.sample_distribution = sample_distribution self.encode_probe = encode_probe self.checkpoint_top_dir = Path(checkpoint_dir) self.trainingset_dir = Path(trainingset_dir) self.checkpoint_dir = make_memnet_checkpoint_dir(checkpoint_dir, self) self.decision_checkpoint_dir = None print("Using {} activation".format(activation)) print("Using {} loss function for reconstruction error".format( self.loss_func_recon)) print("Loss weights (rate, reconstruction, decision): ", (self.w_rate, w_reconstruction, w_decision)) self.sess = None