Ejemplo n.º 1
0
    def plot_channels(
        self,
        channels: List[str] = ["mean"],
        per_sample: bool = False,
        merged: bool = False,
        save: bool = False,
        output_dir: Optional[Path] = None,
        samples: Optional[List["IMCSample"]] = None,
        rois: Optional[List["ROI"]] = None,
        **kwargs,
    ) -> Figure:
        """
        Plot a list of channels for all Samples/ROIs.
        """
        if isinstance(channels, str):
            channels = [channels]
        output_dir = Path(output_dir or self.results_dir / "qc")
        if save:
            output_dir.mkdir(exist_ok=True)
            channels_str = ",".join(channels)
            fig_file = output_dir / ".".join(
                [self.name, f"all_rois.{channels_str}.pdf"])
        if per_sample:
            for sample in samples or self.samples:
                fig = sample.plot_channels(channels, **kwargs)
                if save:
                    fig_file = output_dir / ".".join([
                        self.name, sample.name, f"all_rois.{channels_str}.pdf"
                    ])
                    fig.savefig(fig_file, **FIG_KWS)
        else:
            rois = self._get_rois(samples, rois)

            i = 0
            j = 1 if merged else len(channels)
            n, m = (get_grid_dims(len(rois)) if merged else get_grid_dims(
                len(rois) * j))
            fig, axes = plt.subplots(n, m, figsize=(4 * m, 4 * n))
            axes = axes.flatten()
            for roi in rois:
                roi.plot_channels(channels,
                                  axes=axes[i:i + j],
                                  merged=merged,
                                  **kwargs)
                i += j
            for _ax in axes[i:]:
                _ax.axis("off")
            if save:
                fig.savefig(fig_file, **FIG_KWS)
        return fig
Ejemplo n.º 2
0
Archivo: sample.py Proyecto: bzrry/imc
    def plot_rois(self,
                  channel: Union[str, int],
                  rois: Optional[List["ROI"]] = None) -> Figure:  # List[ROI]
        """Plot a single channel for all ROIs"""
        rois = rois or self.rois

        n, m = get_grid_dims(len(rois))
        fig, axis = plt.subplots(n, m, figsize=(m * 4, n * 4), squeeze=False)
        axis = axis.flatten()
        i = 0  # just in case there are no ROIs
        for i, roi in enumerate(rois):
            roi.plot_channel(channel, ax=axis[i])
        for _ax in axis[i:]:
            _ax.axis("off")
        return fig
Ejemplo n.º 3
0
    def plot_overlayied_channels_subplots(self, n_groups: int) -> Figure:
        """
        Plot all channels of ROI in `n_groups` combinations, where each combination
        has as little overlap as possible.
        """
        stack = self.stack

        _, marker_sets = self.get_distinct_marker_sets(
            n_groups=n_groups,
            group_size=int(np.floor(self.channel_number / n_groups)),
        )

        n, m = get_grid_dims(n_groups)
        fig, axis = plt.subplots(
            n,
            m,
            figsize=(6 * m, 6 * n),
            sharex=True,
            sharey=True,
            squeeze=False,
        )
        axis = axis.flatten()
        for i, (marker_set, mrks) in enumerate(marker_sets.items()):
            patches = list()
            cmaps = get_transparent_cmaps(len(mrks))
            for _, (_l, c) in enumerate(zip(mrks, cmaps)):
                x = stack[self.channel_labels == _l, :, :].squeeze()
                v = x.mean() + x.std() * 2
                axis[i].imshow(
                    x,
                    cmap=c,
                    vmin=0,
                    vmax=v,
                    label=_l,
                    interpolation="bilinear",
                    rasterized=True,
                )
                axis[i].axis("off")
                patches.append(mpatches.Patch(color=c(256), label=m))
            axis[i].legend(
                handles=patches,
                bbox_to_anchor=(1.05, 1),
                loc=2,
                borderaxespad=0.0,
                title=marker_set,
            )
        return fig
Ejemplo n.º 4
0
    def plot_channels(
        self,
        channels: Optional[List[str]] = None,
        merged: bool = False,
        axes: List[Axis] = None,
        equalize: bool = None,
        log: bool = True,
        minmax: bool = True,
        add_scale: bool = True,
        add_range: bool = True,
        share_axes: bool = True,
        **kwargs,
    ) -> Optional[Figure]:
        """If axes is given it must be length channels"""
        # TODO: optimize this by avoiding reading stack for every channel
        if channels is None:
            channels = self.channel_labels.index

        if axes is None:
            n, m = (1, 1) if merged else get_grid_dims(len(channels))
            fig, _axes = plt.subplots(
                n,
                m,
                figsize=(m * 4, n * 4),
                squeeze=False,
                sharex=share_axes,
                sharey=share_axes,
            )
            fig.suptitle(f"{self.sample}\n{self}")
            _axes = _axes.flatten()
        else:
            _axes = axes

        # i = 0  # in case merged or len(channels) is 0
        if merged:
            if equalize is None:
                equalize = True

            names, arr, minmaxes = self._get_channels(list(channels),
                                                      log=log,
                                                      equalize=equalize,
                                                      minmax=minmax)
            arr2, colors = merge_channels(arr, return_colors=True, **kwargs)
            x, y, _ = arr2.shape
            _axes[0].imshow(arr2 / arr2.max())
            x = x * 0.05
            y = y * 0.05
            bbox = dict(
                boxstyle="round",
                ec=(0.3, 0.3, 0.3, 0.5),
                fc=(0.0, 0.0, 0.0, 0.5),
            )
            rainbow_text(
                x,
                y,
                names.split(","),
                colors,
                ax=_axes[0],
                fontsize=3,
                bbox=bbox,
            )
            if add_scale:
                _add_scale(_axes[0])
            # TODO: add minmaxes, perhaps to the channel labels?
            # if add_range:
            #     _add_minmax(minmax, _ax)
        else:
            if equalize is None:
                equalize = True
            for i, channel in enumerate(channels):
                self.plot_channel(
                    channel,
                    ax=_axes[i],
                    equalize=equalize,
                    log=log,
                    add_scale=add_scale,
                    add_range=add_range,
                    **kwargs,
                )
        for _ax in _axes:  # [i + 1 :]
            _ax.axis("off")
        return fig if axes is None else None
Ejemplo n.º 5
0
    def measure_adjacency(
        self,
        output_prefix: Optional[Path] = None,
        samples: Optional[List["IMCSample"]] = None,
        rois: Optional[List["ROI"]] = None,
    ) -> None:
        """
        Derive cell adjacency graphs for each ROI.
        """
        output_prefix = (output_prefix
                         or self.results_dir / "single_cell" / self.name + ".")
        rois = self._get_rois(samples, rois)

        # Get graph for missing ROIs
        _rois = [r for r in rois if r._adjacency_graph is None]
        if _rois:
            gs = parmap.map(get_adjacency_graph, _rois, pm_pbar=True)
            # gs = [get_adjacency_graph(roi) for roi in _rois]
            for roi, g in zip(_rois, gs):
                roi._adjacency_graph = g

        # TODO: package the stuff below into a function

        # First measure adjacency as odds against background
        freqs = parmap.map(measure_cell_type_adjacency, rois)
        # freqs = [measure_cell_type_adjacency(roi) for roi in rois]
        # freqs = [
        #     pd.read_csv(
        #         roi.sample.root_dir / "single_cell" / roi.name
        #         + ".cluster_adjacency_graph.norm_over_random.csv",
        #         index_col=0,
        #     )
        #     for roi in rois
        # ]

        melted = pd.concat([
            f.reset_index().melt(id_vars="index").assign(
                sample=roi.sample.name, roi=roi.name)
            for roi, f in zip(rois, freqs)
        ])

        # mean_f = melted.pivot_table(
        #     index="index", columns="variable", values="value", aggfunc=np.mean
        # )
        # sns.clustermap(mean_f, cmap="RdBu_r", center=0, robust=True)

        v = np.percentile(melted["value"].abs(), 95)
        n, m = get_grid_dims(len(freqs))
        fig, axes = plt.subplots(n,
                                 m,
                                 figsize=(m * 5, n * 5),
                                 sharex=True,
                                 sharey=True)
        axes = axes.flatten()
        i = -1
        for i, (dfs, roi) in enumerate(zip(freqs, rois)):
            axes[i].set_title(roi.name)
            sns.heatmap(
                dfs,
                ax=axes[i],
                cmap="RdBu_r",
                center=0,
                rasterized=True,
                square=True,
                xticklabels=True,
                yticklabels=True,
                vmin=-v,
                vmax=v,
            )
        for axs in axes[i + 1:]:
            axs.axis("off")
        fig.savefig(output_prefix + "adjacency.all_rois.pdf", **FIG_KWS)