def __init__(self, model, contrast, corrected=True, stat_thresh=2.3, stat_range=(2, 5), stat_cmap="OrRd_r", stat_alpha=.85, sharp=False, label_slices=True): """Initialize the object but do not plot anything yet.""" anat_img = nib.load(op.join(data_dir, "average_anat.nii.gz")) contrast_dir = op.join(analysis_dir, exp + "-" + model, "group", "mni", contrast) if corrected: stat_file = "zstat1_threshold.nii.gz" else: stat_file = "zstat1.nii.gz" stat_img = nib.load(op.join(contrast_dir, stat_file)) mask_img = nib.load(op.join(contrast_dir, "mask.nii.gz")) self.anat_img = VolumeImg(anat_img.get_data(), anat_img.get_affine(), "mni") self.stat_img = VolumeImg(stat_img.get_data(), stat_img.get_affine(), "mni") self.mask_img = VolumeImg(mask_img.get_data(), mask_img.get_affine(), "mni", interpolation="nearest") self.stat_thresh = stat_thresh self.stat_cmap = stat_cmap self.stat_alpha = stat_alpha self.stat_vmin, self.stat_vmax = stat_range self.label_slices = label_slices self.sharp = sharp
class SlicePlotter(object): """Object to generate single slice images or mosaics of volume results.""" def __init__(self, model, contrast, corrected=True, stat_thresh=2.3, stat_range=(2, 5), stat_cmap="OrRd_r", stat_alpha=.85, sharp=False, label_slices=True): """Initialize the object but do not plot anything yet.""" anat_img = nib.load(op.join(data_dir, "average_anat.nii.gz")) contrast_dir = op.join(analysis_dir, exp + "-" + model, "group", "mni", contrast) if corrected: stat_file = "zstat1_threshold.nii.gz" else: stat_file = "zstat1.nii.gz" stat_img = nib.load(op.join(contrast_dir, stat_file)) mask_img = nib.load(op.join(contrast_dir, "mask.nii.gz")) self.anat_img = VolumeImg(anat_img.get_data(), anat_img.get_affine(), "mni") self.stat_img = VolumeImg(stat_img.get_data(), stat_img.get_affine(), "mni") self.mask_img = VolumeImg(mask_img.get_data(), mask_img.get_affine(), "mni", interpolation="nearest") self.stat_thresh = stat_thresh self.stat_cmap = stat_cmap self.stat_alpha = stat_alpha self.stat_vmin, self.stat_vmax = stat_range self.label_slices = label_slices self.sharp = sharp def plot_slice(self, ax, y=None, z=None, stat_only=False, contour=None): """Draw a single slice image onto a matplotlib Axes object.""" x_vals = np.arange(-70, 74, 2) y_vals = np.arange(-108, 76, 2) z_vals = np.arange(-50, 80, 2) if y is None: label = "z = %d" % z x, y = np.meshgrid(x_vals, y_vals) z = np.ones_like(x) * z elif z is None: label = "y = %d" % y x, z = np.meshgrid(x_vals, z_vals) y = np.ones_like(x) * y else: raise ValueError anat_slice = self.anat_img.values_in_world(x, y, z) anat_slice = np.ma.masked_array(anat_slice, anat_slice < 30) mask_slice = self.mask_img.values_in_world(x, y, z) stat_slice = self.stat_img.values_in_world(x, y, z) stat_mask = stat_slice < self.stat_thresh stat_slice = np.ma.masked_array(stat_slice, stat_mask) mask_mask = (mask_slice == 1) | (anat_slice < 30) mask_slice = np.ma.masked_array(mask_slice, mask_mask) interp = "spline16" if self.sharp else "bilinear" im_kws = dict(origin="lower", interpolation=interp, rasterized=True) if not stat_only: ax.imshow(anat_slice, cmap="Greys_r", vmin=20, vmax=120, **im_kws) ax.imshow(mask_slice, cmap="bone", vmin=-.25, vmax=1, alpha=.5, **im_kws) if contour is None: ax.imshow(stat_slice, cmap=self.stat_cmap, alpha=self.stat_alpha, vmin=self.stat_vmin, vmax=self.stat_vmax, **im_kws) else: outline = stat_slice > contour if outline.any(): ax.contour(outline, 1, cmap="Greys_r", vmin=0, vmax=5, lw=.3) if self.label_slices: ax.set_xlabel(label, size=7, labelpad=1.2) sns.despine(ax=ax, left=True, bottom=True) ax.set(xticks=[], yticks=[]) def plot_cmap(self, ax, vert=True): """Draw a colorbar to show the extent of the statistical colormap.""" bar = np.linspace(self.stat_thresh, self.stat_vmax, 100) bar = np.atleast_2d(bar) if vert: bar = bar.T ax.pcolormesh(bar, cmap=self.stat_cmap, vmin=self.stat_vmin, vmax=self.stat_vmax, rasterized=True) ax.set(xticks=[], yticks=[])