def calculate_redox(self): logging.info("Calculating redox measurements") redox_params = self.config["redox"] # Images self.images = utils.add_derived_wavelengths(self.images, **redox_params) self.rot_fl = utils.add_derived_wavelengths(self.rot_fl, **redox_params) # profiles self.trimmed_raw_profiles = utils.add_derived_wavelengths( self.trimmed_raw_profiles, **redox_params) self.untrimmed_raw_profiles = utils.add_derived_wavelengths( self.untrimmed_raw_profiles, **redox_params)
def test_add_derived_wavelengths(self, paired_imgs): data = paired_imgs r = data.sel(wavelength="410") / data.sel(wavelength="470") oxd = pp.r_to_oxd(r) e = pp.oxd_to_redox_potential(oxd) data = utils.add_derived_wavelengths(data) assert np.allclose(data.sel(wavelength="r").values, r.values, equal_nan=True) assert np.allclose(data.sel(wavelength="oxd").values, oxd.values, equal_nan=True) assert np.allclose(data.sel(wavelength="e").values, e.values, equal_nan=True)
def save_plots(self): with warnings.catch_warnings(): warnings.simplefilter("ignore") for data, treatment, trimmed in [ (self.untrimmed_raw_profiles, "raw", False), (self.untrimmed_std_profiles, "standardized", False), (self.untrimmed_reg_profiles, "channel-registered", False), (self.trimmed_raw_profiles, "raw", True), (self.trimmed_std_profiles, "standardized", True), (self.trimmed_reg_profiles, "channel-registered", True), ]: self.save_individual_profiles(data, treatment, trimmed) self.save_avg_profiles(data, treatment, trimmed) # frame-normed Ratio Images mvmt_annotation_img_path = self.fig_dir.joinpath( f"{self.experiment_id}-movement_annotation_imgs.pdf") imgs = utils.add_derived_wavelengths(self.images, **self.config["redox"]) with PdfPages(mvmt_annotation_img_path) as pdf: for i in tqdm(range(self.images.animal.size)): fig = plots.plot_pharynx_R_imgs(imgs[i], mask=self.seg_images[i]) fig.suptitle(f"animal = {i}") pdf.savefig(fig) if (i % 20) == 0: plt.close("all") # Pop-normed ratio images u = self.trimmed_raw_profiles.sel(wavelength="r").mean() std = self.trimmed_raw_profiles.sel(wavelength="r").std() for pair in self.rot_fl.pair.values: for tp in self.rot_fl.timepoint.values: ratio_img_path = self.fig_dir.joinpath( f"{self.experiment_id}-ratio_images-pair={pair};timepoint={tp}.pdf" ) with PdfPages(ratio_img_path) as pdf: logging.info( f"Saving ratio images to {ratio_img_path}") for i in tqdm(range(self.rot_fl.animal.size)): fig, ax = plt.subplots(dpi=300) ratio_img = (self.rot_fl.sel( wavelength=self.config["redox"] ["ratio_numerator"], pair=pair, timepoint=tp, ) / self.rot_fl.sel( wavelength=self.config["redox"] ["ratio_denominator"], pair=pair, timepoint=tp, ))[i] fl_img = self.rot_fl.sel( wavelength=self.config["redox"] ["ratio_numerator"], pair=pair, timepoint=tp, )[i] im, cbar = plots.imshow_ratio_normed( ratio_img, fl_img, r_min=u - (std * 1.96), r_max=u + (std * 1.96), colorbar=True, i_max=5000, i_min=1000, ax=ax, ) ax.plot( *self.midlines.sel( pair=pair, timepoint=tp, )[i].values[()].linspace(), color="green", alpha=0.3, ) strain = self.rot_fl.strain.values[i] ax.set_title( f"Animal={i} ; Pair={pair} ; Strain={strain}") cax = cbar.ax for j in range(len(self.trimmed_raw_profiles)): cax.axhline( self.trimmed_raw_profiles.sel( wavelength="r", pair=pair, timepoint=tp)[j].mean(), color="k", alpha=0.1, ) cax.axhline( self.trimmed_raw_profiles.sel( wavelength="r", pair=pair, timepoint=tp)[i].mean(), color="k", ) pdf.savefig() if (i % 20) == 0: plt.close("all")
def channel_register( profile_data: xr.DataArray, redox_params: dict, reg_params: dict, eng: matlab.engine.MatlabEngine = None, ) -> Tuple[xr.DataArray, xr.DataArray]: """ Perform channel-registration on the given profile data Parameters ---------- profile_data the data to register redox_params the redox parameters reg_params the registration parameters eng the MATLAB engine (optional) Returns ------- reg_data: xr.DataArray the registered data warp_data: xr.DataArray the warp functions used to register the data """ if eng is None: eng = matlab.engine.start_matlab() reg_profile_data = profile_data.copy() warp_data = profile_data.copy().isel(wavelength=0) for p in profile_data.pair: for tp in profile_data.timepoint: i_num = matlab.double( profile_data.sel( timepoint=tp, pair=p, wavelength=redox_params["ratio_numerator"] ).values.tolist() ) i_denom = matlab.double( profile_data.sel( timepoint=tp, pair=p, wavelength=redox_params["ratio_denominator"] ).values.tolist() ) resample_resolution = float(profile_data.position.size) reg_num, reg_denom, warps = eng.channel_register( i_num, i_denom, resample_resolution, reg_params["warp_n_basis"], reg_params["warp_order"], reg_params["warp_lambda"], reg_params["smooth_lambda"], reg_params["smooth_n_breaks"], reg_params["smooth_order"], reg_params["rough_lambda"], reg_params["rough_n_breaks"], reg_params["rough_order"], reg_params["n_deriv"], nargout=3, ) reg_num, reg_denom = np.array(reg_num).T, np.array(reg_denom).T reg_profile_data.loc[ dict(timepoint=tp, pair=p, wavelength=redox_params["ratio_numerator"]) ] = reg_num reg_profile_data.loc[ dict(timepoint=tp, pair=p, wavelength=redox_params["ratio_denominator"]) ] = reg_denom warp_data.loc[dict(pair=p, timepoint=tp)] = np.array(warps).T reg_profile_data = utils.add_derived_wavelengths(reg_profile_data, **redox_params) return reg_profile_data, warp_data
def standardize_profiles( profile_data: xr.DataArray, redox_params, template: Union[xr.DataArray, np.ndarray] = None, eng=None, **reg_kwargs, ) -> (xr.DataArray, xr.DataArray): """ Standardize the A-P positions of the pharyngeal intensity profiles. Parameters ---------- profile_data The data to standardize. Must have the following dimensions: ``["animal", "timepoint", "pair", "wavelength"]``. redox_params the parameters used to map R -> OxD -> E template a 1D profile to register all intensity profiles to. If None, intensity profiles are registered to the population mean of the ratio numerator. eng The MATLAB engine to use for registration. If ``None``, a new engine is started. reg_kwargs Keyword arguments to use for registration. See `registration kwargs` for more information. Returns ------- standardized_data: xr.DataArray the standardized data warp_functions: xr.DataArray the warp functions generated to standardize the data """ if eng is None: eng = matlab.engine.start_matlab() std_profile_data = profile_data.copy() std_warp_data = profile_data.copy().isel(wavelength=0) if template is None: template = profile_data.sel(wavelength=redox_params["ratio_numerator"]).mean( dim=["animal", "pair"] ) try: template = matlab.double(template.values.tolist()) except AttributeError: template = matlab.double(template.tolist()) for tp in profile_data.timepoint: for pair in profile_data.pair: data = std_profile_data.sel(timepoint=tp, pair=pair) i_num = matlab.double( data.sel(wavelength=redox_params["ratio_numerator"]).values.tolist() ) i_denom = matlab.double( data.sel(wavelength=redox_params["ratio_denominator"]).values.tolist() ) resample_resolution = float(profile_data.position.size) reg_num, reg_denom, warp_data = eng.standardize_profiles( i_num, i_denom, template, resample_resolution, reg_kwargs["warp_n_basis"], reg_kwargs["warp_order"], reg_kwargs["warp_lambda"], reg_kwargs["smooth_lambda"], reg_kwargs["smooth_n_breaks"], reg_kwargs["smooth_order"], reg_kwargs["rough_lambda"], reg_kwargs["rough_n_breaks"], reg_kwargs["rough_order"], reg_kwargs["n_deriv"], nargout=3, ) reg_num, reg_denom = np.array(reg_num).T, np.array(reg_denom).T std_profile_data.loc[ dict( timepoint=tp, pair=pair, wavelength=redox_params["ratio_numerator"] ) ] = reg_num std_profile_data.loc[ dict( timepoint=tp, pair=pair, wavelength=redox_params["ratio_denominator"], ) ] = reg_denom std_warp_data.loc[dict(timepoint=tp, pair=pair)] = np.array(warp_data).T std_profile_data = std_profile_data.assign_attrs(**reg_kwargs) std_profile_data = utils.add_derived_wavelengths(std_profile_data, **redox_params) return std_profile_data, std_warp_data
def summarize_over_regions( data: xr.DataArray, regions: Dict, rescale: bool = True, value_name: str = "value", pointwise: Union[bool, str] = False, **redox_params, ): if pointwise == "both": # recursively call this function for pointwise=T/F and concat the results return pd.concat( [ summarize_over_regions( data, regions, rescale, value_name, pointwise=False ), summarize_over_regions( data, regions, rescale, value_name, pointwise=True ), ] ) if rescale: regions = utils.scale_region_boundaries(regions, data.shape[-1]) # Ensure that derived wavelengths are present data = utils.add_derived_wavelengths(data, **redox_params) with warnings.catch_warnings(): warnings.simplefilter("ignore") all_region_data = [] for _, bounds in regions.items(): if isinstance(bounds, (int, float)): all_region_data.append(data.interp(position=bounds)) else: all_region_data.append( data.sel(position=slice(bounds[0], bounds[1])).mean( dim="position", skipna=True ) ) region_data = xr.concat(all_region_data, pd.Index(regions.keys(), name="region")) region_data = region_data.assign_attrs(**data.attrs) region_data.loc[dict(wavelength="r")] = region_data.sel( wavelength=redox_params["ratio_numerator"] ) / region_data.sel(wavelength=redox_params["ratio_denominator"]) region_data.loc[dict(wavelength="oxd")] = r_to_oxd( region_data.sel(wavelength="r"), r_min=redox_params["r_min"], r_max=redox_params["r_max"], instrument_factor=redox_params["instrument_factor"], ) region_data.loc[dict(wavelength="e")] = oxd_to_redox_potential( region_data.sel(wavelength="oxd"), midpoint_potential=redox_params["midpoint_potential"], z=redox_params["z"], temperature=redox_params["temperature"], ) df = to_dataframe(region_data, value_name) df["pointwise"] = pointwise try: df.set_index(["experiment_id"], append=True, inplace=True) except ValueError: pass return df