Пример #1
0
 def __init__(self,
              activation,
              checkpoint_dir,
              trainingset_dir,
              w_reconstruction=[1., 1.],
              w_decision=[1., 1.],
              w_reg=1e-10,
              w_rate=0,
              beta0=None,
              beta_steps=200000,
              sampling_off=False,
              dataset='attention_search_both2',
              input_distribution_dim=-1,
              batch_size=64,
              dropout_prob=0.5,
              regenerate_steps=1000,
              dataset_size=5000):
     if "attention" not in dataset:
         raise Exception(
             "{} is not an acceptable training task.".format(dataset))
     self.name = "popout"
     activation_dict = {
         "relu": tf.nn.relu,
         "sigmoid": tf.nn.sigmoid,
         "tanh": tf.nn.tanh,
         "elu": tf.nn.elu
     }
     self.dataset = dataset
     self.dataset_size = int(dataset_size)
     self.iterator = None
     self.next_element = None
     self.batch_size = batch_size
     self.regenerate_steps = regenerate_steps
     self.decision_dim = 2
     # Activation function
     self.activation = activation_dict[activation]
     self.image_width = 32
     self.RGB = True
     self.image_channels = 3
     self.latent_size = 500
     # Size of hidden layer(s)
     self.w_rate = w_rate  # Initialize to this value, but possibly change
     self.beta1 = w_rate  # Final value for rate_loss_weight
     self.beta0 = beta0  # Initial value for rate_loss_weight
     if self.beta0 is not None and self.beta0 > self.beta1:
         raise Exception("beta1 must be greater than beta0")
     self.beta_steps = beta_steps
     self.w_reconstruction_enc, self.w_reconstruction_dec = w_reconstruction
     self.w_decision_enc, self.w_decision_dec = w_decision
     self.w_reg = w_reg
     self.dropout_prob = dropout_prob
     self.sampling_off = sampling_off
     self.checkpoint_top_dir = Path(checkpoint_dir)
     self.trainingset_dir = Path(trainingset_dir)
     self.checkpoint_dir = make_memnet_checkpoint_dir(checkpoint_dir, self)
     print("Using {} activation".format(activation))
     print("Loss weights (rate, reconstruction, decision): ",
           (self.w_rate, w_reconstruction, w_decision))
     self.sess = None
Пример #2
0
def visualize_network_samples_grid_1D(net,
                                      start=5,
                                      stop=100,
                                      step=10,
                                      f_ext="pdf"):
    """
    [PLANTS STIMULI ONLY]
    Like the 2D version above, but just a single line of images (just one
    stimulus dimension).
    """
    save_dir = make_memnet_checkpoint_dir(
        Path("plots/vis_memnet_samples/grid/"), net)
    if not save_dir.exists():
        save_dir.makedirs()
    relevant_vals = range(start, stop, step)
    grid_width = len(relevant_vals)
    print("Loading plants.")
    inputs0, stim_vals0 = load_plants(net.image_width,
                                      net.trainingset_dir,
                                      layer_type=net.layer_type,
                                      normalize=True)
    inputs0 = inputs0 / 255.  # Normalize
    print("Finished.")
    irrelevant_dim = int(net.input_distribution_dim == 0)
    if irrelevant_dim == 0:
        inds = list(product([50], relevant_vals))
    else:
        inds = list(product(relevant_vals, [50]))
    inputs_valid = inputs0[stim_vals0[:, irrelevant_dim] == 50]
    stim_vals = stim_vals0[stim_vals0[:, irrelevant_dim] == 50]
    fig, axes = plt.subplots(1,
                             grid_width,
                             figsize=(50, 50 / len(relevant_vals)))
    preds = []
    inputs = []
    print("Making plots...")
    for i in range(len(stim_vals)):
        if tuple(stim_vals[i]) in inds:
            print("Making ", stim_vals[i])
            inputs.append(inputs_valid[i])
    preds = net.predict(inputs, keep_session=True)
    for i in range(len(preds)):
        ax = axes[i]
        ax.imshow(preds[i:i + 1].reshape(net.image_width, net.image_width),
                  cmap="gray_r")
        ax.set_xticklabels("")
        ax.set_xticks([])
        ax.set_xticklabels("")
        ax.set_yticks([])
        ax.set_yticklabels("")
    plt.savefig(save_dir.joinpath('grid.' + f_ext))
    print("Saved to ", save_dir.joinpath('grid.' + f_ext))
    return
Пример #3
0
def plants_average_reconstruction(net,
                                  plant_coords=[[0, 0], [0, 99], [99, 99],
                                                [99, 0]],
                                  n_samples=100,
                                  f_ext="pdf"):
    """For the specified plants (given by `plant_coords`), plot
    the target image alongside the mean reconstructed image, averaged
    over `n_samples` samples.
    """
    images, stim_vals = load_plants(net.image_width,
                                    net.trainingset_dir,
                                    layer_type=net.layer_type,
                                    normalize=True)
    save_dir = make_memnet_checkpoint_dir(
        Path("plots/plants_average_reconstruction/"), net)
    if not save_dir.exists():
        save_dir.mkdir()

    for coord in plant_coords:
        # Iterate over all stimulus values for target image
        i_target = np.logical_and(coord[0] == stim_vals[:, 0],
                                  coord[1] == stim_vals[:, 1])
        x = images[i_target]
        X = np.repeat(x, n_samples, axis=0)  # Take n samples
        Xhat = net.predict(X, keep_session=True)
        mean_reconstruction = Xhat.mean(axis=0)  # Take mean over samples
        fig, ax = plt.subplots(nrows=1, ncols=2, figsize=(5.6, 3.1))
        ax[0].imshow(x.reshape(net.image_width, net.image_width),
                     cmap="gray_r")
        ax[1].imshow(mean_reconstruction.reshape(net.image_width,
                                                 net.image_width),
                     cmap="gray_r")
        ax[0].set_xlabel("Target", fontsize=14)
        ax[1].set_xlabel("Mean reconstruction", fontsize=14)
        ax[0].set_xticks([])
        ax[1].set_xticks([])
        ax[0].set_xticklabels("")
        ax[1].set_xticklabels("")
        ax[0].set_yticks([])
        ax[1].set_yticks([])
        ax[0].set_yticklabels("")
        ax[1].set_yticklabels("")
        fig.suptitle("Leaf width {}, leaf angle {}".format(coord[0], coord[1]))
        plt.tight_layout()
        # plt.subplots_adjust(top=0.9)
        save_pth = save_dir.joinpath("width{}_angle{}.{}".format(
            coord[0], coord[1], f_ext))
        plt.savefig(save_pth, dpi=300)
    print("Saved to " + save_dir)
    return
Пример #4
0
def visualize_network_samples_grid(net,
                                   start=5,
                                   stop=100,
                                   step=10,
                                   f_ext="pdf"):
    """
    [PLANTS STIMULI ONLY]
    Plot grid of images, where (x, y) position corresponds to the stimulus
    values (width, droop) of the input to the network. Images are the outputs
    of the network given the input.
    """
    save_dir = make_memnet_checkpoint_dir(
        Path("plots/vis_memnet_samples/grid/"), net)
    if not save_dir.exists():
        save_dir.makedirs()
    inds = list(product(range(start, stop, step), range(start, stop, step)))
    grid_width = len(range(start, stop, step))
    print("Loading plants.")
    inputs, stim_vals = load_plants(net.image_width,
                                    net.trainingset_dir,
                                    layer_type=net.layer_type,
                                    normalize=True)
    print("Finished.")
    fig, axes = plt.subplots(grid_width, grid_width, figsize=(50, 50))
    preds = []
    print("Making plots...")
    for i in range(len(stim_vals)):
        if tuple(stim_vals[i]) in inds:
            print("Making ", stim_vals[i])
            preds.append(
                net.predict(inputs[i:i + 1],
                            keep_session=True).reshape(net.image_width,
                                                       net.image_width))
            grid_coord = tuple(stim_vals[i] / step)
            ax = axes[int(grid_coord[1]), int(grid_coord[0])]
            # ax = axes[tuple(stim_vals[i] / step)]
            ax.imshow(preds[-1], cmap="gray_r")
            ax.set_xticklabels("")
            ax.set_xticks([])
            ax.set_xticklabels("")
            ax.set_yticks([])
            ax.set_yticklabels("")
    plt.savefig(save_dir.joinpath('grid.' + f_ext))
    print("Saved to ", save_dir.joinpath('grid.' + f_ext))
    return
Пример #5
0
def correlation(net, start=0, stop=100, step=10, n_samples=2, f_ext="pdf"):
    """An analysis for networks trained with non-uniform input distributions.
    (Designed for 'plants' experiment)

    Correlate reconstructed images with all original stimulus images.
    Since one plant dimension will be uniform and one will be non-uniform,
    marginalize over the non-uniform dimension

    n_samples: Number of network samples to average over
    """
    from scipy.stats import pearsonr
    from sklearn.metrics import auc
    # Save directory
    save_dir = make_memnet_checkpoint_dir(Path("plots/correlation/"), net)
    if not save_dir.exists():
        save_dir.mkdir()
    dim = net.input_distribution_dim
    if dim == -1:
        raise Exception
    other_dim = 1 if dim == 0 else 0
    images, stim_vals = load_plants(net.image_width,
                                    net.trainingset_dir,
                                    layer_type=net.layer_type,
                                    normalize=True)
    target_selection = range(start, stop, step)  # Select only subset of values
    probe_selection = range(0, 100, 1)
    unique_stim_vals = 100
    corrs = np.zeros((len(target_selection), len(probe_selection)))
    xhats = np.zeros((len(target_selection), len(probe_selection),
                      unique_stim_vals, images.shape[1]))
    A_left = []
    A_right = []
    for i, dim_t in enumerate(target_selection):
        # Iterate over all stimulus values for target image
        inds_x = [
            np.logical_and(dim_t == stim_vals[:, dim],
                           d == stim_vals[:, other_dim])
            for d in range(unique_stim_vals)
        ]
        x = np.vstack([images[ix] for ix in inds_x])
        for j, dim_p in enumerate(probe_selection):
            # Iterate over all stimulus values for probe image
            inds_y = [
                np.logical_and(dim_p == stim_vals[:, dim],
                               d == stim_vals[:, other_dim])
                for d in range(unique_stim_vals)
            ]
            y = np.vstack([images[iy] for iy in inds_y])
            xhats[i, j] = np.mean(
                [net.predict(x, keep_session=True) for _ in range(n_samples)],
                axis=0)
            corrs[i, j] = np.mean([
                pearsonr(xhats[i, j][k], y[k])[0]
                for k in range(unique_stim_vals)
            ])
            print(i, j)
        i_true_target_val = probe_selection.index(
            dim_t
            # Note, probe_selection must include all the target values for this
            # to work
        )
        rAleft = corrs[i, max(0, i_true_target_val - 10):i_true_target_val + 1]
        rAright = corrs[i, i_true_target_val:i_true_target_val + 11]
        if len(rAleft) > 1:
            A_left.append(auc(range(len(rAleft)), rAleft))
        else:
            A_left.append(0)
        if len(rAright) > 1:
            A_right.append(auc(range(len(rAright)), rAright))
        else:
            A_right.append(0)
    cutoff = 100
    fig, ax = plt.subplots()
    for i in range(len(target_selection)):
        ind_t = probe_selection.index(target_selection[i])
        line = ax.plot(probe_selection[max(0, ind_t - cutoff):ind_t + cutoff],
                       corrs[i][max(0, ind_t - cutoff):ind_t + cutoff])
        color = line[0].get_color()
        ax.axvline(x=target_selection[i], color=color, linestyle="--")
    ax.set_ylim(0.6, 1)
    ax.set_xlim(0, 100)
    ax.set_xlabel("stimulus value", fontsize=16)
    ax.set_ylabel("correlation coefficient", fontsize=16)
    save_pth = save_dir.joinpath("correlation_dim{}_mean{}_std{}.{}".format(
        net.input_distribution_dim, net.input_mean, net.input_std, f_ext))
    plt.savefig(save_pth)
    print("Saved to ", save_pth)
    return
Пример #6
0
def visualize_network_samples(net,
                              outputs_per_example=5,
                              color_map="gray",
                              setsize=[1, 6],
                              f_ext="pdf"):
    """Make grid of sample images from decoder
    """
    n_examples = net.batch_size
    if net.name == "cifar":
        inputs, targets = generate_training_data(
            None,
            None,
            dataset=net.dataset,
            train_test="test",
            conv=True,
            data_dir=net.trainingset_dir)[0:2]
        inputs, targets = net.get_batch(inputs, targets, None)[0:2]
        sig = True  # For predicting decisions
    elif net.name == "fruits":
        inputs, recall_targets = net.generate_batch("test")
        targets = inputs
        sig = True
    elif net.name == "popout":
        inputs, recall_targets = net.generate_training_data(n_examples)
        targets = inputs
        sig = False
    else:
        inputs = generate_training_data(n_examples,
                                        net.image_width,
                                        task_weights=net.task_weights,
                                        dataset=net.dataset,
                                        conv=net.layer_type == "conv",
                                        data_dir=net.trainingset_dir,
                                        setsize=setsize,
                                        mean=net.input_mean,
                                        std=net.input_std,
                                        dim=net.input_distribution_dim)[0]
        targets = inputs
        sig = False
    if net.image_channels > 1:
        shp = [net.image_width, net.image_width, net.image_channels]
    else:
        shp = [net.image_width, net.image_width]
    preds0, dec0 = zip(*[
        net.predict_both(
            inputs, np.zeros_like(inputs), sigmoid=sig, keep_session=True)
        for i in range(outputs_per_example)
    ])
    preds1 = np.array(preds0)
    preds2 = preds1.swapaxes(0, 1)
    preds = preds2.reshape([n_examples, outputs_per_example] + shp)
    preds = preds.clip(0, 1)
    # dec = np.array(dec0).swapaxes(0, 1)
    # print("decisions: ", dec)
    # print(dec)
    save_dir = make_memnet_checkpoint_dir(Path("plots/vis_memnet_samples/"),
                                          net)
    if not save_dir.exists():
        save_dir.makedirs()
    print("Saving images to: ", save_dir)
    for i in range(n_examples):
        fig, axes = plt.subplots(nrows=1,
                                 ncols=outputs_per_example + 1,
                                 figsize=(45, 10))
        if net.RGB:
            axes[0].imshow(targets[i], cmap=color_map)
        else:
            axes[0].imshow(targets[i].reshape(net.image_width,
                                              net.image_width),
                           cmap=color_map)
        [
            axes[j + 1].imshow(preds[i][j], cmap=color_map)
            for j in range(outputs_per_example)
        ]
        [ax.set_aspect('equal') for ax in axes]
        xlabs = ["Target"] + [
            "Sample {}".format(j) for j in range(1, outputs_per_example + 1)
        ]
        [ax.set_xlabel(l, size=50) for ax, l in zip(axes, xlabs)]
        for ax in axes:
            ax.set_xticks([])
            ax.set_xticklabels("")
            ax.set_yticks([])
            ax.set_yticklabels("")
        plt.tight_layout()
        plt.savefig(
            save_dir.joinpath('setsize_{}_sample{}.{}'.format(
                "_".join([str(x) for x in setsize]),
                str(i).zfill(2), f_ext)))
 def __init__(self,
              hidden_size,
              latent_size,
              activation,
              checkpoint_dir,
              trainingset_dir,
              decision_size=100,
              encoder_layers=1,
              decoder_layers=1,
              decision_layers=1,
              kernel_size=3,
              load_decision_weights=False,
              task_weights=None,
              w_reconstruction=[1., 1.],
              w_decision=[1., 1.],
              w_reg=1e-10,
              w_rate=0,
              beta0=None,
              beta_steps=200000,
              sampling_off=False,
              encode_probe=False,
              dataset='rectangle',
              image_width=100,
              RGB=False,
              layer_type="MLP",
              decision_target="same_different",
              loss_func_dec="squared_error",
              loss_func_recon=None,
              dataset_size=int(1e4),
              regenerate_steps=None,
              input_std=10,
              input_mean=50,
              input_distribution_dim=-1,
              sample_distribution="gaussian",
              decision_dim=1,
              image_channels=3,
              batch_size=20,
              dropout_prob=1.0):
     if "float" in dataset and layer_type == "conv":
         raise NotImplementedError
     if dataset == "letter" and len(task_weights) != 4:
         raise Exception
     if layer_type == "conv" and hidden_size % 2:
         raise Exception("Kernel size must be odd number")
     self.name = "VAE"
     activation_dict = {
         "relu": tf.nn.relu,
         "sigmoid": tf.nn.sigmoid,
         "tanh": tf.nn.tanh,
         "elu": tf.nn.elu
     }
     self.dataset = dataset
     self.dataset_size = int(dataset_size)
     self.regenerate_steps = regenerate_steps
     self.batch_size = batch_size
     self.input_mean = input_mean
     self.input_std = input_std
     self.input_distribution_dim = input_distribution_dim
     if task_weights is None:
         if decision_dim is None:
             raise Exception(
                 "Need to specify either task_weights or decision_dim")
         else:
             task_weights = [1] * decision_dim
     else:
         if hasattr(task_weights, '__len__'):
             if decision_dim is None:
                 decision_dim = len(task_weights)
             elif not decision_dim == len(task_weights):
                 raise Exception(
                     "decision_dim and task_weights mismatched for length")
         else:
             raise Exception("task_weights must be list or tuple")
     self.task_weights = np.array(task_weights)
     self.decision_dim = decision_dim
     self.layer_type = layer_type
     self.loss_func_dec = loss_func_dec
     self.decision_target = decision_target
     self.logistic_decision = (False if self.decision_target == "tp_dist"
                               else True)
     if dataset in ["gabor_array", "plants_setsize"]:
         # Hardcode for now to avoid errors
         self.decision_dim = 6
         self.decision_target = "recall"
     for i in range(1, 7):
         # Hardcode for now to avoid errors
         if dataset in [
                 "gabor_array{}".format(i), "plants_setsize{}".format(i)
         ]:
             self.decision_dim = i
             self.decision_target = "recall"
     if loss_func_recon is None:
         self.loss_func_recon = "squared_error"
     else:
         self.loss_func_recon = loss_func_recon
     if dataset == "episodic_shape_race":
         self.decision_dim = 10
     elif dataset == "episodic_setsize":
         self.decision_dim = 12
     # Activation function
     self.activation = activation_dict[activation]
     self.image_width = image_width
     self.RGB = RGB
     if self.RGB:
         self.image_channels = 3
     else:
         self.image_channels = 1
     # Size of hidden layer(s)
     self.hidden_size = hidden_size
     self.kernel_size = kernel_size  # Used if layers are convolutional
     self.encoder_layers = encoder_layers
     self.decoder_layers = decoder_layers
     self.decision_layers = decision_layers
     self.latent_size = latent_size
     self.decision_size = decision_size
     self.w_rate = float(
         w_rate)  # Initialize to this value, but possibly change
     self.beta1 = w_rate  # Final value for rate_loss_weight
     self.beta0 = beta0  # Initial value for rate_loss_weight
     if self.beta0 is not None and self.beta0 > self.beta1:
         raise Exception("beta1 must be greater than beta0")
     self.beta_steps = beta_steps
     self.w_reconstruction_enc, self.w_reconstruction_dec = w_reconstruction
     self.w_reconstruction_enc = float(self.w_reconstruction_enc)
     self.w_reconstruction_dec = float(self.w_reconstruction_dec)
     self.w_decision_enc, self.w_decision_dec = w_decision
     self.w_decision_enc = float(self.w_decision_enc)
     self.w_decision_dec = float(self.w_decision_dec)
     self.w_reg = float(w_reg)
     self.dropout_prob = dropout_prob
     self.sampling_off = sampling_off
     self.sample_distribution = sample_distribution
     self.encode_probe = encode_probe
     self.checkpoint_top_dir = Path(checkpoint_dir)
     self.trainingset_dir = Path(trainingset_dir)
     self.checkpoint_dir = make_memnet_checkpoint_dir(checkpoint_dir, self)
     self.decision_checkpoint_dir = None
     print("Using {} activation".format(activation))
     print("Using {} loss function for reconstruction error".format(
         self.loss_func_recon))
     print("Loss weights (rate, reconstruction, decision): ",
           (self.w_rate, w_reconstruction, w_decision))
     self.sess = None