def save_patches(self, filename_stem): batch = self.data_stream.get_epoch_iterator(as_dict=True).next() videos = batch['features'] locationss, scaless, patchess = self.extractor(videos) n_patches = patchess.shape[1] if videos.shape[1] == 1: # remove degenerate channel axis because pyplot rejects it videos = np.squeeze(videos, axis=1) patchess = np.squeeze(patchess, axis=2) else: # move channel axis to the end because pyplot wants this videos = np.rollaxis(videos, 1, videos.ndim) patchess = np.rollaxis(patchess, 2, patchess.ndim) outer_grid = gridspec.GridSpec(2, 1) for i, (video, patches, locations, scales) in enumerate(zip(videos, patchess, locationss, scaless)): video_ax = plt.subplot(outer_grid[0, 0]) video_image = (video.transpose(1, 0, 2).reshape( (video.shape[1], video.shape[0] * video.shape[2]))) self.imshow(video_image, axes=video_ax) video_ax.axis("off") # TODO: maybe rectangles in video_ax patch_ax = plt.subplot(outer_grid[1, 0]) patch_image = (patches.transpose(0, 2, 1, 3).reshape( (patches.shape[0] * patches.shape[2], patches.shape[1] * patches.shape[3]))) self.imshow(patch_image, axes=patch_ax) patch_ax.axis("off") fig = plt.gcf() fig.set_size_inches((20, 20)) plt.tight_layout() filename = "%s_example_%i.png" % (filename_stem, i) fig.savefig(filename, bbox_inches="tight", facecolor="gray") plt.close()