def _grab_RNN(self, initial_states): '''Creates objects for interfacing with the RNN. These objects include 1) the optimization variables (initialized to the user-specified initial_states) which will, after optimization, contain fixed points of the RNN, and 2) hooks into those optimization variables that are required for building the TF graph. Args: initial_states: Either an [n_inits x n_dims] numpy array or an LSTMStateTuple with initial_states.c and initial_states.h as [n_inits x n_dims/2] numpy arrays. These data specify the initial states of the RNN, from which the optimization will search for fixed points. The choice of type must be consistent with state type of rnn_cell. Returns: x: An [n_inits x n_dims] tf.Variable (the optimization variable) representing RNN states, initialized to the values in initial_states. If the RNN is an LSTM, n_dims represents the concatenated hidden and cell states. F: An [n_inits x n_dims] tf op representing the state transition function of the RNN applied to x. states: Contains the same data as in x, but formatted to interface with self.rnn_cell (e.g., formatted as LSTMStateTuple if rnn_cell is a LSTMCell) new_states: Contains the same data as in F, but formatted to interface with self.rnn_cell ''' if self.is_lstm: # [1 x (2*n_dims)] c_h_init = tf_utils.convert_from_LSTMStateTuple(initial_states) # [1 x (2*n_dims)] x = tf.Variable(c_h_init, dtype=tf.float32) states = tf_utils.convert_to_LSTMStateTuple(x) else: x = tf.Variable(initial_states, dtype=tf.float32) states = x n_inits = x.shape[0] tiled_inputs = np.tile(self.inputs, [n_inits, 1]) inputs_tf = tf.constant(tiled_inputs, dtype=tf.float32) output, new_states = self.rnn_cell(inputs_tf, states) if self.is_lstm: # [1 x (2*n_dims)] F = tf_utils.convert_from_LSTMStateTuple(new_states) else: F = new_states init = tf.variables_initializer(var_list=[x]) self.session.run(init) return x, F, states, new_states
def _grab_RNN(self, initial_states, inputs): '''Creates objects for interfacing with the RNN. These objects include 1) the optimization variables (initialized to the user-specified initial_states) which will, after optimization, contain fixed points of the RNN, and 2) hooks into those optimization variables that are required for building the TF graph. Args: initial_states: Either an [n x n_states] numpy array or an LSTMStateTuple with initial_states.c and initial_states.h as [n x n_states/2] numpy arrays. These data specify the initial states of the RNN, from which the optimization will search for fixed points. The choice of type must be consistent with state type of rnn_cell. inputs: A [n x n_inputs] numpy array specifying the inputs to the RNN for this fixed point optimization. Returns: x: An [n x n_states] tf.Variable (the optimization variable) representing RNN states, initialized to the values in initial_states. If the RNN is an LSTM, n_states represents the concatenated hidden and cell states. F: An [n x n_states] tf op representing the state transition function of the RNN applied to x. ''' if self.is_lstm: c_h_init = tf_utils.convert_from_LSTMStateTuple(initial_states) x = tf.Variable(c_h_init, dtype=self.tf_dtype) x_rnncell = tf_utils.convert_to_LSTMStateTuple(x) else: x = tf.Variable(initial_states, dtype=self.tf_dtype) x_rnncell = x n = x.shape[0] inputs_tf = tf.constant(inputs, dtype=self.tf_dtype) output, F_rnncell = self.rnn_cell(inputs_tf, x_rnncell) if self.is_lstm: F = tf_utils.convert_from_LSTMStateTuple(F_rnncell) else: F = F_rnncell init = tf.variables_initializer(var_list=[x]) self.session.run(init) return x, F
def identify_distance_non_outliers(fps, initial_states, dist_thresh): if tf_utils.is_lstm(initial_states): initial_states = \ tf_utils.convert_from_LSTMStateTuple(initial_states) n_inits = initial_states.shape[0] n_fps = fps.n centroid = np.mean(initial_states, axis=0) # shape (n_states,) init_dists = \ np.linalg.norm(initial_states - centroid, axis=1) # shape: (n,) avg_init_dist = np.mean(init_dists) scaled_init_dists = \ np.true_divide(init_dists, avg_init_dist)# shape: (n,) fps_dists = np.linalg.norm(fps.xstar - centroid, axis=1) scaled_fps_dists = np.true_divide(fps_dists, avg_init_dist) init_non_outlier_idx = np.where(scaled_init_dists < dist_thresh)[0] n_init_non_outliers = init_non_outlier_idx.size print('\t\tinitial_states: %d outliers detected (of %d).' % (n_inits - n_init_non_outliers, n_inits)) fps_non_outlier_idx = np.where(scaled_fps_dists < dist_thresh)[0] n_fps_non_outliers = fps_non_outlier_idx.size print('\t\tfixed points: %d outliers detected (of %d).' % (n_fps - n_fps_non_outliers, n_fps)) return fps_non_outlier_idx
def sample_states(self, state_traj, n_inits, noise_scale=0.0, rng=npr.RandomState(0)): '''Draws random samples from trajectories of the RNN state. Samples can optionally be corrupted by independent and identically distributed (IID) Gaussian noise. These samples are intended to be used as initial states for fixed point optimizations. Args: state_traj: [n_batch x n_time x n_states] numpy array or LSTMStateTuple with .c and .h as [n_batch x n_time x n_states] numpy arrays. Contains example trajectories of the RNN state. n_inits: int specifying the number of sampled states to return. noise_scale (optional): non-negative float specifying the standard deviation of IID Gaussian noise samples added to the sampled states. Returns: initial_states: Sampled RNN states as a [n_inits x n_states] numpy array or as an LSTMStateTuple with .c and .h as [n_inits x n_states] numpy arrays (type matches than of state_traj). Raises: ValueError if noise_scale is negative. ''' if self.is_lstm: state_traj_bxtxd = tf_utils.convert_from_LSTMStateTuple(state_traj) else: state_traj_bxtxd = state_traj [n_batch, n_time, n_states] = state_traj_bxtxd.shape # Draw random samples from state trajectories states = np.zeros([n_inits, n_states]) for init_idx in range(n_inits): trial_idx = rng.randint(n_batch) time_idx = rng.randint(n_time) states[init_idx, :] = state_traj_bxtxd[trial_idx, time_idx, :] # Add IID Gaussian noise to the sampled states if noise_scale > 0.0: states += noise_scale * rng.randn(n_inits, n_states) elif noise_scale < 0.0: raise ValueError('noise_scale must be non-negative,' ' but was %f' % noise_scale) else: # noise_scale == 0 --> don't add noise pass if self.is_lstm: return tf_utils.convert_to_LSTMStateTuple(states) else: return states
def plot_fps(fps, state_traj=None, plot_batch_idx=None, plot_start_time=0, plot_stop_time=None, mode_scale=0.25, fig=None): '''Plots a visualization and analysis of the unique fixed points. 1) Finds a low-dimensional subspace for visualization via PCA. If state_traj is provided, PCA is fit to [all of] those RNN state trajectories. Otherwise, PCA is fit to the identified unique fixed points. This subspace is 3-dimensional if the RNN state dimensionality is >= 3. 2) Plots the PCA representation of the stable unique fixed points as black dots. 3) Plots the PCA representation of the unstable unique fixed points as red dots. 4) Plots the PCA representation of the modes of the Jacobian at each fixed point. By default, only unstable modes are plotted. 5) (optional) Plots example RNN state trajectories as blue lines. Args: fps: a FixedPoints object. See FixedPoints.py. state_traj (optional): [n_batch x n_time x n_states] numpy array or LSTMStateTuple with .c and .h as [n_batch x n_time x n_states/2] numpy arrays. Contains example trials of RNN state trajectories. plot_batch_idx (optional): Indices specifying which trials in state_traj to plot on top of the fixed points. Default: plot all trials. plot_start_time (optional): int specifying the first timestep to plot in the example trials of state_traj. Default: 0. plot_stop_time (optional): int specifying the last timestep to plot in the example trials of stat_traj. Default: n_time. stop_time (optional): mode_scale (optional): Non-negative float specifying the scaling of the plotted eigenmodes. A value of 1.0 results in each mode plotted as a set of diametrically opposed line segments originating at a fixed point, with each segment's length specified by the magnitude of the corresponding eigenvalue. fig (optional): Matplotlib figure upon which to plot. Returns: None. ''' FONT_WEIGHT = 'bold' if fig is None: FIG_WIDTH = 6 # inches FIG_HEIGHT = 6 # inches fig = plt.figure(figsize=(FIG_WIDTH, FIG_HEIGHT), tight_layout=True) if state_traj is not None: if tf_utils.is_lstm(state_traj): state_traj_bxtxd = tf_utils.convert_from_LSTMStateTuple(state_traj) else: state_traj_bxtxd = state_traj [n_batch, n_time, n_states] = state_traj_bxtxd.shape # Ensure plot_start_time >= 0 plot_start_time = np.max([plot_start_time, 0]) if plot_stop_time is None: plot_stop_time = n_time else: # Ensure plot_stop_time <= n_time plot_stop_time = np.min([plot_stop_time, n_time]) plot_time_idx = range(plot_start_time, plot_stop_time) n_inits = fps.n n_states = fps.n_states if n_states >= 3: pca = PCA(n_components=3) if state_traj is not None: state_traj_btxd = np.reshape(state_traj_bxtxd, (n_batch * n_time, n_states)) pca.fit(state_traj_btxd) else: pca.fit(fps.xstar) ax = fig.add_subplot(111, projection='3d') ax.set_xlabel('PC 1', fontweight=FONT_WEIGHT) ax.set_zlabel('PC 3', fontweight=FONT_WEIGHT) ax.set_ylabel('PC 2', fontweight=FONT_WEIGHT) # For generating figure in paper.md #ax.set_xticks([-2, -1, 0, 1, 2]) #ax.set_yticks([-1, 0, 1]) #ax.set_zticks([-1, 0, 1]) else: # For 1D or 0D networks (i.e., never) pca = None ax = fig.add_subplot(111) ax.xlabel('Hidden 1', fontweight=FONT_WEIGHT) if n_states == 2: ax.ylabel('Hidden 2', fontweight=FONT_WEIGHT) if state_traj is not None: if plot_batch_idx is None: plot_batch_idx = range(n_batch) for batch_idx in plot_batch_idx: x_idx = state_traj_bxtxd[batch_idx] if n_states >= 3: z_idx = pca.transform(x_idx[plot_time_idx, :]) else: z_idx = x_idx[plot_time_idx, :] plot_123d(ax, z_idx, color='b', linewidth=0.2) for init_idx in range(n_inits): plot_fixed_point( ax, fps[init_idx], # xstar[init_idx:(init_idx+1)], # J_xstar[init_idx], pca, scale=mode_scale) plt.ion() plt.show() plt.pause(1e-10)
def plot(self, state_traj=None, plot_batch_idx=None, plot_start_time=0, plot_stop_time=None, mode_scale=0.25, stim_config=None, gng_time=0, block=False, title='', dpa2_time=[]): '''Plots a visualization and analysis of the unique fixed points. 1) Finds a low-dimensional subspace for visualization via PCA. If state_traj is provided, PCA is fit to [all of] those RNN state trajectories. Otherwise, PCA is fit to the identified unique fixed points. This subspace is 3-dimensional if the RNN state dimensionality is >= 3. 2) Plots the PCA representation of the stable unique fixed points as black dots. 3) Plots the PCA representation of the unstable unique fixed points as red dots. 4) Plots the PCA representation of the modes of the Jacobian at each fixed point. By default, only unstable modes are plotted. 5) (optional) Plots example RNN state trajectories as blue lines. Args: state_traj (optional): [n_batch x n_time x n_states] numpy array or LSTMStateTuple with .c and .h as [n_batch x n_time x n_states/2] numpy arrays. Contains example trials of RNN state trajectories. plot_batch_idx (optional): Indices specifying which trials in state_traj to plot on top of the fixed points. Default: plot all trials. plot_start_time (optional): int specifying the first timestep to plot in the example trials of state_traj. Default: 0. plot_stop_time (optional): int specifying the last timestep to plot in the example trials of stat_traj. Default: n_time. stop_time (optional): mode_scale (optional): Non-negative float specifying the scaling of the plotted eigenmodes. A value of 1.0 results in each mode plotted as a set of diametrically opposed line segments originating at a fixed point, with each segment's length specified by the magnitude of the corresponding eigenvalue. Returns: None. ''' def plot_123d(ax, z, **kwargs): '''Plots in 1D, 2D, or 3D. Args: ax: Matplotlib figure axis on which to plot everything. z: [n x n_states] numpy array containing data to be plotted, where n_states is 1, 2, or 3. any keyword arguments that can be passed to ax.plot(...). Returns: None. ''' n_states = z.shape[1] if n_states == 3: ax.plot(z[:, 0], z[:, 1], z[:, 2], **kwargs) elif n_states == 2: ax.plot(z[:, 0], z[:, 1], **kwargs) elif n_states == 1: ax.plot(z, **kwargs) def plot_fixed_point(ax, xstar, J, pca, scale=1.0, max_n_modes=3, do_plot_stable_modes=False): '''Plots a single fixed point and its dominant eigenmodes. Args: ax: Matplotlib figure axis on which to plot everything. xstar: [1 x n_states] numpy array representing the fixed point to be plotted. J: [n_states x n_states] numpy array containing the Jacobian of the RNN transition function at fixed point xstar. pca: PCA object as returned by sklearn.decomposition.PCA. This is used to transform the high-d state space representations into 3-d for visualization. scale (optional): Scale factor for stretching (>1) or shrinking (<1) lines representing eigenmodes of the Jacobian. Default: 1.0 (unity). max_n_modes (optional): Maximum number of eigenmodes to plot. Default: 3. do_plot_stable_modes (optional): bool indicating whether or not to plot lines representing stable modes (i.e., eigenvectors of the Jacobian whose eigenvalue magnitude is less than one). Returns: None. ''' n_states = xstar.shape[1] e_vals, e_vecs = np.linalg.eig(J) sorted_e_val_idx = np.argsort(np.abs(e_vals)) if max_n_modes > len(e_vals): max_n_modes = e_vals for mode_idx in range(max_n_modes): # -[1, 2, ..., max_n_modes] idx = sorted_e_val_idx[-(mode_idx + 1)] # Magnitude of complex eigenvalue e_val_mag = np.abs(e_vals[idx]) if e_val_mag > 1.0 or do_plot_stable_modes: # Already real. Cast to avoid warning. e_vec = np.real(e_vecs[:, idx]) # [1 x d] numpy arrays xstar_plus = xstar + scale * e_val_mag * e_vec xstar_minus = xstar - scale * e_val_mag * e_vec # [3 x d] numpy array xstar_mode = np.vstack((xstar_minus, xstar, xstar_plus)) if e_val_mag < 1.0: color = 'k' else: color = 'r' if n_states >= 3: # [3 x 3] numpy array zstar_mode = pca.transform(xstar_mode) # else: # zstar_mode = x_star_mode plot_123d(ax, zstar_mode, color=color) is_stable = all(np.abs(e_vals) < 1.0) if is_stable: color = 'k' else: color = 'r' if n_states >= 3: zstar = pca.transform(xstar) else: zstar = xstar plot_123d(ax, zstar, color=color, marker='.', markersize=12) FIG_WIDTH = 6 # inches FIG_HEIGHT = 6 # inches FONT_WEIGHT = 'bold' xstar = self.xstar J_xstar = self.J_xstar if state_traj is not None: if tf_utils.is_lstm(state_traj): state_traj_bxtxd = tf_utils.convert_from_LSTMStateTuple( state_traj) else: state_traj_bxtxd = state_traj [n_batch, n_time, n_states] = state_traj_bxtxd.shape # Ensure plot_start_time >= 0 plot_start_time = np.max([plot_start_time, 0]) if plot_stop_time is None: plot_stop_time = n_time else: # Ensure plot_stop_time <= n_time plot_stop_time = np.min([plot_stop_time, n_time]) plot_time_idx = range(plot_start_time, plot_stop_time) n_inits, n_states = np.shape(xstar) fig = plt.figure(figsize=(FIG_WIDTH, FIG_HEIGHT), tight_layout=True) if n_states >= 3: pca = PCA(n_components=3) if state_traj is not None: state_traj_btxd = np.reshape(state_traj_bxtxd, (n_batch * n_time, n_states)) pca.fit(state_traj_btxd) else: pca.fit(xstar) ax = fig.add_subplot(111, projection='3d') ax.set_xlabel('PC 1', fontweight=FONT_WEIGHT) ax.set_zlabel('PC 3', fontweight=FONT_WEIGHT) ax.set_ylabel('PC 2', fontweight=FONT_WEIGHT) # For generating figure in paper.md ax.set_xticks([-2, -1, 0, 1, 2]) ax.set_yticks([-1, 0, 1]) ax.set_zticks([-1, 0, 1]) else: # For 1D or 0D networks (i.e., never) pca = None ax = fig.add_subplot(111) ax.xlabel('Hidden 1', fontweight=FONT_WEIGHT) if n_states == 2: ax.ylabel('Hidden 2', fontweight=FONT_WEIGHT) if state_traj is not None: if plot_batch_idx is None: plot_batch_idx = range(n_batch) for batch_idx in plot_batch_idx: x_idx = state_traj_bxtxd[batch_idx] if n_states >= 3: z_idx = pca.transform(x_idx[plot_time_idx, :]) else: z_idx = x_idx[plot_time_idx, :] if stim_config is not None: some_noise = np.random.normal(scale=0.005, size=(1, z_idx.shape[1])) plot_123d(ax, z_idx + some_noise, color=stim_config[batch_idx, :], linewidth=0.2) else: plot_123d(ax, z_idx, color='b', linewidth=0.2) plot_123d(ax, z_idx[0, :].reshape((1, 3)), color='g', marker='+', markersize=8) plot_123d(ax, z_idx[int(dpa2_time[batch_idx]) - 1, :].reshape( (1, 3)), color='c', marker='+', markersize=8) plot_123d(ax, z_idx[-1, :].reshape((1, 3)), color='m', marker='x', markersize=8) if gng_time != 0: plot_123d(ax, z_idx[gng_time, :].reshape((1, 3)), color='b', marker='+', markersize=8) # for init_idx in range(n_inits): # plot_fixed_point( # ax, # xstar[init_idx:(init_idx+1)], # J_xstar[init_idx], # pca, # scale=mode_scale) plt.ion() plt.title(title) plt.show(block=block) # plt.pause() return fig