def save_movie(save_file, ani, frame_rate=15): """Save out matplotlib ArtistAnimation Parameters ---------- save_file : :obj:`str` full save file (path and filename) ani : :obj:`matplotlib.animation.ArtistAnimation` object animation to save frame_rate : :obj:`int`, optional frame rate of saved movie """ if save_file is not None: make_dir_if_not_exists(save_file) if save_file[-3:] == 'gif': print('saving video to %s...' % save_file, end='') ani.save(save_file, writer='imagemagick', fps=frame_rate) print('done') else: if save_file[-3:] != 'mp4': save_file += '.mp4' writer = FFMpegWriter(fps=frame_rate, bitrate=-1) print('saving video to %s...' % save_file, end='') ani.save(save_file, writer=writer) print('done')
def plot_real_vs_sampled( latents, latents_samp, states, states_samp, save_file=None, xtick_locs=None, frame_rate=None, format='png'): """Plot real and sampled latents overlaying real and (potentially sampled) states. Parameters ---------- latents : :obj:`np.ndarray` shape (n_frames, n_latents) latents_samp : :obj:`np.ndarray` shape (n_frames, n_latents) states : :obj:`np.ndarray` shape (n_frames,) states_samp : :obj:`np.ndarray` shape (n_frames,) if :obj:`latents_samp` are not conditioned on :obj:`states`, otherwise shape (0,) save_file : :obj:`str` full save file (path and filename) xtick_locs : :obj:`array-like`, optional tick locations in bin values for plot frame_rate : :obj:`float`, optional behavioral video framerate; to properly relabel xticks format : :obj:`str`, optional any accepted matplotlib save format, e.g. 'png' | 'pdf' | 'jpeg' Returns ------- :obj:`matplotlib.figure.Figure` matplotlib figure handle """ fig, axes = plt.subplots(2, 1, figsize=(10, 8)) # plot observations axes[0] = plot_states_overlaid_with_latents( latents, states, ax=axes[0], xtick_locs=xtick_locs, frame_rate=frame_rate) axes[0].set_xticks([]) axes[0].set_xlabel('') axes[0].set_title('Inferred latents') # plot samples if len(states_samp) == 0: plot_states = states title_str = 'Sampled latents' else: plot_states = states_samp title_str = 'Sampled states and latents' axes[1] = plot_states_overlaid_with_latents( latents_samp, plot_states, ax=axes[1], xtick_locs=xtick_locs, frame_rate=frame_rate) axes[1].set_title(title_str) if save_file is not None: make_dir_if_not_exists(save_file) plt.savefig(save_file + '.' + format, dpi=300, format=format) return fig
def plot_states_overlaid_with_latents(latents, states, save_file=None, ax=None, xtick_locs=None, frame_rate=None, format='png'): """Plot states for a single trial overlaid with latents. Parameters ---------- latents : :obj:`np.ndarray` shape (n_frames, n_latents) states : :obj:`np.ndarray` shape (n_frames,) save_file : :obj:`str`, optional full save file (path and filename) ax : :obj:`matplotlib.Axes` object or :obj:`NoneType`, optional axes to plot into; if :obj:`NoneType`, a new figure is created xtick_locs : :obj:`array-like`, optional tick locations in bin values for plot frame_rate : :obj:`float`, optional behavioral video framerate; to properly relabel xticks format : :obj:`str`, optional any accepted matplotlib save format, e.g. 'png' | 'pdf' | 'jpeg' Returns ------- :obj:`matplotlib.figure.Figure` matplotlib figure handle if :obj:`ax=None`, otherwise updated axis """ if ax is None: fig = plt.figure(figsize=(8, 4)) ax = fig.gca() else: fig = None spc = 1.1 * abs(latents.max()) n_latents = latents.shape[1] plotting_latents = latents + spc * np.arange(n_latents) ymin = min(-spc - 1, np.min(plotting_latents)) ymax = max(spc * n_latents, np.max(plotting_latents)) ax.imshow(states[None, :], aspect='auto', extent=(0, len(latents), ymin, ymax), cmap='tab20b', alpha=1.0) ax.plot(plotting_latents, '-k', lw=3) ax.set_ylim([ymin, ymax]) # yticks = spc * np.arange(n_latents) # ax.set_yticks(yticks[::2]) # ax.set_yticklabels(np.arange(n_latents)[::2]) ax.set_yticks([]) # ax.set_ylabel('Latent') ax.set_xlabel('Time (bins)') if xtick_locs is not None: ax.set_xticks(xtick_locs) if frame_rate is not None: ax.set_xticklabels( (np.asarray(xtick_locs) / frame_rate).astype('int')) ax.set_xlabel('Time (sec)') if save_file is not None: make_dir_if_not_exists(save_file) plt.savefig(save_file + '.' + format, dpi=300, format=format) if fig is None: return ax else: return fig
def make_real_vs_sampled_movies(ims_recon, ims_recon_samp, conditional, save_file=None, frame_rate=15): """Produce movie with (AE) reconstructed video and sampled video. Parameters ---------- ims_recon : :obj:`np.ndarray` shape (n_frames, y_pix, x_pix) ims_recon_samp : :obj:`np.ndarray` shape (n_frames, y_pix, x_pix) conditional : :obj:`bool` conditional vs unconditional samples; for creating reconstruction title save_file : :obj:`str`, optional full save file (path and filename) frame_rate : :obj:`float`, optional frame rate of saved movie """ n_frames = ims_recon.shape[0] n_plots = 2 [y_pix, x_pix] = ims_recon[0].shape fig_dim_div = x_pix * n_plots / 10 # aiming for dim 1 being 10 x_dim = x_pix * n_plots / fig_dim_div y_dim = y_pix / fig_dim_div fig, axes = plt.subplots(1, n_plots, figsize=(x_dim, y_dim)) for j in range(2): axes[j].set_xticks([]) axes[j].set_yticks([]) axes[0].set_title('Real Reconstructions\n', fontsize=16) if conditional: title_str = 'Generative Reconstructions\n(Conditional)' else: title_str = 'Generative Reconstructions\n(Unconditional)' axes[1].set_title(title_str, fontsize=16) fig.tight_layout(pad=0) im_kwargs = {'cmap': 'gray', 'vmin': 0, 'vmax': 1, 'animated': True} ims = [] for i in range(n_frames): ims_curr = [] im = axes[0].imshow(ims_recon[i], **im_kwargs) ims_curr.append(im) im = axes[1].imshow(ims_recon_samp[i], **im_kwargs) ims_curr.append(im) ims.append(ims_curr) ani = animation.ArtistAnimation(fig, ims, blit=True, repeat_delay=1000) writer = FFMpegWriter(fps=frame_rate, bitrate=-1) if save_file is not None: make_dir_if_not_exists(save_file) if save_file[-3:] != 'mp4': save_file += '.mp4' print('saving video to %s...' % save_file, end='') ani.save(save_file, writer=writer) print('done')
def make_syllable_movies(ims_orig, state_list, trial_idxs, save_file=None, max_frames=400, frame_rate=10, n_buffer=5, n_pre_frames=3, n_rows=None, single_syllable=None): """Present video clips of each individual syllable in separate panels Parameters ---------- ims_orig : :obj:`np.ndarray` shape (n_frames, n_channels, y_pix, x_pix) state_list : :obj:`list` each entry (one per state) contains all occurences of that discrete state by :obj:`[chunk number, starting index, ending index]` trial_idxs : :obj:`array-like` indices into :obj:`states` for which trials should be plotted save_file : :obj:`str` full save file (path and filename) max_frames : :obj:`int`, optional maximum number of frames to animate frame_rate : :obj:`float`, optional frame rate of saved movie n_buffer : :obj:`int` number of blank frames between syllable instances n_pre_frames : :obj:`int` number of behavioral frames to precede each syllable instance n_rows : :obj:`int` or :obj:`NoneType` number of rows in output movie single_syllable : :obj:`int` or :obj:`NoneType` choose only a single state for movie """ K = len(state_list) # Initialize syllable movie frames plt.clf() if single_syllable is not None: K = 1 fig_width = 5 n_rows = 1 else: fig_width = 10 # aiming for dim 1 being 10 # get video dims bs, n_channels, y_dim, x_dim = ims_orig[0].shape movie_dim1 = n_channels * y_dim movie_dim2 = x_dim if n_rows is None: n_rows = int(np.floor(np.sqrt(K))) n_cols = int(np.ceil(K / n_rows)) fig_dim_div = movie_dim2 * n_cols / fig_width fig_width = (movie_dim2 * n_cols) / fig_dim_div fig_height = (movie_dim1 * n_rows) / fig_dim_div fig, axes = plt.subplots(n_rows, n_cols, figsize=(fig_width, fig_height)) for i, ax in enumerate(fig.axes): ax.set_yticks([]) ax.set_xticks([]) if i >= K: ax.set_axis_off() elif single_syllable is not None: ax.set_title('Syllable %i' % single_syllable, fontsize=16) else: ax.set_title('Syllable %i' % i, fontsize=16) fig.tight_layout(pad=0, h_pad=1.005) imshow_kwargs = {'animated': True, 'cmap': 'gray', 'vmin': 0, 'vmax': 1} ims = [[] for _ in range(max_frames + bs + 200)] # Loop through syllables for i_k, ax in enumerate(fig.axes): # skip if no syllable in this axis if i_k >= K: continue print('processing syllable %i/%i' % (i_k + 1, K)) # skip if no syllables are longer than threshold if len(state_list[i_k]) == 0: continue if single_syllable is not None: i_k = single_syllable i_chunk = 0 i_frame = 0 while i_frame < max_frames: if i_chunk >= len(state_list[i_k]): # show blank if out of syllable examples im = ax.imshow(np.zeros((movie_dim1, movie_dim2)), **imshow_kwargs) ims[i_frame].append(im) i_frame += 1 else: # Get movies/latents chunk_idx = state_list[i_k][i_chunk, 0] which_trial = trial_idxs[chunk_idx] tr_beg = state_list[i_k][i_chunk, 1] tr_end = state_list[i_k][i_chunk, 2] batch = ims_orig[which_trial] movie_chunk = batch[max(tr_beg - n_pre_frames, 0):tr_end] movie_chunk = np.concatenate( [movie_chunk[:, j] for j in range(movie_chunk.shape[1])], axis=1) # if np.sum(states[chunk_idx][tr_beg:tr_end-1] != i_k) > 0: # raise ValueError('Misaligned states for syllable segmentation') # Loop over this chunk for i in range(movie_chunk.shape[0]): im = ax.imshow(movie_chunk[i], **imshow_kwargs) ims[i_frame].append(im) # Add red box if start of syllable syllable_start = n_pre_frames if tr_beg >= n_pre_frames else tr_beg if syllable_start <= i < (syllable_start + 2): rect = matplotlib.patches.Rectangle((5, 5), 10, 10, linewidth=1, edgecolor='r', facecolor='r') im = ax.add_patch(rect) ims[i_frame].append(im) i_frame += 1 # Add buffer black frames for j in range(n_buffer): im = ax.imshow(np.zeros((movie_dim1, movie_dim2)), **imshow_kwargs) ims[i_frame].append(im) i_frame += 1 i_chunk += 1 print('creating animation...', end='') ani = animation.ArtistAnimation( fig, [ims[i] for i in range(len(ims)) if ims[i] != []], interval=20, blit=True, repeat=False) writer = FFMpegWriter(fps=max(frame_rate, 10), bitrate=-1) print('done') if save_file is not None: # put together file name if save_file[-3:] == 'mp4': save_file = save_file[:-3] if single_syllable is not None: state_str = str('_syllable-%02i' % single_syllable) else: state_str = '' save_file += state_str save_file += '.mp4' make_dir_if_not_exists(save_file) print('saving video to %s...' % save_file, end='') ani.save(save_file, writer=writer) print('done')
def plot_neural_reconstruction_traces(traces_ae, traces_neural, save_file=None, xtick_locs=None, frame_rate=None, format='png', scale=0.5, max_traces=8, add_r2=True, add_legend=True, colored_predictions=True): """Plot ae latents and their neural reconstructions. Parameters ---------- traces_ae : :obj:`np.ndarray` shape (n_frames, n_latents) traces_neural : :obj:`np.ndarray` shape (n_frames, n_latents) save_file : :obj:`str`, optional full save file (path and filename) xtick_locs : :obj:`array-like`, optional tick locations in units of bins frame_rate : :obj:`float`, optional frame rate of behavorial video; to properly relabel xticks format : :obj:`str`, optional any accepted matplotlib save format, e.g. 'png' | 'pdf' | 'jpeg' scale : :obj:`int`, optional scale magnitude of traces max_traces : :obj:`int`, optional maximum number of traces to plot, for easier visualization add_r2 : :obj:`bool`, optional print R2 value on plot add_legend : :obj:`bool`, optional print legend on plot colored_predictions : :obj:`bool`, optional color predictions using default seaborn colormap; else predictions are black Returns ------- :obj:`matplotlib.figure.Figure` matplotlib figure handle """ import seaborn as sns sns.set_style('white') sns.set_context('poster') means = np.nanmean(traces_ae, axis=0) std = np.nanstd(traces_ae) / scale # scale for better visualization traces_ae_sc = (traces_ae - means) / std traces_neural_sc = (traces_neural - means) / std traces_ae_sc = traces_ae_sc[:, :max_traces] traces_neural_sc = traces_neural_sc[:, :max_traces] fig = plt.figure(figsize=(12, 8)) if colored_predictions: plt.plot(traces_neural_sc + np.arange(traces_neural_sc.shape[1]), linewidth=3) else: plt.plot(traces_neural_sc + np.arange(traces_neural_sc.shape[1]), linewidth=3, color='k') plt.plot(traces_ae_sc + np.arange(traces_ae_sc.shape[1]), color=[0.2, 0.2, 0.2], linewidth=3, alpha=0.7) # add legend if desired if add_legend: # original latents - gray orig_line = mlines.Line2D([], [], color=[0.2, 0.2, 0.2], linewidth=3, alpha=0.7) # predicted latents - cycle through some colors colors = plt.rcParams['axes.prop_cycle'].by_key()['color'] dls = [] for c in range(5): dls.append( mlines.Line2D([], [], linewidth=3, linestyle='--', dashes=(0, 3 * c, 20, 1), color='%s' % colors[c])) plt.legend([orig_line, tuple(dls)], ['Original latents', 'Predicted latents'], loc='lower right', frameon=True, framealpha=0.7, edgecolor=[1, 1, 1]) # add r2 info if desired if add_r2: from sklearn.metrics import r2_score r2 = r2_score(traces_ae, traces_neural, multioutput='variance_weighted') plt.text(0.05, 0.06, '$R^2$=%1.3f' % r2, horizontalalignment='left', verticalalignment='bottom', transform=plt.gca().transAxes, bbox=dict(facecolor='white', alpha=0.7, edgecolor=[1, 1, 1])) if xtick_locs is not None and frame_rate is not None: plt.xticks(xtick_locs, (np.asarray(xtick_locs) / frame_rate).astype('int')) plt.xlabel('Time (s)') else: plt.xlabel('Time (bins)') plt.ylabel('Latent state') plt.yticks([]) if save_file is not None: make_dir_if_not_exists(save_file) plt.savefig(save_file + '.' + format, dpi=300, format=format) plt.show() return fig
def make_ae_reconstruction_movie(ims_orig, ims_recon_ae, ims_recon_lin=None, save_file=None, frame_rate=15): """Produce movie with original video, reconstructed video, and residual. Parameters ---------- ims_orig : :obj:`np.ndarray` shape (n_frames, n_channels, y_pix, x_pix) ims_recon_ae : :obj:`np.ndarray` shape (n_frames, n_channels, y_pix, x_pix) ims_recon_lin : :obj:`np.ndarray`, optional shape (n_frames, n_channels, y_pix, x_pix) save_file : :obj:`str`, optional full save file (path and filename) frame_rate : :obj:`float`, optional frame rate of saved movie """ n_frames, n_channels, y_pix, x_pix = ims_orig.shape n_cols = 1 if ims_recon_lin is None else 2 n_rows = 3 offset = 1 # 0 if ims_recon_lin is None else 1 scale_ = 5 fig_width = scale_ * n_cols * n_channels / 2 fig_height = y_pix / x_pix * scale_ * n_rows / 2 fig = plt.figure(figsize=(fig_width, fig_height + offset), dpi=100) gs = GridSpec(n_rows, n_cols, figure=fig) axs = [] if ims_recon_lin is not None: axs.append(fig.add_subplot(gs[0, 0])) # 0: original frames axs.append(fig.add_subplot(gs[1, 0])) # 1: ae reconstructed frames axs.append(fig.add_subplot(gs[1, 1])) # 2: ae residuals axs.append(fig.add_subplot(gs[2, 0])) # 3: linear reconstructed frames axs.append(fig.add_subplot(gs[2, 1])) # 4: linear residuals else: axs.append(fig.add_subplot(gs[0, 0])) # 0: original frames axs.append(fig.add_subplot(gs[1, 0])) # 1: ae reconstructed frames axs.append(fig.add_subplot(gs[2, 0])) # 2: ae residuals for ax in fig.axes: ax.set_xticks([]) ax.set_yticks([]) fontsize = 12 axs[0].set_title('Original', fontsize=fontsize) axs[1].set_title('Conv AE reconstructed', fontsize=fontsize) axs[2].set_title('Conv AE residual', fontsize=fontsize) if ims_recon_lin is not None: axs[3].set_title('Linear AE reconstructed', fontsize=fontsize) axs[4].set_title('Linear AE residual', fontsize=fontsize) ims_res_ae = ims_orig - ims_recon_ae if ims_recon_lin is not None: ims_res_lin = ims_orig - ims_recon_lin else: ims_res_lin = None default_kwargs = {'animated': True, 'cmap': 'gray', 'vmin': 0, 'vmax': 1} # ims is a list of lists, each row is a list of artists to draw in the current frame; here we # are just animating one artist, the image, in each frame ims = [] for i in range(ims_orig.shape[0]): ims_curr = [] # original video ims_tmp = ims_orig[i, 0] if n_channels == 1 else concat(ims_orig[i]) im = axs[0].imshow(ims_tmp, **default_kwargs) [s.set_visible(False) for s in axs[0].spines.values()] ims_curr.append(im) # ae reconstructed video ims_tmp = ims_recon_ae[i, 0] if n_channels == 1 else concat( ims_recon_ae[i]) im = axs[1].imshow(ims_tmp, **default_kwargs) [s.set_visible(False) for s in axs[1].spines.values()] ims_curr.append(im) # ae residual video ims_tmp = ims_res_ae[i, 0] if n_channels == 1 else concat(ims_res_ae[i]) im = axs[2].imshow(0.5 + ims_tmp, **default_kwargs) [s.set_visible(False) for s in axs[2].spines.values()] ims_curr.append(im) if ims_recon_lin is not None: # linear reconstructed video ims_tmp = ims_recon_lin[i, 0] if n_channels == 1 else concat( ims_recon_lin[i]) im = axs[3].imshow(ims_tmp, **default_kwargs) [s.set_visible(False) for s in axs[3].spines.values()] ims_curr.append(im) # linear residual video ims_tmp = ims_res_lin[i, 0] if n_channels == 1 else concat( ims_res_lin[i]) im = axs[4].imshow(0.5 + ims_tmp, **default_kwargs) [s.set_visible(False) for s in axs[4].spines.values()] ims_curr.append(im) ims.append(ims_curr) plt.tight_layout(pad=0) ani = animation.ArtistAnimation(fig, ims, blit=True, repeat_delay=1000) writer = FFMpegWriter(fps=frame_rate, bitrate=-1) if save_file is not None: make_dir_if_not_exists(save_file) if save_file[-3:] != 'mp4': save_file += '.mp4' print('saving video to %s...' % save_file, end='') ani.save(save_file, writer=writer) # if save_file[-3:] != 'gif': # save_file += '.gif' # ani.save(save_file, writer='imagemagick', fps=15) print('done')
def plot_neural_reconstruction_traces(traces_ae, traces_neural, save_file=None, xtick_locs=None, frame_rate=None, format='png'): """Plot ae latents and their neural reconstructions. Parameters ---------- traces_ae : :obj:`np.ndarray` shape (n_frames, n_latents) traces_neural : :obj:`np.ndarray` shape (n_frames, n_latents) save_file : :obj:`str`, optional full save file (path and filename) xtick_locs : :obj:`array-like`, optional tick locations in units of bins frame_rate : :obj:`float`, optional frame rate of behavorial video; to properly relabel xticks format : :obj:`str`, optional any accepted matplotlib save format, e.g. 'png' | 'pdf' | 'jpeg' Returns ------- :obj:`matplotlib.figure.Figure` matplotlib figure handle """ import matplotlib.pyplot as plt import matplotlib.lines as mlines import seaborn as sns sns.set_style('white') sns.set_context('poster') means = np.mean(traces_ae, axis=0) std = np.std(traces_ae) * 2 # scale for better visualization traces_ae_sc = (traces_ae - means) / std traces_neural_sc = (traces_neural - means) / std traces_ae_sc = traces_ae_sc[:, :8] traces_neural_sc = traces_neural_sc[:, :8] fig = plt.figure(figsize=(12, 8)) plt.plot(traces_neural_sc + np.arange(traces_neural_sc.shape[1]), linewidth=3) plt.plot(traces_ae_sc + np.arange(traces_ae_sc.shape[1]), color=[0.2, 0.2, 0.2], linewidth=3, alpha=0.7) # add legend # original latents - gray orig_line = mlines.Line2D([], [], color=[0.2, 0.2, 0.2], linewidth=3, alpha=0.7) # predicted latents - cycle through some colors colors = plt.rcParams['axes.prop_cycle'].by_key()['color'] dls = [] for c in range(5): dls.append( mlines.Line2D([], [], linewidth=3, linestyle='--', dashes=(0, 3 * c, 20, 1), color='%s' % colors[c])) plt.legend([orig_line, tuple(dls)], ['Original latents', 'Predicted latents'], loc='lower right', frameon=True, framealpha=0.7, edgecolor=[1, 1, 1]) if xtick_locs is not None and frame_rate is not None: plt.xticks(xtick_locs, (np.asarray(xtick_locs) / frame_rate).astype('int')) plt.xlabel('Time (s)') else: plt.xlabel('Time (bins)') plt.ylabel('Latent state') plt.yticks([]) if save_file is not None: make_dir_if_not_exists(save_file) plt.savefig(save_file + '.' + format, dpi=300, format=format) plt.show() return fig
def make_neural_reconstruction_movie(ims_orig, ims_recon_ae, ims_recon_neural, latents_ae, latents_neural, save_file=None, frame_rate=15): """Produce movie with original video, ae reconstructed video, and neural reconstructed video. Latent traces are additionally plotted, as well as the residual between the ae reconstruction and the neural reconstruction. Parameters ---------- ims_orig : :obj:`np.ndarray` shape (n_frames, n_channels, y_pix, x_pix) ims_recon_ae : :obj:`np.ndarray` shape (n_frames, n_channels, y_pix, x_pix) ims_recon_neural : :obj:`np.ndarray`, optional shape (n_frames, n_channels, y_pix, x_pix) latents_ae : :obj:`np.ndarray`, optional shape (n_frames, n_latents) save_file : :obj:`str`, optional full save file (path and filename) frame_rate : :obj:`float`, optional frame rate of saved movie """ means = np.mean(latents_ae, axis=0) std = np.std(latents_ae) * 2 latents_ae_sc = (latents_ae - means) / std latents_dec_sc = (latents_neural - means) / std n_channels, y_pix, x_pix = ims_orig.shape[1:] n_time, n_ae_latents = latents_ae.shape n_cols = 3 n_rows = 2 offset = 2 # 0 if ims_recon_lin is None else 1 scale_ = 5 fig_width = scale_ * n_cols * n_channels / 2 fig_height = y_pix / x_pix * scale_ * n_rows / 2 fig = plt.figure(figsize=(fig_width, fig_height + offset)) gs = GridSpec(n_rows, n_cols, figure=fig) axs = [] axs.append(fig.add_subplot(gs[0, 0])) # 0: original frames axs.append(fig.add_subplot(gs[0, 1])) # 1: ae reconstructed frames axs.append(fig.add_subplot(gs[0, 2])) # 2: neural reconstructed frames axs.append(fig.add_subplot(gs[1, 0])) # 3: residual axs.append(fig.add_subplot(gs[1, 1:3])) # 4: ae and predicted ae latents for i, ax in enumerate(fig.axes): ax.set_yticks([]) if i > 2: ax.get_xaxis().set_tick_params(labelsize=12, direction='in') axs[0].set_xticks([]) axs[1].set_xticks([]) axs[2].set_xticks([]) axs[3].set_xticks([]) # check that the axes are correct fontsize = 12 idx = 0 axs[idx].set_title('Original', fontsize=fontsize) idx += 1 axs[idx].set_title('AE reconstructed', fontsize=fontsize) idx += 1 axs[idx].set_title('Neural reconstructed', fontsize=fontsize) idx += 1 axs[idx].set_title('Reconstructions residual', fontsize=fontsize) idx += 1 axs[idx].set_title('AE latent predictions', fontsize=fontsize) axs[idx].set_xlabel('Time (bins)', fontsize=fontsize) time = np.arange(n_time) ims_res = ims_recon_ae - ims_recon_neural im_kwargs = {'animated': True, 'cmap': 'gray', 'vmin': 0, 'vmax': 1} tr_kwargs = {'animated': True, 'linewidth': 2} latents_ae_color = [0.2, 0.2, 0.2] latents_dec_color = [0, 0, 0] # ims is a list of lists, each row is a list of artists to draw in the # current frame; here we are just animating one artist, the image, in # each frame ims = [] for i in range(n_time): ims_curr = [] idx = 0 if i % 100 == 0: print('processing frame %03i/%03i' % (i, n_time)) ################### # behavioral videos ################### # original video ims_tmp = ims_orig[i, 0] if n_channels == 1 else concat(ims_orig[i]) im = axs[idx].imshow(ims_tmp, **im_kwargs) ims_curr.append(im) idx += 1 # ae reconstruction ims_tmp = ims_recon_ae[i, 0] if n_channels == 1 else concat( ims_recon_ae[i]) im = axs[idx].imshow(ims_tmp, **im_kwargs) ims_curr.append(im) idx += 1 # neural reconstruction ims_tmp = ims_recon_neural[i, 0] if n_channels == 1 else concat( ims_recon_neural[i]) im = axs[idx].imshow(ims_tmp, **im_kwargs) ims_curr.append(im) idx += 1 # residual ims_tmp = ims_res[i, 0] if n_channels == 1 else concat(ims_res[i]) im = axs[idx].imshow(0.5 + ims_tmp, **im_kwargs) ims_curr.append(im) idx += 1 ######## # traces ######## # latents over time for latent in range(n_ae_latents): # just put labels on last lvs if latent == n_ae_latents - 1 and i == 0: label_ae = 'AE latents' label_dec = 'Predicted AE latents' else: label_ae = None label_dec = None im = axs[idx].plot(time[0:i + 1], latent + latents_ae_sc[0:i + 1, latent], color=latents_ae_color, alpha=0.7, label=label_ae, **tr_kwargs)[0] axs[idx].spines['top'].set_visible(False) axs[idx].spines['right'].set_visible(False) axs[idx].spines['left'].set_visible(False) ims_curr.append(im) im = axs[idx].plot(time[0:i + 1], latent + latents_dec_sc[0:i + 1, latent], color=latents_dec_color, label=label_dec, **tr_kwargs)[0] axs[idx].spines['top'].set_visible(False) axs[idx].spines['right'].set_visible(False) axs[idx].spines['left'].set_visible(False) plt.legend(loc='lower right', fontsize=fontsize, frameon=True, framealpha=0.7, edgecolor=[1, 1, 1]) ims_curr.append(im) ims.append(ims_curr) plt.tight_layout(pad=0) ani = animation.ArtistAnimation(fig, ims, blit=True, repeat_delay=1000) writer = FFMpegWriter(fps=frame_rate, bitrate=-1) if save_file is not None: make_dir_if_not_exists(save_file) if save_file[-3:] != 'mp4': save_file += '.mp4' print('saving video to %s...' % save_file, end='') ani.save(save_file, writer=writer) print('done')