def novel_pattern_replay(CONFIG): """ Show how a network that has weak connections in addition to its strong connections can learn to replay novel patterns that would not normally arise spontaneously. """ SEED = CONFIG['SEED'] LOAD_FILE_NAME = CONFIG['LOAD_FILE_NAME'] W_WEAK = CONFIG['W_WEAK'] GAIN = CONFIG['GAIN'] REFRACTORY_STRENGTH = CONFIG['REFRACTORY_STRENGTH'] LINGERING_INPUT_VALUE = CONFIG['LINGERING_INPUT_VALUE'] LINGERING_INPUT_TIMESCALE = CONFIG['LINGERING_INPUT_TIMESCALE'] STRONG_DRIVE_AMPLITUDE = CONFIG['STRONG_DRIVE_AMPLITUDE'] WEAK_DRIVE_AMPLITUDE = CONFIG['WEAK_DRIVE_AMPLITUDE'] TRIAL_LENGTH_TRIGGERED_REPLAY = CONFIG['TRIAL_LENGTH_TRIGGERED_REPLAY'] RUN_LENGTH = CONFIG['RUN_LENGTH'] FIG_SIZE_0 = CONFIG['FIG_SIZE_0'] FIG_SIZE_1 = CONFIG['FIG_SIZE_1'] FONT_SIZE = CONFIG['FONT_SIZE'] np.random.seed(SEED) fig = plt.figure(figsize=FIG_SIZE_0, tight_layout=True) axs = [] for row_ctr in range(3): axs.append([fig.add_subplot(4, 3, 3*row_ctr + col_ctr) for col_ctr in range(1, 4)]) axs = list(axs) axs.append(fig.add_subplot(4, 1, 4)) # load old network ntwk_old = np.load(LOAD_FILE_NAME)[0] # demonstrate how one cannot use intrinsic plasticity to learn sequence that is made of disjoint paths path = list(ntwk_old.node_0_path_tree[0][:]) path[2:] = ntwk_old.node_1_path_tree[0][2:] drives = np.zeros((TRIAL_LENGTH_TRIGGERED_REPLAY, ntwk_old.w.shape[0]), dtype=float) for ctr, node in enumerate(path): drives[ctr, node] = STRONG_DRIVE_AMPLITUDE drives[len(path), path[0]] = STRONG_DRIVE_AMPLITUDE for ctr, ax in enumerate(axs[0]): ntwk = deepcopy(ntwk_old) ntwk.store_voltages = True for drive in drives: ntwk.step(drive) spikes = np.array(ntwk.rs_history) fancy_raster.by_row_circles(ax, spikes, drives) ax.set_xlim(-1, len(drives)) ax.set_ylim(-1, 20) ax.set_xlabel('time step') ax.set_ylabel('active \n ensemble') ax.set_title('Strongly driving nonexisting path (trial {})'.format(ctr + 1)) w = ntwk_old.w.copy() # add weak connection to element 2 of node_1 path tree from element 1 of node_0 path tree w[ntwk_old.node_1_path_tree[0][2], ntwk_old.node_0_path_tree[0][1]] = W_WEAK # make new base network ntwk_base = network.RecurrentSoftMaxLingeringModel( w, GAIN, REFRACTORY_STRENGTH, LINGERING_INPUT_VALUE, LINGERING_INPUT_TIMESCALE ) ntwk_base.node_0 = ntwk_old.node_0 ntwk_base.node_0 = ntwk_old.node_1 ntwk_base.node_0_path_tree = ntwk_old.node_0_path_tree ntwk_base.node_1_path_tree = ntwk_old.node_1_path_tree # demonstrate how weak connections allow linking of paths into short term memory path = list(ntwk_base.node_0_path_tree[0][:]) path[2:] = ntwk_base.node_1_path_tree[0][2:] drives = np.zeros((TRIAL_LENGTH_TRIGGERED_REPLAY, ntwk_base.w.shape[0]), dtype=float) for ctr, node in enumerate(path): drives[ctr, node] = STRONG_DRIVE_AMPLITUDE drives[len(path), path[0]] = STRONG_DRIVE_AMPLITUDE for ctr, ax in enumerate(axs[1]): ntwk = deepcopy(ntwk_base) ntwk.store_voltages = True for drive in drives: ntwk.step(drive) spikes = np.array(ntwk.rs_history) fancy_raster.by_row_circles(ax, spikes, drives) ax.set_xlim(-1, len(drives)) ax.set_ylim(-1, 20) ax.set_xlabel('time step') ax.set_ylabel('active \n ensemble') ax.set_title('Activity from forced initial condition after \n driving path with weak connection (trial {})'.format(ctr + 1)) # demonstrate how weak connections do not substantially affect path probabilities path = list(ntwk_base.node_0_path_tree[0][:]) path[2:] = ntwk_base.node_1_path_tree[0][2:] drives = np.zeros((TRIAL_LENGTH_TRIGGERED_REPLAY, ntwk_old.w.shape[0]), dtype=float) drives[0, path[0]] = STRONG_DRIVE_AMPLITUDE for ctr, ax in enumerate(axs[2]): ntwk = deepcopy(ntwk_base) ntwk.store_voltages = True for drive in drives: ntwk.step(drive) spikes = np.array(ntwk.rs_history) fancy_raster.by_row_circles(ax, spikes, drives) ax.set_xlim(-1, len(drives)) ax.set_ylim(-1, 20) ax.set_xlabel('time step') ax.set_ylabel('active ensemble') ax.set_title('Free activity after only forcing \n initial condition (trial {})'.format(ctr + 1)) path = list(ntwk_base.node_0_path_tree[0][:]) path[2:] = ntwk_base.node_1_path_tree[0][2:] drives = np.zeros((RUN_LENGTH, ntwk_base.w.shape[0]), dtype=float) for ctr, node in enumerate(path): drives[ctr, node] = STRONG_DRIVE_AMPLITUDE ntwk = deepcopy(ntwk_base) ntwk.store_voltages = True for drive in drives: ntwk.step(drive) spikes = np.array(ntwk.rs_history) fancy_raster.by_row_circles(axs[3], spikes, drives) axs[3].set_xlim(-1, len(drives)) axs[3].set_ylim(-1, 40) axs[3].set_xlabel('time step') axs[3].set_ylabel('active \n ensemble') axs[3].set_title('Free activity after driving path with weak connection') for ax_row in axs[:-1]: for ax in ax_row: axis_tools.set_fontsize(ax, FONT_SIZE) axis_tools.set_fontsize(axs[-1], FONT_SIZE) # now demonstrate how pattern-matching computation changes with respect to short-term memory fig, axs = plt.subplots(1, 2, figsize=FIG_SIZE_1, tight_layout=True) path = list(ntwk_base.node_0_path_tree[1][:]) path[0] = 22 path[2] = 17 path[3] = ntwk_base.node_1_path_tree[0][3] drives_new = np.zeros((len(path), ntwk_base.w.shape[0]), dtype=float) drives_new[0, path[0]] = STRONG_DRIVE_AMPLITUDE for ctr, node in enumerate(path[1:]): drives_new[ctr + 1, node] = WEAK_DRIVE_AMPLITUDE # drive a network with just the new drive to see how it completes the pattern ntwk = deepcopy(ntwk_base) ntwk.store_voltages = True for drive in drives_new: ntwk.step(drive) spikes = np.array(ntwk.rs_history) fancy_raster.by_row_circles(axs[0], spikes, drives_new) axs[0].set_xlim(-1, 8) axs[0].set_ylim(-1, ntwk_base.w.shape[0]) axs[0].set_xlabel('time step') axs[0].set_ylabel('active ensemble') axs[0].set_title('Weakly driving nonexistent path') drives = np.concatenate([drives[:4, :], drives_new]) ntwk = deepcopy(ntwk_base) ntwk.store_voltages = True for drive in drives: ntwk.step(drive) spikes = np.array(ntwk.rs_history) fancy_raster.by_row_circles(axs[1], spikes, drives) axs[1].set_xlim(-1, 8) axs[1].set_ylim(-1, ntwk_base.w.shape[0]) axs[1].set_xlabel('time step') axs[1].set_ylabel('active ensemble') axs[1].set_title('Weakly driving nonexistent path after \n strongly driving path with weak connection') for ax in axs: axis_tools.set_fontsize(ax, FONT_SIZE)
def novel_pattern_learning(CONFIG): """ Show that a network has an increased probability of embedding a novel pattern into its connectivity network if we allow nonassociative priming to act. """ SEED = CONFIG["SEED"] LOAD_FILE_NAME = CONFIG["LOAD_FILE_NAME"] W_WEAK = CONFIG["W_WEAK"] GAIN = CONFIG["GAIN"] REFRACTORY_STRENGTH = CONFIG["REFRACTORY_STRENGTH"] LINGERING_INPUT_VALUE = CONFIG["LINGERING_INPUT_VALUE"] LINGERING_INPUT_TIMESCALE = CONFIG["LINGERING_INPUT_TIMESCALE"] W_MAX = CONFIG["W_MAX"] ALPHA = CONFIG["ALPHA"] STRONG_DRIVE_AMPLITUDE = CONFIG["STRONG_DRIVE_AMPLITUDE"] RUN_LENGTH = CONFIG["RUN_LENGTH"] N_REPEATS = CONFIG["N_REPEATS"] FIG_SIZE_0 = CONFIG["FIG_SIZE_0"] FONT_SIZE = CONFIG["FONT_SIZE"] np.random.seed(SEED) # create new network with STDP learning rule ntwk_old = np.load(LOAD_FILE_NAME)[0] w = ntwk_old.w.copy() # add weak connection to element 2 of node_1 path tree from element 1 of node_0 path tree w[ntwk_old.node_1_path_tree[0][2], ntwk_old.node_0_path_tree[0][1]] = W_WEAK w_to_track = (ntwk_old.node_1_path_tree[0][2], ntwk_old.node_0_path_tree[0][1]) path_novel = ntwk_old.node_0_path_tree[0][:2] + ntwk_old.node_1_path_tree[0][2:] # make new base network ntwk_base = network.RecurrentSoftMaxLingeringSTDPModelBasic( w, GAIN, REFRACTORY_STRENGTH, LINGERING_INPUT_VALUE, LINGERING_INPUT_TIMESCALE, W_MAX, ALPHA ) ntwk_base.node_0 = ntwk_old.node_0 ntwk_base.node_0 = ntwk_old.node_1 ntwk_base.node_0_path_tree = ntwk_old.node_0_path_tree ntwk_base.node_1_path_tree = ntwk_old.node_1_path_tree # show long trials of novel sequence drive, spontaneous activity, and test sequence fig, axs = plt.subplots(N_REPEATS, 1, figsize=FIG_SIZE_0, sharex=True, tight_layout=True) axs_twin = [ax.twinx() for ax in axs] drives = np.zeros((RUN_LENGTH, ntwk_base.w.shape[0]), dtype=float) for ctr, node in enumerate(path_novel): drives[ctr, node] = STRONG_DRIVE_AMPLITUDE for ax, ax_twin in zip(axs, axs_twin): ntwk = deepcopy(ntwk_base) ntwk.store_voltages = True ws = [] for drive in drives: ntwk.step(drive) ws.append(ntwk.w[w_to_track]) spikes = np.array(ntwk.rs_history) fancy_raster.by_row_circles(ax, spikes, drives) ax_twin.plot(ws, color="b", lw=2, alpha=0.7) ax.set_ylabel("active \n ensemble") ax_twin.set_ylabel("W({}, {})".format(*w_to_track), color="b") axs[-1].set_xlabel("time step") for ax_twin in axs_twin: ax_twin.set_ylim(0, 2) axs[0].set_xlim(-5, RUN_LENGTH) axs[0].set_title("With nonassociative priming") for ax in list(axs) + axs_twin: axis_tools.set_fontsize(ax, FONT_SIZE) fig, axs = plt.subplots(N_REPEATS, 1, figsize=FIG_SIZE_0, sharex=True, tight_layout=True) axs_twin = [ax.twinx() for ax in axs] drives = np.zeros((RUN_LENGTH, ntwk_base.w.shape[0]), dtype=float) for ctr, node in enumerate(path_novel): drives[ctr, node] = STRONG_DRIVE_AMPLITUDE for ax, ax_twin in zip(axs, axs_twin): ntwk = deepcopy(ntwk_base) ntwk.lingering_input_value = 0 ntwk.store_voltages = True ws = [] for drive in drives: ntwk.step(drive) ws.append(ntwk.w[w_to_track]) spikes = np.array(ntwk.rs_history) fancy_raster.by_row_circles(ax, spikes, drives) ax_twin.plot(ws, color="b", lw=2, alpha=0.7) ax.set_ylabel("active \n ensemble") ax_twin.set_ylabel("W({}, {})".format(*w_to_track), color="b") axs[-1].set_xlabel("time step") for ax_twin in axs_twin: ax_twin.set_ylim(0, 2) axs[0].set_xlim(-5, RUN_LENGTH) axs[0].set_title("Without nonassociative priming") for ax in list(axs) + axs_twin: axis_tools.set_fontsize(ax, FONT_SIZE)
def basic_replay_ex(CONFIG): """ Run a simulation demonstrating the basic capability of a network with nonassociative priming to demonstrated replay, both triggered and spontaneous. """ SEED = CONFIG['SEED'] LOAD_FILE_NAME = CONFIG['LOAD_FILE_NAME'] GAIN_HIGH = CONFIG['GAIN_HIGH'] GAIN_LOW = CONFIG['GAIN_LOW'] LINGERING_INPUT_VALUE = CONFIG['LINGERING_INPUT_VALUE'] LINGERING_INPUT_TIMESCALE = CONFIG['LINGERING_INPUT_TIMESCALE'] STRONG_DRIVE_AMPLITUDE = CONFIG['STRONG_DRIVE_AMPLITUDE'] WEAK_DRIVE_AMPLITUDE = CONFIG['WEAK_DRIVE_AMPLITUDE'] TRIAL_LENGTH_TRIGGERED_REPLAY = CONFIG['TRIAL_LENGTH_TRIGGERED_REPLAY'] RUN_LENGTH = CONFIG['RUN_LENGTH'] FIG_SIZE = CONFIG['FIG_SIZE'] FONT_SIZE = CONFIG['FONT_SIZE'] np.random.seed(SEED) ntwk_base = np.load(LOAD_FILE_NAME)[0] ntwk_base.lingering_input_value = LINGERING_INPUT_VALUE ntwk_base.lingering_input_timescale = LINGERING_INPUT_TIMESCALE fig = plt.figure(figsize=FIG_SIZE, tight_layout=True) axs = [] axs.append(fig.add_subplot(6, 2, 1)) axs.append(fig.add_subplot(6, 2, 2)) axs.append(fig.add_subplot(6, 2, 3)) axs.append(fig.add_subplot(6, 2, 4)) axs.append(fig.add_subplot(6, 1, 3)) axs.append(fig.add_subplot(6, 1, 4)) axs.append(fig.add_subplot(6, 1, 5)) axs.append(fig.add_subplot(6, 1, 6)) # play sequences aligned to the network's intrinsic path structure path_00 = ntwk_base.node_0_path_tree[0] path_10 = ntwk_base.node_1_path_tree[0] # drive network for first trial: path_00 drives = np.zeros((TRIAL_LENGTH_TRIGGERED_REPLAY, ntwk_base.w.shape[1]), dtype=float) drives[0, path_00[0]] = STRONG_DRIVE_AMPLITUDE for t_ctr, node in enumerate(path_00[1:]): drives[t_ctr + 1, node] = WEAK_DRIVE_AMPLITUDE drives[len(path_00), path_00[0]] = STRONG_DRIVE_AMPLITUDE ntwk = deepcopy(ntwk_base) ntwk.gain = GAIN_HIGH ntwk.store_voltages = True for drive in drives: ntwk.step(drive) spikes = np.array(ntwk.rs_history) fancy_raster.by_row_circles(axs[0], spikes, drives) axs[0].set_xlim(-1, len(drives)) axs[0].set_ylim(-1, 20) axs[0].set_xlabel('time step') axs[0].set_ylabel('active ensemble') axs[0].set_title('Aligning external drive with \n strongly connected paths') # drive network for first trial: path_10 drives = np.zeros((TRIAL_LENGTH_TRIGGERED_REPLAY, ntwk_base.w.shape[1]), dtype=float) drives[0, path_10[0]] = STRONG_DRIVE_AMPLITUDE for t_ctr, node in enumerate(path_10[1:]): drives[t_ctr + 1, node] = WEAK_DRIVE_AMPLITUDE drives[len(path_10), path_10[0]] = STRONG_DRIVE_AMPLITUDE ntwk = deepcopy(ntwk_base) ntwk.gain = GAIN_HIGH ntwk.store_voltages = True for drive in drives: ntwk.step(drive) spikes = np.array(ntwk.rs_history) fancy_raster.by_row_circles(axs[1], spikes, drives) axs[1].set_xlim(-1, len(drives)) axs[1].set_ylim(-1, 20) axs[1].set_xlabel('time step') axs[1].set_ylabel('active ensemble') axs[1].set_title('Aligning external drive with \n strongly connected paths') # drive network for third trial: all path_00 except for element 2 path = list(path_00[:]) path[2] = path_10[2] drives = np.zeros((TRIAL_LENGTH_TRIGGERED_REPLAY, ntwk_base.w.shape[1]), dtype=float) drives[0, path[0]] = STRONG_DRIVE_AMPLITUDE for t_ctr, node in enumerate(path[1:]): drives[t_ctr + 1, node] = WEAK_DRIVE_AMPLITUDE drives[len(path), path[0]] = STRONG_DRIVE_AMPLITUDE ntwk = deepcopy(ntwk_base) ntwk.gain = GAIN_HIGH ntwk.store_voltages = True for drive in drives: ntwk.step(drive) spikes = np.array(ntwk.rs_history) fancy_raster.by_row_circles(axs[2], spikes, drives) axs[2].set_xlim(-1, len(drives)) axs[2].set_ylim(-1, 20) axs[2].set_xlabel('time step') axs[2].set_ylabel('active ensemble') axs[2].set_title('Aligning external drive with \n nonexisting path') # drive network for fourth trial: all path_10 except for element 2 path = list(path_10[:]) path[2] = path_00[2] drives = np.zeros((TRIAL_LENGTH_TRIGGERED_REPLAY, ntwk_base.w.shape[1]), dtype=float) drives[0, path[0]] = STRONG_DRIVE_AMPLITUDE for t_ctr, node in enumerate(path[1:]): drives[t_ctr + 1, node] = WEAK_DRIVE_AMPLITUDE drives[len(path), path[0]] = STRONG_DRIVE_AMPLITUDE ntwk = deepcopy(ntwk_base) ntwk.gain = GAIN_HIGH ntwk.store_voltages = True for drive in drives: ntwk.step(drive) spikes = np.array(ntwk.rs_history) fancy_raster.by_row_circles(axs[3], spikes, drives) axs[3].set_xlim(-1, len(drives)) axs[3].set_ylim(-1, 20) axs[3].set_xlabel('time step') axs[3].set_ylabel('active ensemble') axs[3].set_title('Aligning external drive with \n nonexisting path') # play sequence and then let network run spontaneously for a while drives = np.zeros((RUN_LENGTH, ntwk_base.w.shape[1]), dtype=float) for t_ctr, node in enumerate(path_00): drives[t_ctr, node] = STRONG_DRIVE_AMPLITUDE ntwk = deepcopy(ntwk_base) ntwk.gain = GAIN_HIGH ntwk.store_voltages = True for drive in drives: ntwk.step(drive) spikes = np.array(ntwk.rs_history) fancy_raster.by_row_circles(axs[4], spikes, drives) axs[4].set_xlim(-1, len(drives)) axs[4].set_ylim(-1, ntwk_base.w.shape[1]) axs[4].set_xlabel('time step') axs[4].set_ylabel('active ensemble') axs[4].set_title('Letting network run freely after driving strongly connected path (high gain)') # play sequence and then let network run spontaneously for a while, now with lower gain drives = np.zeros((RUN_LENGTH, ntwk_base.w.shape[1]), dtype=float) for t_ctr, node in enumerate(path_00): drives[t_ctr, node] = STRONG_DRIVE_AMPLITUDE ntwk = deepcopy(ntwk_base) ntwk.gain = GAIN_LOW ntwk.store_voltages = True for drive in drives: ntwk.step(drive) spikes = np.array(ntwk.rs_history) fancy_raster.by_row_circles(axs[5], spikes, drives) axs[5].set_xlim(-1, len(drives)) axs[5].set_ylim(-1, ntwk_base.w.shape[1]) axs[5].set_xlabel('time step') axs[5].set_ylabel('active ensemble') axs[5].set_title('Letting network run freely after driving strongly connected path (low gain)') # let network run spontaneously for a while with no initial drive ntwk = deepcopy(ntwk_base) ntwk.gain = GAIN_HIGH ntwk.store_voltages = True for _ in range(RUN_LENGTH): ntwk.step() spikes = np.array(ntwk.rs_history) fancy_raster.by_row_circles(axs[6], spikes, drives=None) axs[6].set_xlim(-1, len(drives)) axs[6].set_ylim(-1, ntwk_base.w.shape[1]) axs[6].set_xlabel('time step') axs[6].set_ylabel('active ensemble') axs[6].set_title('Letting network run freely with no drive (high gain)') # let network run spontaneously for a while with no initial drive, now with lower gain ntwk = deepcopy(ntwk_base) ntwk.gain = GAIN_LOW ntwk.store_voltages = True for _ in range(RUN_LENGTH): ntwk.step() spikes = np.array(ntwk.rs_history) fancy_raster.by_row_circles(axs[7], spikes, drives=None) axs[7].set_xlim(-1, len(drives)) axs[7].set_ylim(-1, ntwk_base.w.shape[1]) axs[7].set_xlabel('time step') axs[7].set_ylabel('active ensemble') axs[7].set_title('Letting network run freely with no drive (low gain)') for ax in axs: axis_tools.set_fontsize(ax, FONT_SIZE)