Ejemplo n.º 1
0
    def test_despine_trim_noticks(self):

        f, ax = plt.subplots()
        ax.plot([1, 2, 3], [1, 2, 3])
        ax.set_yticks([])
        utils.despine(trim=True)
        assert ax.get_yticks().size == 0
Ejemplo n.º 2
0
 def plot_dendrograms(self, row_cluster, col_cluster, metric, method,
                      row_linkage, col_linkage):
     # Plot the row dendrogram
     if row_cluster:
         self.dendrogram_row = dendrogram(self.data2d,
                                          metric=metric,
                                          method=method,
                                          label=False,
                                          axis=0,
                                          ax=self.ax_row_dendrogram,
                                          rotate=True,
                                          linkage=row_linkage)
     else:
         self.ax_row_dendrogram.set_xticks([])
         self.ax_row_dendrogram.set_yticks([])
     # PLot the column dendrogram
     if col_cluster:
         self.dendrogram_col = dendrogram(self.data2d,
                                          metric=metric,
                                          method=method,
                                          label=False,
                                          axis=1,
                                          ax=self.ax_col_dendrogram,
                                          linkage=col_linkage)
     else:
         self.ax_col_dendrogram.set_xticks([])
         self.ax_col_dendrogram.set_yticks([])
     despine(ax=self.ax_row_dendrogram, bottom=True, left=True)
     despine(ax=self.ax_col_dendrogram, bottom=True, left=True)
Ejemplo n.º 3
0
def pairplot(data,
             target_col,
             columns=None,
             scatter_alpha='auto',
             scatter_size='auto'):
    """Pairplot (scattermatrix)

    Because there's already too many implementations of this.
    This is meant for classification only.
    This is very bare-bones right now :-/

    Parameters
    ----------
    data : pandas dataframe
        Input data
    target_col : column specifier
        Target column in data.
    columns : column specifiers, default=None.
        Columns in data to include. None means all.
    scatter_alpha : float, default='auto'
        Alpha values for scatter plots. 'auto' is dirty hacks.
    scatter_size : float, default='auto'.
        Marker size for scatter plots. 'auto' is dirty hacks.
    """
    if columns is None:
        columns = data.columns.drop(target_col)
    n_features = len(columns)
    fig, axes = plt.subplots(n_features,
                             n_features,
                             figsize=(n_features * 3, n_features * 3))
    axes = np.atleast_2d(axes)
    for ax, (i, j) in zip(axes.ravel(),
                          itertools.product(range(n_features), repeat=2)):
        legend = i == 0 and j == n_features - 1
        if i == j:
            class_hists(data, columns[i], target_col, ax=ax.twinx())
        else:
            discrete_scatter(data[columns[j]],
                             data[columns[i]],
                             c=data[target_col],
                             legend=legend,
                             ax=ax,
                             alpha=scatter_alpha,
                             s=scatter_size)
        if j == 0:
            ax.set_ylabel(columns[i])
        else:
            ax.set_ylabel("")
            ax.set_yticklabels(())
        if i == n_features - 1:
            ax.set_xlabel(_shortname(columns[j]))
        else:
            ax.set_xlabel("")
            ax.set_xticklabels(())
    despine(fig)
    if n_features > 1:
        axes[0, 0].set_yticks(axes[0, 1].get_yticks())
        axes[0, 0].set_ylim(axes[0, 1].get_ylim())
    return axes
Ejemplo n.º 4
0
    def plot(self, ax):
        """Plots a dendrogram of the similarities between data on the axes

        Parameters
        ----------
        ax : matplotlib.axes.Axes
            Axes object upon which the dendrogram is plotted

        """
        line_kwargs = dict(linewidths=.5, colors='k')
        if self.rotate and self.axis == 0:
            lines = LineCollection([
                list(zip(x, y))
                for x, y in zip(self.dependent_coord, self.independent_coord)
            ], **line_kwargs)
        else:
            lines = LineCollection([
                list(zip(x, y))
                for x, y in zip(self.independent_coord, self.dependent_coord)
            ], **line_kwargs)

        ax.add_collection(lines)
        number_of_leaves = len(self.reordered_ind)
        max_dependent_coord = max(map(max, self.dependent_coord))

        if self.rotate:
            ax.yaxis.set_ticks_position('right')

            # Constants 10 and 1.05 come from
            # `scipy.cluster.hierarchy._plot_dendrogram`
            ax.set_ylim(0, number_of_leaves * 10)
            ax.set_xlim(0, max_dependent_coord * 1.05)

            ax.invert_xaxis()
            ax.invert_yaxis()
        else:
            # Constants 10 and 1.05 come from
            # `scipy.cluster.hierarchy._plot_dendrogram`
            ax.set_xlim(0, number_of_leaves * 10)
            ax.set_ylim(0, max_dependent_coord * 1.05)

        despine(ax=ax, bottom=True, left=True)

        ax.set(xticks=self.xticks,
               yticks=self.yticks,
               xlabel=self.xlabel,
               ylabel=self.ylabel)
        xtl = ax.set_xticklabels(self.xticklabels)
        ytl = ax.set_yticklabels(self.yticklabels, rotation='vertical')

        # Force a draw of the plot to avoid matplotlib window error
        plt.draw()
        if len(ytl) > 0 and axis_ticklabels_overlap(ytl):
            plt.setp(ytl, rotation="horizontal")
        if len(xtl) > 0 and axis_ticklabels_overlap(xtl):
            plt.setp(xtl, rotation="vertical")
        return self
Ejemplo n.º 5
0
    def test_despine_trim_spines(self):

        f, ax = plt.subplots()
        ax.plot([1, 2, 3], [1, 2, 3])
        ax.set_xlim(.75, 3.25)

        utils.despine(trim=True)
        for side in self.inner_sides:
            bounds = ax.spines[side].get_bounds()
            assert bounds == (1, 3)
Ejemplo n.º 6
0
    def test_despine_trim_inverted(self):

        f, ax = plt.subplots()
        ax.plot([1, 2, 3], [1, 2, 3])
        ax.set_ylim(.85, 3.15)
        ax.invert_yaxis()

        utils.despine(trim=True)
        for side in self.inner_sides:
            bounds = ax.spines[side].get_bounds()
            assert bounds == (1, 3)
Ejemplo n.º 7
0
    def test_despine_trim_categorical(self):

        f, ax = plt.subplots()
        ax.plot(["a", "b", "c"], [1, 2, 3])

        utils.despine(trim=True)

        bounds = ax.spines["left"].get_bounds()
        assert bounds == (1, 3)

        bounds = ax.spines["bottom"].get_bounds()
        assert bounds == (0, 2)
Ejemplo n.º 8
0
    def test_despine_specific_axes(self):
        f, (ax1, ax2) = plt.subplots(2, 1)

        utils.despine(ax=ax2)

        for side in self.sides:
            assert ax1.spines[side].get_visible()

        for side in self.outer_sides:
            assert ~ax2.spines[side].get_visible()
        for side in self.inner_sides:
            assert ax2.spines[side].get_visible()
Ejemplo n.º 9
0
    def test_despine_side_specific_offset(self):

        f, ax = plt.subplots()
        utils.despine(ax=ax, offset=dict(left=self.offset))

        for side in self.sides:
            is_visible = ax.spines[side].get_visible()
            new_position = ax.spines[side].get_position()
            if is_visible and side == "left":
                assert new_position == self.offset_position
            else:
                assert new_position == self.original_position
Ejemplo n.º 10
0
    def test_despine_with_offset_specific_axes(self):
        f, (ax1, ax2) = plt.subplots(2, 1)

        utils.despine(offset=self.offset, ax=ax2)

        for side in self.sides:
            pos1 = ax1.spines[side].get_position()
            pos2 = ax2.spines[side].get_position()
            assert pos1 == self.original_position
            if ax2.spines[side].get_visible():
                assert pos2 == self.offset_position
            else:
                assert pos2 == self.original_position
Ejemplo n.º 11
0
    def test_despine(self):
        f, ax = plt.subplots()
        for side in self.sides:
            assert ax.spines[side].get_visible()

        utils.despine()
        for side in self.outer_sides:
            assert ~ax.spines[side].get_visible()
        for side in self.inner_sides:
            assert ax.spines[side].get_visible()

        utils.despine(**dict(zip(self.sides, [True] * 4)))
        for side in self.sides:
            assert ~ax.spines[side].get_visible()
Ejemplo n.º 12
0
    def test_despine_moved_ticks(self):

        f, ax = plt.subplots()
        for t in ax.yaxis.majorTicks:
            t.tick1line.set_visible(True)
        utils.despine(ax=ax, left=True, right=False)
        for t in ax.yaxis.majorTicks:
            assert t.tick2line.get_visible()
        plt.close(f)

        f, ax = plt.subplots()
        for t in ax.yaxis.majorTicks:
            t.tick1line.set_visible(False)
        utils.despine(ax=ax, left=True, right=False)
        for t in ax.yaxis.majorTicks:
            assert not t.tick2line.get_visible()
        plt.close(f)

        f, ax = plt.subplots()
        for t in ax.xaxis.majorTicks:
            t.tick1line.set_visible(True)
        utils.despine(ax=ax, bottom=True, top=False)
        for t in ax.xaxis.majorTicks:
            assert t.tick2line.get_visible()
        plt.close(f)

        f, ax = plt.subplots()
        for t in ax.xaxis.majorTicks:
            t.tick1line.set_visible(False)
        utils.despine(ax=ax, bottom=True, top=False)
        for t in ax.xaxis.majorTicks:
            assert not t.tick2line.get_visible()
        plt.close(f)
Ejemplo n.º 13
0
    def test_despine_with_offset(self):
        f, ax = plt.subplots()

        for side in self.sides:
            pos = ax.spines[side].get_position()
            assert pos == self.original_position

        utils.despine(ax=ax, offset=self.offset)

        for side in self.sides:
            is_visible = ax.spines[side].get_visible()
            new_position = ax.spines[side].get_position()
            if is_visible:
                assert new_position == self.offset_position
            else:
                assert new_position == self.original_position
Ejemplo n.º 14
0
    def plot(self, ax, cax, kws):
        """Draw the heatmap on the provided Axes."""
        # Remove all the Axes spines
        despine(ax=ax, left=True, bottom=True)

        #Annie Yim
        # Draw the heatmap
        self.mesh = ax.pcolormesh(self.plot_data,
                                  vmin=self.vmin,
                                  vmax=self.vmax,
                                  cmap=self.cmap,
                                  **kws)

        # Set the axis limits
        ax.set(xlim=(0, self.data.shape[1]), ylim=(0, self.data.shape[0]))

        # Add row and column labels
        ax.set(xticks=self.xticks, yticks=self.yticks)
        xtl = ax.set_xticklabels(self.xticklabels)
        ytl = ax.set_yticklabels(self.yticklabels, rotation="vertical")

        # Possibly rotate them if they overlap
        plt.draw()
        if axis_ticklabels_overlap(xtl):
            plt.setp(xtl, rotation="vertical")
        if axis_ticklabels_overlap(ytl):
            plt.setp(ytl, rotation="horizontal")

        # Add the axis labels
        ax.set(xlabel=self.xlabel, ylabel=self.ylabel)

        # Annotate the cells with the formatted values
        if self.annot:
            self._annotate_heatmap(ax, mesh)

        #Annie Yim
        # Possibly add a colorbar
        if self.cbar:
            cb = ax.figure.colorbar(self.mesh, cax, ax, **self.cbar_kws)
            cb.outline.set_linewidth(0)
            # If rasterized is passed to pcolormesh, also rasterize the
            # colorbar to avoid white lines on the PDF rendering
            if kws.get('rasterized', False):
                cb.solids.set_rasterized(True)
Ejemplo n.º 15
0
def joint_plot(ratio=1, height=3):
    """
    Taken from Seaborn JointGrid
    """
    fig = plt.figure(figsize=(height, height))
    gsp = plt.GridSpec(ratio + 1, ratio + 1)

    ax_joint = fig.add_subplot(gsp[1:, :-1])
    ax_marg_x = fig.add_subplot(gsp[0, :-1], sharex=ax_joint)
    ax_marg_y = fig.add_subplot(gsp[1:, -1], sharey=ax_joint)

    # Turn off tick visibility for the measure axis on the marginal plots
    plt.setp(ax_marg_x.get_xticklabels(), visible=False)
    plt.setp(ax_marg_y.get_yticklabels(), visible=False)

    # Turn off the ticks on the density axis for the marginal plots
    plt.setp(ax_marg_x.yaxis.get_majorticklines(), visible=False)
    plt.setp(ax_marg_x.yaxis.get_minorticklines(), visible=False)
    plt.setp(ax_marg_y.xaxis.get_majorticklines(), visible=False)
    plt.setp(ax_marg_y.xaxis.get_minorticklines(), visible=False)
    plt.setp(ax_marg_x.get_yticklabels(), visible=False)
    plt.setp(ax_marg_y.get_xticklabels(), visible=False)
    ax_marg_x.yaxis.grid(False)
    ax_marg_y.xaxis.grid(False)

    # Make the grid look nice
    from seaborn import utils
    # utils.despine(fig)
    utils.despine(ax=ax_marg_x, left=True)
    utils.despine(ax=ax_marg_y, bottom=True)
    fig.tight_layout(h_pad=0, w_pad=0)

    ax_marg_y.tick_params(axis='y', which='major', direction='out')
    ax_marg_x.tick_params(axis='x', which='major', direction='out')
    ax_marg_y.tick_params(axis='y', which='minor', direction='out')
    ax_marg_x.tick_params(axis='x', which='minor', direction='out')
    ax_marg_y.margins(x=0.1, y=0.)

    fig.subplots_adjust(hspace=0, wspace=0)

    return fig, ax_joint, ax_marg_x, ax_marg_y
    def plot(self, ax, cax, kws):
        """Draw the scattermap on the provided Axes."""
        # Remove all the Axes spines
        despine(ax=ax, left=True, bottom=True)

        # Draw the heatmap
        data = self.plot_data

        range_y = np.arange(data.shape[0], dtype=int) + 0.5
        range_x = np.arange(data.shape[1], dtype=int) + 0.5
        x, y = np.meshgrid(range_x, range_y)

        hmap = ax.scatter(x,
                          y,
                          c=data,
                          marker=self.marker,
                          cmap=self.cmap,
                          vmin=self.vmin,
                          vmax=self.vmax,
                          s=self.marker_size,
                          **kws)

        # Set the axis limits
        ax.set(xlim=(0, self.data.shape[1]), ylim=(0, self.data.shape[0]))

        # Possibly add a colorbar
        if self.cbar:
            cb = ax.figure.colorbar(hmap, cax, ax, **self.cbar_kws)
            cb.outline.set_linewidth(0)
            # If rasterized is passed to pcolormesh, also rasterize the
            # colorbar to avoid white lines on the PDF rendering
            if kws.get('rasterized', False):
                cb.solids.set_rasterized(True)

        # Add row and column labels
        if isinstance(self.xticks, string_types) and self.xticks == "auto":
            xticks, xticklabels = self._auto_ticks(ax, self.xticklabels, 0)
        else:
            xticks, xticklabels = self.xticks, self.xticklabels

        if isinstance(self.yticks, string_types) and self.yticks == "auto":
            yticks, yticklabels = self._auto_ticks(ax, self.yticklabels, 1)
        else:
            yticks, yticklabels = self.yticks, self.yticklabels

        ax.set(xticks=xticks, yticks=yticks)
        xtl = ax.set_xticklabels(xticklabels)
        ytl = ax.set_yticklabels(yticklabels, rotation="vertical")

        # Possibly rotate them if they overlap
        ax.figure.draw(ax.figure.canvas.get_renderer())
        if axis_ticklabels_overlap(xtl):
            plt.setp(xtl, rotation="vertical")
        if axis_ticklabels_overlap(ytl):
            plt.setp(ytl, rotation="horizontal")

        # Add the axis labels
        ax.set(xlabel=self.xlabel, ylabel=self.ylabel)

        # Annotate the cells with the formatted values
        if self.annot:
            self._annotate_heatmap(ax, hmap)

        # Invert the y axis to show the plot in matrix form
        ax.invert_yaxis()
Ejemplo n.º 17
0
    def plot_colors(self, xind, yind, **kws):
        """Plots color labels between the dendrogram and the heatmap

        Parameters
        ----------
        heatmap_kws : dict
            Keyword arguments heatmap
        """
        # Remove any custom colormap and centering
        kws = kws.copy()
        kws.pop('cmap', None)
        kws.pop('center', None)
        kws.pop('vmin', None)
        kws.pop('vmax', None)
        kws.pop('xticklabels', None)
        kws.pop('yticklabels', None)
        if self.row_colors is not None:
            matrix, cmap = self.color_list_to_matrix_and_cmap(self.row_colors,
                                                              yind,
                                                              axis=0)

            # Get row_color labels
            if self.row_color_labels is not None:
                row_color_labels = self.row_color_labels
            else:
                row_color_labels = False

            heatmap(self,
                    matrix,
                    cmap=cmap,
                    cbar=False,
                    ax=self.ax_row_colors,
                    xticklabels=row_color_labels,
                    yticklabels=False,
                    **kws)

            # Adjust rotation of labels
            if row_color_labels is not False:
                plt.setp(self.ax_row_colors.get_xticklabels(), rotation=90)
        else:
            despine(self.ax_row_colors, left=True, bottom=True)

        if self.col_colors is not None:
            matrix, cmap = self.color_list_to_matrix_and_cmap(self.col_colors,
                                                              xind,
                                                              axis=1)

            # Get col_color labels
            if self.col_color_labels is not None:
                col_color_labels = self.col_color_labels
            else:
                col_color_labels = False

            heatmap(self,
                    matrix,
                    cmap=cmap,
                    cbar=False,
                    ax=self.ax_col_colors,
                    xticklabels=False,
                    yticklabels=col_color_labels,
                    **kws)

            # Adjust rotation of labels, place on right side
            if col_color_labels is not False:
                self.ax_col_colors.yaxis.tick_right()
                plt.setp(self.ax_col_colors.get_yticklabels(), rotation=0)
        else:
            despine(self.ax_col_colors, left=True, bottom=True)
Ejemplo n.º 18
0
    def plot(self, ax, cax):
        """Draw the heatmap on the provided Axes."""

        # Remove all the Axes spines
        despine(ax=ax, left=True, bottom=True)

        # Draw the heatmap and annotate
        height, width = self.plot_data.shape
        xpos, ypos = np.meshgrid(np.arange(width) + .5, np.arange(height) + .5)

        data = self.plot_data.data
        cellsize = self.cellsize

        mask = self.plot_data.mask
        if not isinstance(mask, np.ndarray) and not mask:
            mask = np.zeros(self.plot_data.shape, np.bool)

        annot_data = self.annot_data
        if not self.annot:
            annot_data = np.zeros(self.plot_data.shape)

        # Draw rectangles instead of using pcolormesh
        # Might be slower than original heatmap
        for x, y, m, val, s, an_val in zip(xpos.flat, ypos.flat, mask.flat,
                                           data.flat, cellsize.flat,
                                           annot_data.flat):
            if not m:
                vv = (val - self.vmin) / (self.vmax - self.vmin)
                size = np.clip(s / self.cellsize_vmax, 0.1, 1.0)
                color = self.cmap(vv)
                rect = plt.Rectangle([x - size / 2, y - size / 2],
                                     size,
                                     size,
                                     facecolor=color,
                                     **self.rect_kws)
                ax.add_patch(rect)

                if self.annot:
                    annotation = ("{:" + self.fmt + "}").format(an_val)
                    text = ax.text(x, y, annotation, **self.annot_kws)
                    # add edge to text
                    text_luminance = relative_luminance(text.get_color())
                    text_edge_color = ".15" if text_luminance > .408 else "w"
                    text.set_path_effects([
                        mpl.patheffects.withStroke(linewidth=1,
                                                   foreground=text_edge_color)
                    ])

        # Set the axis limits
        ax.set(xlim=(0, self.data.shape[1]), ylim=(0, self.data.shape[0]))

        # Set other attributes
        ax.set(**self.ax_kws)

        if self.cbar:
            norm = mpl.colors.Normalize(vmin=self.vmin, vmax=self.vmax)
            scalar_mappable = mpl.cm.ScalarMappable(cmap=self.cmap, norm=norm)
            scalar_mappable.set_array(self.plot_data.data)
            cb = ax.figure.colorbar(scalar_mappable, cax, ax, **self.cbar_kws)
            cb.outline.set_linewidth(0)
            # if kws.get('rasterized', False):
            #     cb.solids.set_rasterized(True)

        # Add row and column labels
        if isinstance(self.xticks, string_types) and self.xticks == "auto":
            xticks, xticklabels = self._auto_ticks(ax, self.xticklabels, 0)
        else:
            xticks, xticklabels = self.xticks, self.xticklabels

        if isinstance(self.yticks, string_types) and self.yticks == "auto":
            yticks, yticklabels = self._auto_ticks(ax, self.yticklabels, 1)
        else:
            yticks, yticklabels = self.yticks, self.yticklabels

        ax.set(xticks=xticks, yticks=yticks)
        xtl = ax.set_xticklabels(xticklabels)
        ytl = ax.set_yticklabels(yticklabels, rotation="vertical")

        # Possibly rotate them if they overlap
        ax.figure.draw(ax.figure.canvas.get_renderer())
        if axis_ticklabels_overlap(xtl):
            plt.setp(xtl, rotation="vertical")
        if axis_ticklabels_overlap(ytl):
            plt.setp(ytl, rotation="horizontal")

        # Add the axis labels
        ax.set(xlabel=self.xlabel, ylabel=self.ylabel)

        # Invert the y axis to show the plot in matrix form
        ax.invert_yaxis()
Ejemplo n.º 19
0
Archivo: viz.py Proyecto: vgurev/delfi
def samples_nd(samples, points=[], **kwargs):
    """Plot samples and points

    See `opts` below for available keyword arguments.
    """
    opts = {
        # what to plot on triagonal and diagonal subplots
        'upper': 'hist',  # hist/scatter/None
        'diag': 'hist',  # hist/None
        #'lower': None,     # hist/scatter/None  # TODO: implement

        # title and legend
        'title': None,
        'legend': False,

        # labels
        'labels': [],  # for dimensions
        'labels_points': [],  # for points
        'labels_samples': [],  # for samples

        # colors
        'samples_colors': plt.rcParams['axes.prop_cycle'].by_key()['color'],
        'points_colors': plt.rcParams['axes.prop_cycle'].by_key()['color'],

        # subset
        'subset': None,

        # axes limits
        'limits': [],

        # ticks
        'ticks': [],
        'tickformatter': mpl.ticker.FormatStrFormatter('%g'),
        'tick_labels': None,

        # options for hist
        'hist_diag': {
            'alpha': 1.,
            'bins': 25,
            'density': False,
            'histtype': 'step'
        },
        'hist_offdiag': {
            #'edgecolor': 'none',
            #'linewidth': 0.0,
            'bins': 25,
        },

        # options for kde
        'kde_diag': {
            'bw_method': 'scott',
            'bins': 100,
            'color': 'black'
        },
        'kde_offdiag': {
            'bw_method': 'scott',
            'bins': 25
        },

        # options for contour
        'contour_offdiag': {
            'levels': [0.68]
        },

        # options for scatter
        'scatter_offdiag': {
            'alpha': 0.5,
            'edgecolor': 'none',
            'rasterized': False,
        },

        # options for plot
        'plot_offdiag': {},

        # formatting points (scale, markers)
        'points_diag': {},
        'points_offdiag': {
            'marker': '.',
            'markersize': 20,
        },

        # matplotlib style
        'style': os.path.join(os.path.dirname(__file__), 'matplotlibrc'),

        # other options
        'fig_size': (10, 10),
        'fig_bg_colors': {
            'upper': None,
            'diag': None,
            'lower': None
        },
        'fig_subplots_adjust': {
            'top': 0.9,
        },
        'subplots': {},
        'despine': {
            'offset': 5,
        },
        'title_format': {
            'fontsize': 16
        },
    }
    # TODO: add color map support
    # TODO: automatically determine good bin sizes for histograms
    # TODO: get rid of seaborn dependency for despine
    # TODO: add legend (if legend is True)

    samples_nd.defaults = opts.copy()
    opts = _update(opts, kwargs)

    # Prepare samples
    if type(samples) != list:
        samples = [samples]

    # Prepare points
    if type(points) != list:
        points = [points]
    points = [np.atleast_2d(p) for p in points]

    # Dimensions
    dim = samples[0].shape[1]
    num_samples = samples[0].shape[0]

    # TODO: add asserts checking compatiblity of dimensions

    # Prepare labels
    if opts['labels'] == [] or opts['labels'] is None:
        labels_dim = ['dim {}'.format(i + 1) for i in range(dim)]
    else:
        labels_dim = opts['labels']

    # Prepare limits
    if opts['limits'] == [] or opts['limits'] is None:
        limits = []
        for d in range(dim):
            min = +np.inf
            max = -np.inf
            for sample in samples:
                min_ = sample[:, d].min()
                min = min_ if min_ < min else min
                max_ = sample[:, d].max()
                max = max_ if max_ > max else max
            limits.append([min, max])
    else:
        if len(opts['limits']) == 1:
            limits = [opts['limits'][0] for _ in range(dim)]
        else:
            limits = opts['limits']

    # Prepare ticks
    if opts['ticks'] == [] or opts['ticks'] is None:
        ticks = None
    else:
        if len(opts['ticks']) == 1:
            ticks = [opts['ticks'][0] for _ in range(dim)]
        else:
            ticks = opts['ticks']

    # Prepare diag/upper/lower
    if type(opts['diag']) is not list:
        opts['diag'] = [opts['diag'] for _ in range(len(samples))]
    if type(opts['upper']) is not list:
        opts['upper'] = [opts['upper'] for _ in range(len(samples))]
    #if type(opts['lower']) is not list:
    #    opts['lower'] = [opts['lower'] for _ in range(len(samples))]
    opts['lower'] = None

    # Style
    if opts['style'] in ['dark', 'light']:
        style = os.path.join(os.path.dirname(__file__),
                             'matplotlib_{}.style'.format(opts['style']))
    else:
        style = opts['style']

    # Apply custom style as context
    with mpl.rc_context(fname=style):

        # Figure out if we subset the plot
        subset = opts['subset']
        if subset is None:
            rows = cols = dim
            subset = [i for i in range(dim)]
        else:
            if type(subset) == int:
                subset = [subset]
            elif type(subset) == list:
                pass
            else:
                raise NotImplementedError
            rows = cols = len(subset)

        fig, axes = plt.subplots(rows,
                                 cols,
                                 figsize=opts['fig_size'],
                                 **opts['subplots'])
        axes = axes.reshape(rows, cols)

        # Style figure
        fig.subplots_adjust(**opts['fig_subplots_adjust'])
        fig.suptitle(opts['title'], **opts['title_format'])

        # Style axes
        row_idx = -1
        for row in range(dim):
            if row not in subset:
                continue
            else:
                row_idx += 1

            col_idx = -1
            for col in range(dim):
                if col not in subset:
                    continue
                else:
                    col_idx += 1

                if row == col:
                    current = 'diag'
                elif row < col:
                    current = 'upper'
                else:
                    current = 'lower'

                ax = axes[row_idx, col_idx]
                plt.sca(ax)

                # Background color
                if current in opts['fig_bg_colors'] and \
                    opts['fig_bg_colors'][current] is not None:
                    ax.set_facecolor(opts['fig_bg_colors'][current])

                # Axes
                if opts[current] is None:
                    ax.axis('off')
                    continue

                # Limits
                if limits is not None:
                    ax.set_xlim((limits[col][0], limits[col][1]))
                    if current is not 'diag':
                        ax.set_ylim((limits[row][0], limits[row][1]))
                xmin, xmax = ax.get_xlim()
                ymin, ymax = ax.get_ylim()

                # Ticks
                if ticks is not None:
                    ax.set_xticks((ticks[col][0], ticks[col][1]))
                    if current is not 'diag':
                        ax.set_yticks((ticks[row][0], ticks[row][1]))

                # Despine
                despine(ax=ax, **opts['despine'])

                # Formatting axes
                if current == 'diag':  # off-diagnoals
                    if opts['lower'] is None or col == dim - 1:
                        _format_axis(ax,
                                     xhide=False,
                                     xlabel=labels_dim[col],
                                     yhide=True,
                                     tickformatter=opts['tickformatter'])
                    else:
                        _format_axis(ax, xhide=True, yhide=True)
                else:  # off-diagnoals
                    if row == dim - 1:
                        _format_axis(ax,
                                     xhide=False,
                                     xlabel=labels_dim[col],
                                     yhide=True,
                                     tickformatter=opts['tickformatter'])
                    else:
                        _format_axis(ax, xhide=True, yhide=True)
                if opts['tick_labels'] is not None:
                    ax.set_xticklabels((str(opts['tick_labels'][col][0]),
                                        str(opts['tick_labels'][col][1])))

                # Diagonals
                if current == 'diag':
                    if len(samples) > 0:
                        for n, v in enumerate(samples):
                            if opts['diag'][n] == 'hist':
                                h = plt.hist(v[:, row],
                                             color=opts['samples_colors'][n],
                                             **opts['hist_diag'])
                            elif opts['diag'][n] == 'kde':
                                density = gaussian_kde(
                                    v[:, row],
                                    bw_method=opts['kde_diag']['bw_method'])
                                xs = np.linspace(xmin, xmax,
                                                 opts['kde_diag']['bins'])
                                ys = density(xs)
                                h = plt.plot(
                                    xs,
                                    ys,
                                    color=opts['samples_colors'][n],
                                )
                            else:
                                pass

                    if len(points) > 0:
                        extent = ax.get_ylim()
                        for n, v in enumerate(points):
                            h = plt.plot([v[:, row], v[:, row]],
                                         extent,
                                         color=opts['points_colors'][n],
                                         **opts['points_diag'])

                # Off-diagonals
                else:

                    if len(samples) > 0:
                        for n, v in enumerate(samples):
                            if opts['upper'][n] == 'hist' or opts['upper'][
                                    n] == 'hist2d':
                                hist, xedges, yedges = np.histogram2d(
                                    v[:, col],
                                    v[:, row],
                                    range=[[limits[col][0], limits[col][1]],
                                           [limits[row][0], limits[row][1]]],
                                    **opts['hist_offdiag'])
                                h = plt.imshow(hist.T,
                                               origin='lower',
                                               extent=[
                                                   xedges[0], xedges[-1],
                                                   yedges[0], yedges[-1]
                                               ],
                                               aspect='auto')

                            elif opts['upper'][n] in [
                                    'kde', 'kde2d', 'contour', 'contourf'
                            ]:
                                density = gaussian_kde(
                                    v[:, [col, row]].T,
                                    bw_method=opts['kde_offdiag']['bw_method'])
                                X, Y = np.meshgrid(
                                    np.linspace(limits[col][0], limits[col][1],
                                                opts['kde_offdiag']['bins']),
                                    np.linspace(limits[row][0], limits[row][1],
                                                opts['kde_offdiag']['bins']))
                                positions = np.vstack([X.ravel(), Y.ravel()])
                                Z = np.reshape(density(positions).T, X.shape)

                                if opts['upper'][n] == 'kde' or opts['upper'][
                                        n] == 'kde2d':
                                    h = plt.imshow(
                                        Z,
                                        extent=[
                                            limits[col][0], limits[col][1],
                                            limits[row][0], limits[row][1]
                                        ],
                                        origin='lower',
                                        aspect='auto',
                                    )
                                elif opts['upper'][n] == 'contour':
                                    Z = (Z - Z.min()) / (Z.max() - Z.min())
                                    h = plt.contour(
                                        X,
                                        Y,
                                        Z,
                                        origin='lower',
                                        extent=[
                                            limits[col][0], limits[col][1],
                                            limits[row][0], limits[row][1]
                                        ],
                                        colors=opts['samples_colors'][n],
                                        **opts['contour_offdiag'])
                                else:
                                    pass
                            elif opts['upper'][n] == 'scatter':
                                h = plt.scatter(
                                    v[:, col],
                                    v[:, row],
                                    color=opts['samples_colors'][n],
                                    **opts['scatter_offdiag'])
                            elif opts['upper'][n] == 'plot':
                                h = plt.plot(v[:, col],
                                             v[:, row],
                                             color=opts['samples_colors'][n],
                                             **opts['plot_offdiag'])
                            else:
                                pass

                    if len(points) > 0:

                        for n, v in enumerate(points):
                            h = plt.plot(v[:, col],
                                         v[:, row],
                                         color=opts['points_colors'][n],
                                         **opts['points_offdiag'])

        if len(subset) < dim:
            for row in range(len(subset)):
                ax = axes[row, len(subset) - 1]
                x0, x1 = ax.get_xlim()
                y0, y1 = ax.get_ylim()
                text_kwargs = {'fontsize': plt.rcParams['font.size'] * 2.}
                ax.text(x1 + (x1 - x0) / 8., (y0 + y1) / 2., '...',
                        **text_kwargs)
                if row == len(subset) - 1:
                    ax.text(x1 + (x1 - x0) / 12.,
                            y0 - (y1 - y0) / 1.5,
                            '...',
                            rotation=-45,
                            **text_kwargs)

    return fig, axes