Ejemplo n.º 1
0
    def add_all_components(self, cnmf_data_dict: dict, input_params_dict: dict):
        """
        Add all components from a CNMF(E) output. Arguments correspond to CNMF(E) outputs

        :param cnmf_data_dict:      CNMF results data directly from the HDF5 file
        :param input_params_dict:   dict of input params, from the batch manager
        :param calc_raw_min_max:    Calculate raw min & max for each ROI
        :return:
        """
        if not hasattr(self, 'roi_list'):
            self.create_roi_list()

        self.cnmf_data_dict = cnmf_data_dict

        # self.cnmf_obj = load_CNMF(self.cnmf_data_dict)

        self.cnmA = self.cnmf_data_dict['estimates']['A']
        self.cnmb = self.cnmf_data_dict['estimates']['b']
        self.cnm_f = self.cnmf_data_dict['estimates']['f']
        self.cnmC = self.cnmf_data_dict['estimates']['C']
        self.cnmYrA = self.cnmf_data_dict['estimates']['YrA']
        self.dims = self.cnmf_data_dict['dims']
        self.cnmS = self.cnmf_data_dict['estimates']['S']
        self.cnm_dfof = self.cnmf_data_dict['estimates']['F_dff']

        # components are already filtered from the output file
        self.idx_components = np.arange(self.cnmC.shape[0])
        self.orig_idx_components = deepcopy(self.idx_components)
        self.input_params_dict = input_params_dict

        # spatial components
        contours = caiman_get_contours(self.cnmA[:, self.idx_components], self.dims, thr=0.9)

        temporal_components = self.cnmC

        self.input_params_dict = self.input_params_dict
        num_components = len(temporal_components)

        self.ui.radioButton_curve_data.setChecked(True)

        for ix in range(num_components):
            self.vi.viewer.status_bar_label.showMessage('Please wait, adding component #: '
                                                        + str(ix) + ' / ' + str(num_components))

            curve_data = temporal_components[ix]
            contour = contours[ix]

            roi = VolCNMF(curve_plot_item=self.get_plot_item(),
                          view_box=self.vi.viewer.getView(),
                          cnmf_idx=self.idx_components[ix],
                          curve_data=curve_data,
                          contour=contour,
                          dfof_data=self.cnm_dfof[ix] if (self.cnm_dfof is not None) else None,
                          spike_data=self.cnmS[ix])

            self.roi_list.append(roi)

        self.vi.workEnv_changed("ROIs imported")
        self.roi_list.reindex_colormap(random_shuffle=True)
        self.vi.viewer.status_bar_label.showMessage('Finished adding all components!')
Ejemplo n.º 2
0
    def add_all_components(self,
                           cnmf_data_dict,
                           input_params_dict,
                           calc_raw_min_max=False):
        """
        Add all components from a CNMF(E) output. Arguments correspond to CNMF(E) outputs

        :param cnmf_data_dict:      CNMF results data directly from the HDF5 file
        :param input_params_dict:   dict of input params, from the batch manager
        :param calc_raw_min_max:    Calculate raw min & max for each ROI
        :return:
        """
        if not hasattr(self, 'roi_list'):
            self.create_roi_list()

        self.cnmf_data_dict = cnmf_data_dict

        self.cnmA = self.cnmf_data_dict['estimates']['A']
        self.cnmb = self.cnmf_data_dict['estimates']['b']
        self.cnm_f = self.cnmf_data_dict['estimates']['f']
        self.cnmC = self.cnmf_data_dict['estimates']['C']
        self.cnmS = self.cnmf_data_dict['estimates']['S']
        self.cnm_dfof = self.cnmf_data_dict['estimates']['F_dff']
        self.cnmYrA = self.cnmf_data_dict['estimates']['YrA']
        self.dims = self.cnmf_data_dict['dims']
        self.idx_components = cnmf_data_dict['estimates']['idx_components']

        if self.idx_components is None:
            self.idx_components = np.arange(self.cnmC.shape[0])

        self.orig_idx_components = deepcopy(self.idx_components)
        self.input_params_dict = input_params_dict

        # spatial components
        contours = caiman_get_contours(self.cnmA[:, self.idx_components],
                                       self.dims)
        # if dfof:
        #     temporal_components = cnmC
        # else:
        #     temporal_components = cnmC[idx_components]
        self.input_params_dict = self.input_params_dict
        num_components = len(self.cnmC)

        if calc_raw_min_max:
            img = self.vi.viewer.workEnv.imgdata.seq.T

        self.ui.radioButton_curve_data.setChecked(True)

        for ix in range(num_components):
            self.vi.viewer.status_bar_label.showMessage(
                'Please wait, adding component #: ' + str(ix) + ' / ' +
                str(num_components))

            curve_data = self.cnmC[ix]
            contour = contours[ix]

            if calc_raw_min_max:
                # Get a binary mask
                mask = self.cnmA[:, self.idx_components[ix]].toarray().reshape(
                    self.dims, order='F') > 0
                # mask3d = np.array((mask,) * curve_data.shape[0])

                max_ix = curve_data.argmax()
                min_ix = curve_data.argmin()

                array_at_max = img[max_ix, :, :].copy()
                array_at_max = array_at_max[mask]

                array_at_min = img[min_ix, :, :].copy()
                array_at_min = array_at_min[mask]

                raw_min_max = self.get_raw_min_max(array_at_max=array_at_max,
                                                   array_at_min=array_at_min)

            else:
                raw_min_max = None

            roi = CNMFROI(curve_plot_item=self.get_plot_item(),
                          view_box=self.vi.viewer.getView(),
                          cnmf_idx=self.idx_components[ix],
                          curve_data=curve_data,
                          contour=contour,
                          raw_min_max=raw_min_max,
                          dfof_data=self.cnm_dfof[ix] if
                          (self.cnm_dfof is not None) else None,
                          spike_data=self.cnmS[ix])

            self.roi_list.append(roi)

        if calc_raw_min_max:
            del img

        self.roi_list.reindex_colormap()
        self.vi.viewer.status_bar_label.showMessage(
            'Finished adding all components!')
Ejemplo n.º 3
0
    def restore_from_states(self, states: dict):
        if 'metadata' in states.keys():
            self.metadata = states['metadata']

        if not hasattr(self, 'roi_list'):
            self.create_roi_list()

        self.cnmf_data_dicts = states['cnmf_data_dicts']

        # for state in states['states']:
        #     roi = VolMultiCNMFROI.from_state(
        #         self.get_plot_item(),
        #         self.vi.viewer.getView(),
        #         state
        #     )
        #
        #     self.roi_list.append(roi)

        self.input_params_dict = states['input_params_cnmf']
        self.num_zlevels = states['num_zlevels']

        self.cnmA = states['cnmf_output']['cnmA']
        self.cnmb = states['cnmf_output']['cnmb']
        self.cnmC = states['cnmf_output']['cnmC']
        self.cnm_f = states['cnmf_output']['cnm_f']
        self.cnmYrA = states['cnmf_output']['cnmYrA']
        self.idx_components = states['cnmf_output']['idx_components']
        self.orig_idx_components = states['cnmf_output']['orig_idx_components']
        self.dims = states['cnmf_output']['dims']

        if not hasattr(self, 'roi_list'):
            self.create_roi_list()

        for zcenter in range(self.num_zlevels):
            print(f"Loading z-level {zcenter}")
            contours = caiman_get_contours(
                self.cnmA[zcenter][:, self.idx_components[zcenter]],
                self.dims[zcenter],
                # swap_dim=True
            )

            num_components = len(self.idx_components[zcenter])

            self.ui.radioButton_curve_data.setChecked(True)

            roi_ixs = []
            roi_xy = []

            for ix in range(len(contours)):
                coors = contours[ix]['coordinates']
                coors = coors[~np.isnan(coors).any(axis=1)]
                roi_xy += [coors]
                roi_ixs += [ix] * coors.shape[0]

            roi_xy = np.vstack(roi_xy)
            roi_ixs = np.vstack(roi_ixs)

            self.roi_xys.append(roi_xy)
            self.roi_ixs.append(roi_ixs)

            cm = matplotlib_color_map.get_cmap('hsv')
            cm._init()
            lut = (cm._lut * 255).view(np.ndarray)

            cm_ixs = np.linspace(0, 210, np.unique(roi_ixs).size + 1, dtype=int)

            roi_crs = []

            for roi_ix, cm_ix in zip(np.unique(roi_ixs), cm_ixs):
                c = lut[cm_ix]
                roi_crs.append(
                    np.array([c] * roi_ixs[roi_ixs == roi_ix].size)  # color for each spot
                )

            roi_crs = np.vstack(roi_crs)
            self.roi_crs.append(roi_crs)

            xy_coors = self.roi_xys[-1]

            brushes = list(map(pg.mkBrush, roi_crs))
            pens = list(map(pg.mkPen, roi_crs))

            sp = pg.ScatterPlotItem(
                xy_coors[:, 0],
                xy_coors[:, 1],
                symbol='s',
                size=1,
                pxMode=True,
                brush=brushes,
                pen=pens
            )

            self.vi.viewer.getView().addItem(sp)
            sp.hide()
            self.roi_sps.append(sp)

            for ix in tqdm(range(len(self.idx_components[zcenter]))):
                self.vi.viewer.status_bar_label.showMessage(
                    f"Please wait, adding component {ix} / {num_components} "
                    f"on zlevel {zcenter} / {self.num_zlevels - 1}"
                )

                curve_data = self.cnmC[zcenter][self.idx_components[zcenter][ix]]
                contour = contours[ix]

                cnmf_idx = self.idx_components[zcenter][ix]

                roi = VolMultiCNMFROI(
                    curve_plot_item=self.get_plot_item(),
                    view_box=self.vi.viewer.getView(),
                    cnmf_idx=cnmf_idx,
                    curve_data=curve_data,
                    contour=contour,
                    zcenter=zcenter,
                    zlevel=self.vi.viewer.current_zlevel,
                    roi_ix=ix,
                    scatter_plot=sp,
                    parent_manager=self,
                )

                roi_state = list(
                    filter(
                        lambda r: r['cnmf_idx'] == cnmf_idx,
                        states['states'][zcenter]
                    )
                )[0]

                for k in roi_state['tags'].keys():
                    roi.set_tag(k, roi_state['tags'][k])

                roi.dfof_data = roi_state['dfof_data']
                roi.spike_data = roi_state['spike_data']

                self.roi_list.append(roi)

        self.roi_list.list_widget.addItems(
            list(map(str, range(len(self.roi_list))))
        )

        self.vi.workEnv_changed("ROIs imported")

        # self.roi_list.reindex_colormap(random_shuffle=True)
        self.roi_sps[self.vi.viewer.current_zlevel].show()

        self.vi.viewer.status_bar_label.showMessage('Finished adding all components!')
Ejemplo n.º 4
0
    def add_all_components(
            self,
            cnmf_data_dicts: List[dict],
            input_params_dict: dict,
    ):
        self.input_params_dict = input_params_dict

        if not hasattr(self, 'roi_list'):
            self.create_roi_list()

        self.cnmf_data_dicts = cnmf_data_dicts

        self.num_zlevels = len(self.cnmf_data_dicts)

        for zcenter, cnmf_data_dict in enumerate(self.cnmf_data_dicts):
            self.cnmA.append(cnmf_data_dict['estimates']['A'])
            self.cnmb.append(cnmf_data_dict['estimates']['b'])
            self.cnm_f.append(cnmf_data_dict['estimates']['f'])
            self.cnmC.append(cnmf_data_dict['estimates']['C'])
            self.cnmS.append(cnmf_data_dict['estimates']['S'])
            self.cnm_dfof.append(cnmf_data_dict['estimates']['F_dff'])
            self.cnmYrA.append(cnmf_data_dict['estimates']['YrA'])
            self.dims.append(cnmf_data_dict['dims'])
            self.idx_components.append(cnmf_data_dict['estimates']['idx_components'])

            if self.idx_components[-1] is None:
                self.idx_components[-1] = np.arange(self.cnmC[-1].shape[0])

            self.orig_idx_components.append(
                deepcopy(
                    self.idx_components[-1]
                )
            )

            contours = caiman_get_contours(
                self.cnmA[-1][:, self.idx_components[-1]],
                self.dims[-1],
                # swap_dim=True
            )

            num_components = len(self.idx_components[-1])

            self.ui.radioButton_curve_data.setChecked(True)

            roi_ixs = []
            roi_xy = []

            for ix in range(len(contours)):
                coors = contours[ix]['coordinates']
                coors = coors[~np.isnan(coors).any(axis=1)]
                roi_xy += [coors]
                roi_ixs += [ix] * coors.shape[0]

            roi_xy = np.vstack(roi_xy)
            roi_ixs = np.vstack(roi_ixs)

            self.roi_xys.append(roi_xy)
            self.roi_ixs.append(roi_ixs)

            cm = matplotlib_color_map.get_cmap('hsv')
            cm._init()
            lut = (cm._lut * 255).view(np.ndarray)

            cm_ixs = np.linspace(0, 210, np.unique(roi_ixs).size + 1, dtype=int)

            roi_crs = []

            for roi_ix, cm_ix in zip(np.unique(roi_ixs), cm_ixs):
                c = lut[cm_ix]
                roi_crs.append(
                    np.array([c] * roi_ixs[roi_ixs == roi_ix].size)  # color for each spot
                )

            roi_crs = np.vstack(roi_crs)
            self.roi_crs.append(roi_crs)

            xy_coors = self.roi_xys[-1]

            brushes = list(map(pg.mkBrush, roi_crs))
            pens = list(map(pg.mkPen, roi_crs))

            sp = pg.ScatterPlotItem(
                xy_coors[:, 0],
                xy_coors[:, 1],
                symbol='s',
                size=1,
                pxMode=True,
                brush=brushes,
                pen=pens
            )

            self.vi.viewer.getView().addItem(sp)
            sp.hide()
            self.roi_sps.append(sp)

            for ix in range(num_components):
                self.vi.viewer.status_bar_label.showMessage(
                    f"Please wait, adding component {ix} / {num_components} "
                    f"on zlevel {zcenter} / {self.num_zlevels - 1}"
                )

                curve_data = self.cnmC[-1][ix]
                contour = contours[ix]

                if self.cnm_dfof[-1] is not None:
                    dfof_data = self.cnm_dfof[-1][ix]
                else:
                    dfof_data = None

                roi = VolMultiCNMFROI(
                    curve_plot_item=self.get_plot_item(),
                    view_box=self.vi.viewer.getView(),
                    cnmf_idx=self.idx_components[-1][ix],
                    curve_data=curve_data,
                    contour=contour,
                    spike_data=self.cnmS[-1][ix],
                    dfof_data=dfof_data,
                    zcenter=zcenter,
                    zlevel=self.vi.viewer.current_zlevel,
                    roi_ix=ix,
                    scatter_plot=sp,
                    parent_manager=self,
                )

                self.roi_list.append(roi, add_to_list_widget=False)

        self.roi_list.list_widget.addItems(
            list(map(str, range(len(self.roi_list))))
        )

        self.vi.workEnv_changed("ROIs imported")

        # self.roi_list.reindex_colormap(random_shuffle=True)
        self.roi_sps[self.vi.viewer.current_zlevel].show()

        self.vi.viewer.status_bar_label.showMessage('Finished adding all components!')
Ejemplo n.º 5
0
    def add_all_components(self,
                           cnmA,
                           cnmb,
                           cnmC,
                           cnm_f,
                           cnmYrA,
                           idx_components,
                           dims,
                           input_params_dict,
                           dfof=False,
                           calc_raw_min_max=False):
        """Add all components from a CNMF(E) output. Arguments correspond to CNMF(E) outputs"""
        if not hasattr(self, 'roi_list'):
            self.create_roi_list()
        self.cnmA = cnmA
        self.cnmb = cnmb
        self.cnmC = cnmC
        self.cnm_f = cnm_f
        self.cnmYrA = cnmYrA
        self.idx_components = idx_components
        self.orig_idx_components = deepcopy(idx_components)
        self.input_params_dict = input_params_dict

        # spatial components
        contours = caiman_get_contours(cnmA[:, idx_components], dims)
        if dfof:
            temporal_components = cnmC
        else:
            temporal_components = cnmC[idx_components]
        self.input_params_dict = self.input_params_dict
        num_components = len(temporal_components)

        if calc_raw_min_max:
            img = self.vi.viewer.workEnv.imgdata.seq.T

        for ix in range(num_components):
            self.vi.viewer.status_bar_label.showMessage(
                'Please wait, adding component #: ' + str(ix) + ' / ' +
                str(num_components))

            curve_data = temporal_components[ix]
            contour = contours[ix]

            if calc_raw_min_max:
                # Get a binary mask
                mask = self.cnmA[:, idx_components[ix]].toarray().reshape(
                    dims, order='F') > 0
                # mask3d = np.array((mask,) * curve_data.shape[0])

                max_ix = curve_data.argmax()
                min_ix = curve_data.argmin()

                array_at_max = img[max_ix, :, :].copy()
                array_at_max = array_at_max[mask]

                array_at_min = img[min_ix, :, :].copy()
                array_at_min = array_at_min[mask]

                raw_min_max = self.get_raw_min_max(array_at_max=array_at_max,
                                                   array_at_min=array_at_min)

            else:
                raw_min_max = None

            roi = CNMFROI(self.get_plot_item(),
                          self.vi.viewer.getView(),
                          idx_components[ix],
                          curve_data,
                          contour,
                          raw_min_max=raw_min_max)

            self.roi_list.append(roi)

        if calc_raw_min_max:
            del img

        self.roi_list.reindex_colormap()
        self.vi.viewer.status_bar_label.showMessage(
            'Finished adding all components!')