def _grab_RNN(self, initial_states):
        '''Creates objects for interfacing with the RNN.

        These objects include 1) the optimization variables (initialized to
        the user-specified initial_states) which will, after optimization,
        contain fixed points of the RNN, and 2) hooks into those optimization
        variables that are required for building the TF graph.

        Args:
            initial_states: Either an [n_inits x n_dims] numpy array or an
            LSTMStateTuple with initial_states.c and initial_states.h as
            [n_inits x n_dims/2] numpy arrays. These data specify the initial
            states of the RNN, from which the optimization will search for
            fixed points. The choice of type must be consistent with state
            type of rnn_cell.

        Returns:
            x: An [n_inits x n_dims] tf.Variable (the optimization variable)
            representing RNN states, initialized to the values in
            initial_states. If the RNN is an LSTM, n_dims represents the
            concatenated hidden and cell states.

            F: An [n_inits x n_dims] tf op representing the state transition
            function of the RNN applied to x.

            states: Contains the same data as in x, but formatted to interface
            with self.rnn_cell (e.g., formatted as LSTMStateTuple if rnn_cell
            is a LSTMCell)

            new_states: Contains the same data as in F, but formatted to
            interface with self.rnn_cell
        '''
        if self.is_lstm:
            # [1 x (2*n_dims)]
            c_h_init = tf_utils.convert_from_LSTMStateTuple(initial_states)

            # [1 x (2*n_dims)]
            x = tf.Variable(c_h_init, dtype=tf.float32)

            states = tf_utils.convert_to_LSTMStateTuple(x)
        else:
            x = tf.Variable(initial_states, dtype=tf.float32)
            states = x

        n_inits = x.shape[0]
        tiled_inputs = np.tile(self.inputs, [n_inits, 1])
        inputs_tf = tf.constant(tiled_inputs, dtype=tf.float32)

        output, new_states = self.rnn_cell(inputs_tf, states)

        if self.is_lstm:
            # [1 x (2*n_dims)]
            F = tf_utils.convert_from_LSTMStateTuple(new_states)
        else:
            F = new_states

        init = tf.variables_initializer(var_list=[x])
        self.session.run(init)

        return x, F, states, new_states
    def _grab_RNN(self, initial_states, inputs):
        '''Creates objects for interfacing with the RNN.

        These objects include 1) the optimization variables (initialized to
        the user-specified initial_states) which will, after optimization,
        contain fixed points of the RNN, and 2) hooks into those optimization
        variables that are required for building the TF graph.

        Args:
            initial_states: Either an [n x n_states] numpy array or an
            LSTMStateTuple with initial_states.c and initial_states.h as
            [n x n_states/2] numpy arrays. These data specify the initial
            states of the RNN, from which the optimization will search for
            fixed points. The choice of type must be consistent with state
            type of rnn_cell.

            inputs: A [n x n_inputs] numpy array specifying the inputs to the
            RNN for this fixed point optimization.

        Returns:
            x: An [n x n_states] tf.Variable (the optimization variable)
            representing RNN states, initialized to the values in
            initial_states. If the RNN is an LSTM, n_states represents the
            concatenated hidden and cell states.

            F: An [n x n_states] tf op representing the state transition
            function of the RNN applied to x.
        '''

        if self.is_lstm:
            c_h_init = tf_utils.convert_from_LSTMStateTuple(initial_states)
            x = tf.Variable(c_h_init, dtype=self.tf_dtype)
            x_rnncell = tf_utils.convert_to_LSTMStateTuple(x)
        else:
            x = tf.Variable(initial_states, dtype=self.tf_dtype)
            x_rnncell = x

        n = x.shape[0]
        inputs_tf = tf.constant(inputs, dtype=self.tf_dtype)

        output, F_rnncell = self.rnn_cell(inputs_tf, x_rnncell)

        if self.is_lstm:
            F = tf_utils.convert_from_LSTMStateTuple(F_rnncell)
        else:
            F = F_rnncell

        init = tf.variables_initializer(var_list=[x])
        self.session.run(init)

        return x, F
Esempio n. 3
0
    def identify_distance_non_outliers(fps, initial_states, dist_thresh):
        if tf_utils.is_lstm(initial_states):
            initial_states = \
                tf_utils.convert_from_LSTMStateTuple(initial_states)

        n_inits = initial_states.shape[0]
        n_fps = fps.n

        centroid = np.mean(initial_states, axis=0) # shape (n_states,)
        init_dists = \
            np.linalg.norm(initial_states - centroid, axis=1) # shape: (n,)
        avg_init_dist = np.mean(init_dists)
        scaled_init_dists = \
            np.true_divide(init_dists, avg_init_dist)# shape: (n,)

        fps_dists = np.linalg.norm(fps.xstar - centroid, axis=1)
        scaled_fps_dists = np.true_divide(fps_dists, avg_init_dist)

        init_non_outlier_idx = np.where(scaled_init_dists < dist_thresh)[0]
        n_init_non_outliers = init_non_outlier_idx.size
        print('\t\tinitial_states: %d outliers detected (of %d).'
            % (n_inits - n_init_non_outliers, n_inits))

        fps_non_outlier_idx = np.where(scaled_fps_dists < dist_thresh)[0]
        n_fps_non_outliers = fps_non_outlier_idx.size
        print('\t\tfixed points: %d outliers detected (of %d).'
            % (n_fps - n_fps_non_outliers, n_fps))

        return fps_non_outlier_idx
    def sample_states(self,
                      state_traj,
                      n_inits,
                      noise_scale=0.0,
                      rng=npr.RandomState(0)):
        '''Draws random samples from trajectories of the RNN state. Samples
        can optionally be corrupted by independent and identically distributed
        (IID) Gaussian noise. These samples are intended to be used as initial
        states for fixed point optimizations.

        Args:
            state_traj: [n_batch x n_time x n_states] numpy array or
            LSTMStateTuple with .c and .h as [n_batch x n_time x n_states]
            numpy arrays. Contains example trajectories of the RNN state.

            n_inits: int specifying the number of sampled states to return.

            noise_scale (optional): non-negative float specifying the standard
            deviation of IID Gaussian noise samples added to the sampled
            states.

        Returns:
            initial_states: Sampled RNN states as a [n_inits x n_states] numpy
            array or as an LSTMStateTuple with .c and .h as [n_inits x
            n_states] numpy arrays (type matches than of state_traj).

        Raises:
            ValueError if noise_scale is negative.
        '''
        if self.is_lstm:
            state_traj_bxtxd = tf_utils.convert_from_LSTMStateTuple(state_traj)
        else:
            state_traj_bxtxd = state_traj

        [n_batch, n_time, n_states] = state_traj_bxtxd.shape

        # Draw random samples from state trajectories
        states = np.zeros([n_inits, n_states])
        for init_idx in range(n_inits):
            trial_idx = rng.randint(n_batch)
            time_idx = rng.randint(n_time)
            states[init_idx, :] = state_traj_bxtxd[trial_idx, time_idx, :]

        # Add IID Gaussian noise to the sampled states
        if noise_scale > 0.0:
            states += noise_scale * rng.randn(n_inits, n_states)
        elif noise_scale < 0.0:
            raise ValueError('noise_scale must be non-negative,'
                             ' but was %f' % noise_scale)
        else:  # noise_scale == 0 --> don't add noise
            pass

        if self.is_lstm:
            return tf_utils.convert_to_LSTMStateTuple(states)
        else:
            return states
Esempio n. 5
0
def plot_fps(fps,
             state_traj=None,
             plot_batch_idx=None,
             plot_start_time=0,
             plot_stop_time=None,
             mode_scale=0.25,
             fig=None):
    '''Plots a visualization and analysis of the unique fixed points.

    1) Finds a low-dimensional subspace for visualization via PCA. If
    state_traj is provided, PCA is fit to [all of] those RNN state
    trajectories. Otherwise, PCA is fit to the identified unique fixed
    points. This subspace is 3-dimensional if the RNN state dimensionality
    is >= 3.

    2) Plots the PCA representation of the stable unique fixed points as
    black dots.

    3) Plots the PCA representation of the unstable unique fixed points as
    red dots.

    4) Plots the PCA representation of the modes of the Jacobian at each
    fixed point. By default, only unstable modes are plotted.

    5) (optional) Plots example RNN state trajectories as blue lines.

    Args:
        fps: a FixedPoints object. See FixedPoints.py.

        state_traj (optional): [n_batch x n_time x n_states] numpy
        array or LSTMStateTuple with .c and .h as
        [n_batch x n_time x n_states/2] numpy arrays. Contains example
        trials of RNN state trajectories.

        plot_batch_idx (optional): Indices specifying which trials in
        state_traj to plot on top of the fixed points. Default: plot all
        trials.

        plot_start_time (optional): int specifying the first timestep to
        plot in the example trials of state_traj. Default: 0.

        plot_stop_time (optional): int specifying the last timestep to
        plot in the example trials of stat_traj. Default: n_time.

        stop_time (optional):

        mode_scale (optional): Non-negative float specifying the scaling
        of the plotted eigenmodes. A value of 1.0 results in each mode
        plotted as a set of diametrically opposed line segments
        originating at a fixed point, with each segment's length specified
        by the magnitude of the corresponding eigenvalue.

        fig (optional): Matplotlib figure upon which to plot.

    Returns:
        None.
    '''

    FONT_WEIGHT = 'bold'
    if fig is None:
        FIG_WIDTH = 6  # inches
        FIG_HEIGHT = 6  # inches
        fig = plt.figure(figsize=(FIG_WIDTH, FIG_HEIGHT), tight_layout=True)

    if state_traj is not None:
        if tf_utils.is_lstm(state_traj):
            state_traj_bxtxd = tf_utils.convert_from_LSTMStateTuple(state_traj)
        else:
            state_traj_bxtxd = state_traj

        [n_batch, n_time, n_states] = state_traj_bxtxd.shape

        # Ensure plot_start_time >= 0
        plot_start_time = np.max([plot_start_time, 0])

        if plot_stop_time is None:
            plot_stop_time = n_time
        else:
            # Ensure plot_stop_time <= n_time
            plot_stop_time = np.min([plot_stop_time, n_time])

        plot_time_idx = range(plot_start_time, plot_stop_time)

    n_inits = fps.n
    n_states = fps.n_states

    if n_states >= 3:
        pca = PCA(n_components=3)

        if state_traj is not None:
            state_traj_btxd = np.reshape(state_traj_bxtxd,
                                         (n_batch * n_time, n_states))
            pca.fit(state_traj_btxd)
        else:
            pca.fit(fps.xstar)

        ax = fig.add_subplot(111, projection='3d')
        ax.set_xlabel('PC 1', fontweight=FONT_WEIGHT)
        ax.set_zlabel('PC 3', fontweight=FONT_WEIGHT)
        ax.set_ylabel('PC 2', fontweight=FONT_WEIGHT)

        # For generating figure in paper.md
        #ax.set_xticks([-2, -1, 0, 1, 2])
        #ax.set_yticks([-1, 0, 1])
        #ax.set_zticks([-1, 0, 1])
    else:
        # For 1D or 0D networks (i.e., never)
        pca = None
        ax = fig.add_subplot(111)
        ax.xlabel('Hidden 1', fontweight=FONT_WEIGHT)
        if n_states == 2:
            ax.ylabel('Hidden 2', fontweight=FONT_WEIGHT)

    if state_traj is not None:
        if plot_batch_idx is None:
            plot_batch_idx = range(n_batch)

        for batch_idx in plot_batch_idx:
            x_idx = state_traj_bxtxd[batch_idx]

            if n_states >= 3:
                z_idx = pca.transform(x_idx[plot_time_idx, :])
            else:
                z_idx = x_idx[plot_time_idx, :]
            plot_123d(ax, z_idx, color='b', linewidth=0.2)

    for init_idx in range(n_inits):
        plot_fixed_point(
            ax,
            fps[init_idx],
            # xstar[init_idx:(init_idx+1)],
            # J_xstar[init_idx],
            pca,
            scale=mode_scale)

    plt.ion()
    plt.show()
    plt.pause(1e-10)
Esempio n. 6
0
    def plot(self,
             state_traj=None,
             plot_batch_idx=None,
             plot_start_time=0,
             plot_stop_time=None,
             mode_scale=0.25,
             stim_config=None,
             gng_time=0,
             block=False,
             title='',
             dpa2_time=[]):
        '''Plots a visualization and analysis of the unique fixed points.

        1) Finds a low-dimensional subspace for visualization via PCA. If
        state_traj is provided, PCA is fit to [all of] those RNN state
        trajectories. Otherwise, PCA is fit to the identified unique fixed
        points. This subspace is 3-dimensional if the RNN state dimensionality
        is >= 3.

        2) Plots the PCA representation of the stable unique fixed points as
        black dots.

        3) Plots the PCA representation of the unstable unique fixed points as
        red dots.

        4) Plots the PCA representation of the modes of the Jacobian at each
        fixed point. By default, only unstable modes are plotted.

        5) (optional) Plots example RNN state trajectories as blue lines.

        Args:
            state_traj (optional): [n_batch x n_time x n_states] numpy
            array or LSTMStateTuple with .c and .h as
            [n_batch x n_time x n_states/2] numpy arrays. Contains example
            trials of RNN state trajectories.

            plot_batch_idx (optional): Indices specifying which trials in
            state_traj to plot on top of the fixed points. Default: plot all
            trials.

            plot_start_time (optional): int specifying the first timestep to
            plot in the example trials of state_traj. Default: 0.

            plot_stop_time (optional): int specifying the last timestep to
            plot in the example trials of stat_traj. Default: n_time.

            stop_time (optional):

            mode_scale (optional): Non-negative float specifying the scaling
            of the plotted eigenmodes. A value of 1.0 results in each mode
            plotted as a set of diametrically opposed line segments
            originating at a fixed point, with each segment's length specified
            by the magnitude of the corresponding eigenvalue.

        Returns:
            None.
        '''
        def plot_123d(ax, z, **kwargs):
            '''Plots in 1D, 2D, or 3D.

            Args:
                ax: Matplotlib figure axis on which to plot everything.

                z: [n x n_states] numpy array containing data to be plotted,
                where n_states is 1, 2, or 3.

                any keyword arguments that can be passed to ax.plot(...).

            Returns:
                None.
            '''
            n_states = z.shape[1]
            if n_states == 3:
                ax.plot(z[:, 0], z[:, 1], z[:, 2], **kwargs)
            elif n_states == 2:
                ax.plot(z[:, 0], z[:, 1], **kwargs)
            elif n_states == 1:
                ax.plot(z, **kwargs)

        def plot_fixed_point(ax,
                             xstar,
                             J,
                             pca,
                             scale=1.0,
                             max_n_modes=3,
                             do_plot_stable_modes=False):
            '''Plots a single fixed point and its dominant eigenmodes.

            Args:
                ax: Matplotlib figure axis on which to plot everything.

                xstar: [1 x n_states] numpy array representing the fixed point
                to be plotted.

                J: [n_states x n_states] numpy array containing the Jacobian
                of the RNN transition function at fixed point xstar.

                pca: PCA object as returned by sklearn.decomposition.PCA. This
                is used to transform the high-d state space representations
                into 3-d for visualization.

                scale (optional): Scale factor for stretching (>1) or shrinking
                (<1) lines representing eigenmodes of the Jacobian. Default:
                1.0 (unity).

                max_n_modes (optional): Maximum number of eigenmodes to plot.
                Default: 3.

                do_plot_stable_modes (optional): bool indicating whether or
                not to plot lines representing stable modes (i.e.,
                eigenvectors of the Jacobian whose eigenvalue magnitude is
                less than one).

            Returns:
                None.
            '''
            n_states = xstar.shape[1]
            e_vals, e_vecs = np.linalg.eig(J)
            sorted_e_val_idx = np.argsort(np.abs(e_vals))

            if max_n_modes > len(e_vals):
                max_n_modes = e_vals

            for mode_idx in range(max_n_modes):

                # -[1, 2, ..., max_n_modes]
                idx = sorted_e_val_idx[-(mode_idx + 1)]

                # Magnitude of complex eigenvalue
                e_val_mag = np.abs(e_vals[idx])

                if e_val_mag > 1.0 or do_plot_stable_modes:

                    # Already real. Cast to avoid warning.
                    e_vec = np.real(e_vecs[:, idx])

                    # [1 x d] numpy arrays
                    xstar_plus = xstar + scale * e_val_mag * e_vec
                    xstar_minus = xstar - scale * e_val_mag * e_vec

                    # [3 x d] numpy array
                    xstar_mode = np.vstack((xstar_minus, xstar, xstar_plus))

                    if e_val_mag < 1.0:
                        color = 'k'
                    else:
                        color = 'r'

                    if n_states >= 3:
                        # [3 x 3] numpy array
                        zstar_mode = pca.transform(xstar_mode)
                    #                    else:
                    #                        zstar_mode = x_star_mode

                    plot_123d(ax, zstar_mode, color=color)

            is_stable = all(np.abs(e_vals) < 1.0)
            if is_stable:
                color = 'k'
            else:
                color = 'r'

            if n_states >= 3:
                zstar = pca.transform(xstar)
            else:
                zstar = xstar
            plot_123d(ax, zstar, color=color, marker='.', markersize=12)

        FIG_WIDTH = 6  # inches
        FIG_HEIGHT = 6  # inches
        FONT_WEIGHT = 'bold'

        xstar = self.xstar
        J_xstar = self.J_xstar

        if state_traj is not None:
            if tf_utils.is_lstm(state_traj):
                state_traj_bxtxd = tf_utils.convert_from_LSTMStateTuple(
                    state_traj)
            else:
                state_traj_bxtxd = state_traj

            [n_batch, n_time, n_states] = state_traj_bxtxd.shape

            # Ensure plot_start_time >= 0
            plot_start_time = np.max([plot_start_time, 0])

            if plot_stop_time is None:
                plot_stop_time = n_time
            else:
                # Ensure plot_stop_time <= n_time
                plot_stop_time = np.min([plot_stop_time, n_time])

            plot_time_idx = range(plot_start_time, plot_stop_time)

        n_inits, n_states = np.shape(xstar)

        fig = plt.figure(figsize=(FIG_WIDTH, FIG_HEIGHT), tight_layout=True)
        if n_states >= 3:
            pca = PCA(n_components=3)

            if state_traj is not None:
                state_traj_btxd = np.reshape(state_traj_bxtxd,
                                             (n_batch * n_time, n_states))
                pca.fit(state_traj_btxd)
            else:
                pca.fit(xstar)

            ax = fig.add_subplot(111, projection='3d')
            ax.set_xlabel('PC 1', fontweight=FONT_WEIGHT)
            ax.set_zlabel('PC 3', fontweight=FONT_WEIGHT)
            ax.set_ylabel('PC 2', fontweight=FONT_WEIGHT)

            # For generating figure in paper.md
            ax.set_xticks([-2, -1, 0, 1, 2])
            ax.set_yticks([-1, 0, 1])
            ax.set_zticks([-1, 0, 1])
        else:
            # For 1D or 0D networks (i.e., never)
            pca = None
            ax = fig.add_subplot(111)
            ax.xlabel('Hidden 1', fontweight=FONT_WEIGHT)
            if n_states == 2:
                ax.ylabel('Hidden 2', fontweight=FONT_WEIGHT)

        if state_traj is not None:
            if plot_batch_idx is None:
                plot_batch_idx = range(n_batch)

            for batch_idx in plot_batch_idx:
                x_idx = state_traj_bxtxd[batch_idx]

                if n_states >= 3:
                    z_idx = pca.transform(x_idx[plot_time_idx, :])
                else:
                    z_idx = x_idx[plot_time_idx, :]
                if stim_config is not None:
                    some_noise = np.random.normal(scale=0.005,
                                                  size=(1, z_idx.shape[1]))
                    plot_123d(ax,
                              z_idx + some_noise,
                              color=stim_config[batch_idx, :],
                              linewidth=0.2)
                else:
                    plot_123d(ax, z_idx, color='b', linewidth=0.2)

                plot_123d(ax,
                          z_idx[0, :].reshape((1, 3)),
                          color='g',
                          marker='+',
                          markersize=8)
                plot_123d(ax,
                          z_idx[int(dpa2_time[batch_idx]) - 1, :].reshape(
                              (1, 3)),
                          color='c',
                          marker='+',
                          markersize=8)
                plot_123d(ax,
                          z_idx[-1, :].reshape((1, 3)),
                          color='m',
                          marker='x',
                          markersize=8)
                if gng_time != 0:
                    plot_123d(ax,
                              z_idx[gng_time, :].reshape((1, 3)),
                              color='b',
                              marker='+',
                              markersize=8)

#        for init_idx in range(n_inits):
#            plot_fixed_point(
#                ax,
#                xstar[init_idx:(init_idx+1)],
#                J_xstar[init_idx],
#                pca,
#                scale=mode_scale)

        plt.ion()
        plt.title(title)
        plt.show(block=block)
        # plt.pause()
        return fig