def test_despine_trim_noticks(self): f, ax = plt.subplots() ax.plot([1, 2, 3], [1, 2, 3]) ax.set_yticks([]) utils.despine(trim=True) assert ax.get_yticks().size == 0
def plot_dendrograms(self, row_cluster, col_cluster, metric, method, row_linkage, col_linkage): # Plot the row dendrogram if row_cluster: self.dendrogram_row = dendrogram(self.data2d, metric=metric, method=method, label=False, axis=0, ax=self.ax_row_dendrogram, rotate=True, linkage=row_linkage) else: self.ax_row_dendrogram.set_xticks([]) self.ax_row_dendrogram.set_yticks([]) # PLot the column dendrogram if col_cluster: self.dendrogram_col = dendrogram(self.data2d, metric=metric, method=method, label=False, axis=1, ax=self.ax_col_dendrogram, linkage=col_linkage) else: self.ax_col_dendrogram.set_xticks([]) self.ax_col_dendrogram.set_yticks([]) despine(ax=self.ax_row_dendrogram, bottom=True, left=True) despine(ax=self.ax_col_dendrogram, bottom=True, left=True)
def pairplot(data, target_col, columns=None, scatter_alpha='auto', scatter_size='auto'): """Pairplot (scattermatrix) Because there's already too many implementations of this. This is meant for classification only. This is very bare-bones right now :-/ Parameters ---------- data : pandas dataframe Input data target_col : column specifier Target column in data. columns : column specifiers, default=None. Columns in data to include. None means all. scatter_alpha : float, default='auto' Alpha values for scatter plots. 'auto' is dirty hacks. scatter_size : float, default='auto'. Marker size for scatter plots. 'auto' is dirty hacks. """ if columns is None: columns = data.columns.drop(target_col) n_features = len(columns) fig, axes = plt.subplots(n_features, n_features, figsize=(n_features * 3, n_features * 3)) axes = np.atleast_2d(axes) for ax, (i, j) in zip(axes.ravel(), itertools.product(range(n_features), repeat=2)): legend = i == 0 and j == n_features - 1 if i == j: class_hists(data, columns[i], target_col, ax=ax.twinx()) else: discrete_scatter(data[columns[j]], data[columns[i]], c=data[target_col], legend=legend, ax=ax, alpha=scatter_alpha, s=scatter_size) if j == 0: ax.set_ylabel(columns[i]) else: ax.set_ylabel("") ax.set_yticklabels(()) if i == n_features - 1: ax.set_xlabel(_shortname(columns[j])) else: ax.set_xlabel("") ax.set_xticklabels(()) despine(fig) if n_features > 1: axes[0, 0].set_yticks(axes[0, 1].get_yticks()) axes[0, 0].set_ylim(axes[0, 1].get_ylim()) return axes
def plot(self, ax): """Plots a dendrogram of the similarities between data on the axes Parameters ---------- ax : matplotlib.axes.Axes Axes object upon which the dendrogram is plotted """ line_kwargs = dict(linewidths=.5, colors='k') if self.rotate and self.axis == 0: lines = LineCollection([ list(zip(x, y)) for x, y in zip(self.dependent_coord, self.independent_coord) ], **line_kwargs) else: lines = LineCollection([ list(zip(x, y)) for x, y in zip(self.independent_coord, self.dependent_coord) ], **line_kwargs) ax.add_collection(lines) number_of_leaves = len(self.reordered_ind) max_dependent_coord = max(map(max, self.dependent_coord)) if self.rotate: ax.yaxis.set_ticks_position('right') # Constants 10 and 1.05 come from # `scipy.cluster.hierarchy._plot_dendrogram` ax.set_ylim(0, number_of_leaves * 10) ax.set_xlim(0, max_dependent_coord * 1.05) ax.invert_xaxis() ax.invert_yaxis() else: # Constants 10 and 1.05 come from # `scipy.cluster.hierarchy._plot_dendrogram` ax.set_xlim(0, number_of_leaves * 10) ax.set_ylim(0, max_dependent_coord * 1.05) despine(ax=ax, bottom=True, left=True) ax.set(xticks=self.xticks, yticks=self.yticks, xlabel=self.xlabel, ylabel=self.ylabel) xtl = ax.set_xticklabels(self.xticklabels) ytl = ax.set_yticklabels(self.yticklabels, rotation='vertical') # Force a draw of the plot to avoid matplotlib window error plt.draw() if len(ytl) > 0 and axis_ticklabels_overlap(ytl): plt.setp(ytl, rotation="horizontal") if len(xtl) > 0 and axis_ticklabels_overlap(xtl): plt.setp(xtl, rotation="vertical") return self
def test_despine_trim_spines(self): f, ax = plt.subplots() ax.plot([1, 2, 3], [1, 2, 3]) ax.set_xlim(.75, 3.25) utils.despine(trim=True) for side in self.inner_sides: bounds = ax.spines[side].get_bounds() assert bounds == (1, 3)
def test_despine_trim_inverted(self): f, ax = plt.subplots() ax.plot([1, 2, 3], [1, 2, 3]) ax.set_ylim(.85, 3.15) ax.invert_yaxis() utils.despine(trim=True) for side in self.inner_sides: bounds = ax.spines[side].get_bounds() assert bounds == (1, 3)
def test_despine_trim_categorical(self): f, ax = plt.subplots() ax.plot(["a", "b", "c"], [1, 2, 3]) utils.despine(trim=True) bounds = ax.spines["left"].get_bounds() assert bounds == (1, 3) bounds = ax.spines["bottom"].get_bounds() assert bounds == (0, 2)
def test_despine_specific_axes(self): f, (ax1, ax2) = plt.subplots(2, 1) utils.despine(ax=ax2) for side in self.sides: assert ax1.spines[side].get_visible() for side in self.outer_sides: assert ~ax2.spines[side].get_visible() for side in self.inner_sides: assert ax2.spines[side].get_visible()
def test_despine_side_specific_offset(self): f, ax = plt.subplots() utils.despine(ax=ax, offset=dict(left=self.offset)) for side in self.sides: is_visible = ax.spines[side].get_visible() new_position = ax.spines[side].get_position() if is_visible and side == "left": assert new_position == self.offset_position else: assert new_position == self.original_position
def test_despine_with_offset_specific_axes(self): f, (ax1, ax2) = plt.subplots(2, 1) utils.despine(offset=self.offset, ax=ax2) for side in self.sides: pos1 = ax1.spines[side].get_position() pos2 = ax2.spines[side].get_position() assert pos1 == self.original_position if ax2.spines[side].get_visible(): assert pos2 == self.offset_position else: assert pos2 == self.original_position
def test_despine(self): f, ax = plt.subplots() for side in self.sides: assert ax.spines[side].get_visible() utils.despine() for side in self.outer_sides: assert ~ax.spines[side].get_visible() for side in self.inner_sides: assert ax.spines[side].get_visible() utils.despine(**dict(zip(self.sides, [True] * 4))) for side in self.sides: assert ~ax.spines[side].get_visible()
def test_despine_moved_ticks(self): f, ax = plt.subplots() for t in ax.yaxis.majorTicks: t.tick1line.set_visible(True) utils.despine(ax=ax, left=True, right=False) for t in ax.yaxis.majorTicks: assert t.tick2line.get_visible() plt.close(f) f, ax = plt.subplots() for t in ax.yaxis.majorTicks: t.tick1line.set_visible(False) utils.despine(ax=ax, left=True, right=False) for t in ax.yaxis.majorTicks: assert not t.tick2line.get_visible() plt.close(f) f, ax = plt.subplots() for t in ax.xaxis.majorTicks: t.tick1line.set_visible(True) utils.despine(ax=ax, bottom=True, top=False) for t in ax.xaxis.majorTicks: assert t.tick2line.get_visible() plt.close(f) f, ax = plt.subplots() for t in ax.xaxis.majorTicks: t.tick1line.set_visible(False) utils.despine(ax=ax, bottom=True, top=False) for t in ax.xaxis.majorTicks: assert not t.tick2line.get_visible() plt.close(f)
def test_despine_with_offset(self): f, ax = plt.subplots() for side in self.sides: pos = ax.spines[side].get_position() assert pos == self.original_position utils.despine(ax=ax, offset=self.offset) for side in self.sides: is_visible = ax.spines[side].get_visible() new_position = ax.spines[side].get_position() if is_visible: assert new_position == self.offset_position else: assert new_position == self.original_position
def plot(self, ax, cax, kws): """Draw the heatmap on the provided Axes.""" # Remove all the Axes spines despine(ax=ax, left=True, bottom=True) #Annie Yim # Draw the heatmap self.mesh = ax.pcolormesh(self.plot_data, vmin=self.vmin, vmax=self.vmax, cmap=self.cmap, **kws) # Set the axis limits ax.set(xlim=(0, self.data.shape[1]), ylim=(0, self.data.shape[0])) # Add row and column labels ax.set(xticks=self.xticks, yticks=self.yticks) xtl = ax.set_xticklabels(self.xticklabels) ytl = ax.set_yticklabels(self.yticklabels, rotation="vertical") # Possibly rotate them if they overlap plt.draw() if axis_ticklabels_overlap(xtl): plt.setp(xtl, rotation="vertical") if axis_ticklabels_overlap(ytl): plt.setp(ytl, rotation="horizontal") # Add the axis labels ax.set(xlabel=self.xlabel, ylabel=self.ylabel) # Annotate the cells with the formatted values if self.annot: self._annotate_heatmap(ax, mesh) #Annie Yim # Possibly add a colorbar if self.cbar: cb = ax.figure.colorbar(self.mesh, cax, ax, **self.cbar_kws) cb.outline.set_linewidth(0) # If rasterized is passed to pcolormesh, also rasterize the # colorbar to avoid white lines on the PDF rendering if kws.get('rasterized', False): cb.solids.set_rasterized(True)
def joint_plot(ratio=1, height=3): """ Taken from Seaborn JointGrid """ fig = plt.figure(figsize=(height, height)) gsp = plt.GridSpec(ratio + 1, ratio + 1) ax_joint = fig.add_subplot(gsp[1:, :-1]) ax_marg_x = fig.add_subplot(gsp[0, :-1], sharex=ax_joint) ax_marg_y = fig.add_subplot(gsp[1:, -1], sharey=ax_joint) # Turn off tick visibility for the measure axis on the marginal plots plt.setp(ax_marg_x.get_xticklabels(), visible=False) plt.setp(ax_marg_y.get_yticklabels(), visible=False) # Turn off the ticks on the density axis for the marginal plots plt.setp(ax_marg_x.yaxis.get_majorticklines(), visible=False) plt.setp(ax_marg_x.yaxis.get_minorticklines(), visible=False) plt.setp(ax_marg_y.xaxis.get_majorticklines(), visible=False) plt.setp(ax_marg_y.xaxis.get_minorticklines(), visible=False) plt.setp(ax_marg_x.get_yticklabels(), visible=False) plt.setp(ax_marg_y.get_xticklabels(), visible=False) ax_marg_x.yaxis.grid(False) ax_marg_y.xaxis.grid(False) # Make the grid look nice from seaborn import utils # utils.despine(fig) utils.despine(ax=ax_marg_x, left=True) utils.despine(ax=ax_marg_y, bottom=True) fig.tight_layout(h_pad=0, w_pad=0) ax_marg_y.tick_params(axis='y', which='major', direction='out') ax_marg_x.tick_params(axis='x', which='major', direction='out') ax_marg_y.tick_params(axis='y', which='minor', direction='out') ax_marg_x.tick_params(axis='x', which='minor', direction='out') ax_marg_y.margins(x=0.1, y=0.) fig.subplots_adjust(hspace=0, wspace=0) return fig, ax_joint, ax_marg_x, ax_marg_y
def plot(self, ax, cax, kws): """Draw the scattermap on the provided Axes.""" # Remove all the Axes spines despine(ax=ax, left=True, bottom=True) # Draw the heatmap data = self.plot_data range_y = np.arange(data.shape[0], dtype=int) + 0.5 range_x = np.arange(data.shape[1], dtype=int) + 0.5 x, y = np.meshgrid(range_x, range_y) hmap = ax.scatter(x, y, c=data, marker=self.marker, cmap=self.cmap, vmin=self.vmin, vmax=self.vmax, s=self.marker_size, **kws) # Set the axis limits ax.set(xlim=(0, self.data.shape[1]), ylim=(0, self.data.shape[0])) # Possibly add a colorbar if self.cbar: cb = ax.figure.colorbar(hmap, cax, ax, **self.cbar_kws) cb.outline.set_linewidth(0) # If rasterized is passed to pcolormesh, also rasterize the # colorbar to avoid white lines on the PDF rendering if kws.get('rasterized', False): cb.solids.set_rasterized(True) # Add row and column labels if isinstance(self.xticks, string_types) and self.xticks == "auto": xticks, xticklabels = self._auto_ticks(ax, self.xticklabels, 0) else: xticks, xticklabels = self.xticks, self.xticklabels if isinstance(self.yticks, string_types) and self.yticks == "auto": yticks, yticklabels = self._auto_ticks(ax, self.yticklabels, 1) else: yticks, yticklabels = self.yticks, self.yticklabels ax.set(xticks=xticks, yticks=yticks) xtl = ax.set_xticklabels(xticklabels) ytl = ax.set_yticklabels(yticklabels, rotation="vertical") # Possibly rotate them if they overlap ax.figure.draw(ax.figure.canvas.get_renderer()) if axis_ticklabels_overlap(xtl): plt.setp(xtl, rotation="vertical") if axis_ticklabels_overlap(ytl): plt.setp(ytl, rotation="horizontal") # Add the axis labels ax.set(xlabel=self.xlabel, ylabel=self.ylabel) # Annotate the cells with the formatted values if self.annot: self._annotate_heatmap(ax, hmap) # Invert the y axis to show the plot in matrix form ax.invert_yaxis()
def plot_colors(self, xind, yind, **kws): """Plots color labels between the dendrogram and the heatmap Parameters ---------- heatmap_kws : dict Keyword arguments heatmap """ # Remove any custom colormap and centering kws = kws.copy() kws.pop('cmap', None) kws.pop('center', None) kws.pop('vmin', None) kws.pop('vmax', None) kws.pop('xticklabels', None) kws.pop('yticklabels', None) if self.row_colors is not None: matrix, cmap = self.color_list_to_matrix_and_cmap(self.row_colors, yind, axis=0) # Get row_color labels if self.row_color_labels is not None: row_color_labels = self.row_color_labels else: row_color_labels = False heatmap(self, matrix, cmap=cmap, cbar=False, ax=self.ax_row_colors, xticklabels=row_color_labels, yticklabels=False, **kws) # Adjust rotation of labels if row_color_labels is not False: plt.setp(self.ax_row_colors.get_xticklabels(), rotation=90) else: despine(self.ax_row_colors, left=True, bottom=True) if self.col_colors is not None: matrix, cmap = self.color_list_to_matrix_and_cmap(self.col_colors, xind, axis=1) # Get col_color labels if self.col_color_labels is not None: col_color_labels = self.col_color_labels else: col_color_labels = False heatmap(self, matrix, cmap=cmap, cbar=False, ax=self.ax_col_colors, xticklabels=False, yticklabels=col_color_labels, **kws) # Adjust rotation of labels, place on right side if col_color_labels is not False: self.ax_col_colors.yaxis.tick_right() plt.setp(self.ax_col_colors.get_yticklabels(), rotation=0) else: despine(self.ax_col_colors, left=True, bottom=True)
def plot(self, ax, cax): """Draw the heatmap on the provided Axes.""" # Remove all the Axes spines despine(ax=ax, left=True, bottom=True) # Draw the heatmap and annotate height, width = self.plot_data.shape xpos, ypos = np.meshgrid(np.arange(width) + .5, np.arange(height) + .5) data = self.plot_data.data cellsize = self.cellsize mask = self.plot_data.mask if not isinstance(mask, np.ndarray) and not mask: mask = np.zeros(self.plot_data.shape, np.bool) annot_data = self.annot_data if not self.annot: annot_data = np.zeros(self.plot_data.shape) # Draw rectangles instead of using pcolormesh # Might be slower than original heatmap for x, y, m, val, s, an_val in zip(xpos.flat, ypos.flat, mask.flat, data.flat, cellsize.flat, annot_data.flat): if not m: vv = (val - self.vmin) / (self.vmax - self.vmin) size = np.clip(s / self.cellsize_vmax, 0.1, 1.0) color = self.cmap(vv) rect = plt.Rectangle([x - size / 2, y - size / 2], size, size, facecolor=color, **self.rect_kws) ax.add_patch(rect) if self.annot: annotation = ("{:" + self.fmt + "}").format(an_val) text = ax.text(x, y, annotation, **self.annot_kws) # add edge to text text_luminance = relative_luminance(text.get_color()) text_edge_color = ".15" if text_luminance > .408 else "w" text.set_path_effects([ mpl.patheffects.withStroke(linewidth=1, foreground=text_edge_color) ]) # Set the axis limits ax.set(xlim=(0, self.data.shape[1]), ylim=(0, self.data.shape[0])) # Set other attributes ax.set(**self.ax_kws) if self.cbar: norm = mpl.colors.Normalize(vmin=self.vmin, vmax=self.vmax) scalar_mappable = mpl.cm.ScalarMappable(cmap=self.cmap, norm=norm) scalar_mappable.set_array(self.plot_data.data) cb = ax.figure.colorbar(scalar_mappable, cax, ax, **self.cbar_kws) cb.outline.set_linewidth(0) # if kws.get('rasterized', False): # cb.solids.set_rasterized(True) # Add row and column labels if isinstance(self.xticks, string_types) and self.xticks == "auto": xticks, xticklabels = self._auto_ticks(ax, self.xticklabels, 0) else: xticks, xticklabels = self.xticks, self.xticklabels if isinstance(self.yticks, string_types) and self.yticks == "auto": yticks, yticklabels = self._auto_ticks(ax, self.yticklabels, 1) else: yticks, yticklabels = self.yticks, self.yticklabels ax.set(xticks=xticks, yticks=yticks) xtl = ax.set_xticklabels(xticklabels) ytl = ax.set_yticklabels(yticklabels, rotation="vertical") # Possibly rotate them if they overlap ax.figure.draw(ax.figure.canvas.get_renderer()) if axis_ticklabels_overlap(xtl): plt.setp(xtl, rotation="vertical") if axis_ticklabels_overlap(ytl): plt.setp(ytl, rotation="horizontal") # Add the axis labels ax.set(xlabel=self.xlabel, ylabel=self.ylabel) # Invert the y axis to show the plot in matrix form ax.invert_yaxis()
def samples_nd(samples, points=[], **kwargs): """Plot samples and points See `opts` below for available keyword arguments. """ opts = { # what to plot on triagonal and diagonal subplots 'upper': 'hist', # hist/scatter/None 'diag': 'hist', # hist/None #'lower': None, # hist/scatter/None # TODO: implement # title and legend 'title': None, 'legend': False, # labels 'labels': [], # for dimensions 'labels_points': [], # for points 'labels_samples': [], # for samples # colors 'samples_colors': plt.rcParams['axes.prop_cycle'].by_key()['color'], 'points_colors': plt.rcParams['axes.prop_cycle'].by_key()['color'], # subset 'subset': None, # axes limits 'limits': [], # ticks 'ticks': [], 'tickformatter': mpl.ticker.FormatStrFormatter('%g'), 'tick_labels': None, # options for hist 'hist_diag': { 'alpha': 1., 'bins': 25, 'density': False, 'histtype': 'step' }, 'hist_offdiag': { #'edgecolor': 'none', #'linewidth': 0.0, 'bins': 25, }, # options for kde 'kde_diag': { 'bw_method': 'scott', 'bins': 100, 'color': 'black' }, 'kde_offdiag': { 'bw_method': 'scott', 'bins': 25 }, # options for contour 'contour_offdiag': { 'levels': [0.68] }, # options for scatter 'scatter_offdiag': { 'alpha': 0.5, 'edgecolor': 'none', 'rasterized': False, }, # options for plot 'plot_offdiag': {}, # formatting points (scale, markers) 'points_diag': {}, 'points_offdiag': { 'marker': '.', 'markersize': 20, }, # matplotlib style 'style': os.path.join(os.path.dirname(__file__), 'matplotlibrc'), # other options 'fig_size': (10, 10), 'fig_bg_colors': { 'upper': None, 'diag': None, 'lower': None }, 'fig_subplots_adjust': { 'top': 0.9, }, 'subplots': {}, 'despine': { 'offset': 5, }, 'title_format': { 'fontsize': 16 }, } # TODO: add color map support # TODO: automatically determine good bin sizes for histograms # TODO: get rid of seaborn dependency for despine # TODO: add legend (if legend is True) samples_nd.defaults = opts.copy() opts = _update(opts, kwargs) # Prepare samples if type(samples) != list: samples = [samples] # Prepare points if type(points) != list: points = [points] points = [np.atleast_2d(p) for p in points] # Dimensions dim = samples[0].shape[1] num_samples = samples[0].shape[0] # TODO: add asserts checking compatiblity of dimensions # Prepare labels if opts['labels'] == [] or opts['labels'] is None: labels_dim = ['dim {}'.format(i + 1) for i in range(dim)] else: labels_dim = opts['labels'] # Prepare limits if opts['limits'] == [] or opts['limits'] is None: limits = [] for d in range(dim): min = +np.inf max = -np.inf for sample in samples: min_ = sample[:, d].min() min = min_ if min_ < min else min max_ = sample[:, d].max() max = max_ if max_ > max else max limits.append([min, max]) else: if len(opts['limits']) == 1: limits = [opts['limits'][0] for _ in range(dim)] else: limits = opts['limits'] # Prepare ticks if opts['ticks'] == [] or opts['ticks'] is None: ticks = None else: if len(opts['ticks']) == 1: ticks = [opts['ticks'][0] for _ in range(dim)] else: ticks = opts['ticks'] # Prepare diag/upper/lower if type(opts['diag']) is not list: opts['diag'] = [opts['diag'] for _ in range(len(samples))] if type(opts['upper']) is not list: opts['upper'] = [opts['upper'] for _ in range(len(samples))] #if type(opts['lower']) is not list: # opts['lower'] = [opts['lower'] for _ in range(len(samples))] opts['lower'] = None # Style if opts['style'] in ['dark', 'light']: style = os.path.join(os.path.dirname(__file__), 'matplotlib_{}.style'.format(opts['style'])) else: style = opts['style'] # Apply custom style as context with mpl.rc_context(fname=style): # Figure out if we subset the plot subset = opts['subset'] if subset is None: rows = cols = dim subset = [i for i in range(dim)] else: if type(subset) == int: subset = [subset] elif type(subset) == list: pass else: raise NotImplementedError rows = cols = len(subset) fig, axes = plt.subplots(rows, cols, figsize=opts['fig_size'], **opts['subplots']) axes = axes.reshape(rows, cols) # Style figure fig.subplots_adjust(**opts['fig_subplots_adjust']) fig.suptitle(opts['title'], **opts['title_format']) # Style axes row_idx = -1 for row in range(dim): if row not in subset: continue else: row_idx += 1 col_idx = -1 for col in range(dim): if col not in subset: continue else: col_idx += 1 if row == col: current = 'diag' elif row < col: current = 'upper' else: current = 'lower' ax = axes[row_idx, col_idx] plt.sca(ax) # Background color if current in opts['fig_bg_colors'] and \ opts['fig_bg_colors'][current] is not None: ax.set_facecolor(opts['fig_bg_colors'][current]) # Axes if opts[current] is None: ax.axis('off') continue # Limits if limits is not None: ax.set_xlim((limits[col][0], limits[col][1])) if current is not 'diag': ax.set_ylim((limits[row][0], limits[row][1])) xmin, xmax = ax.get_xlim() ymin, ymax = ax.get_ylim() # Ticks if ticks is not None: ax.set_xticks((ticks[col][0], ticks[col][1])) if current is not 'diag': ax.set_yticks((ticks[row][0], ticks[row][1])) # Despine despine(ax=ax, **opts['despine']) # Formatting axes if current == 'diag': # off-diagnoals if opts['lower'] is None or col == dim - 1: _format_axis(ax, xhide=False, xlabel=labels_dim[col], yhide=True, tickformatter=opts['tickformatter']) else: _format_axis(ax, xhide=True, yhide=True) else: # off-diagnoals if row == dim - 1: _format_axis(ax, xhide=False, xlabel=labels_dim[col], yhide=True, tickformatter=opts['tickformatter']) else: _format_axis(ax, xhide=True, yhide=True) if opts['tick_labels'] is not None: ax.set_xticklabels((str(opts['tick_labels'][col][0]), str(opts['tick_labels'][col][1]))) # Diagonals if current == 'diag': if len(samples) > 0: for n, v in enumerate(samples): if opts['diag'][n] == 'hist': h = plt.hist(v[:, row], color=opts['samples_colors'][n], **opts['hist_diag']) elif opts['diag'][n] == 'kde': density = gaussian_kde( v[:, row], bw_method=opts['kde_diag']['bw_method']) xs = np.linspace(xmin, xmax, opts['kde_diag']['bins']) ys = density(xs) h = plt.plot( xs, ys, color=opts['samples_colors'][n], ) else: pass if len(points) > 0: extent = ax.get_ylim() for n, v in enumerate(points): h = plt.plot([v[:, row], v[:, row]], extent, color=opts['points_colors'][n], **opts['points_diag']) # Off-diagonals else: if len(samples) > 0: for n, v in enumerate(samples): if opts['upper'][n] == 'hist' or opts['upper'][ n] == 'hist2d': hist, xedges, yedges = np.histogram2d( v[:, col], v[:, row], range=[[limits[col][0], limits[col][1]], [limits[row][0], limits[row][1]]], **opts['hist_offdiag']) h = plt.imshow(hist.T, origin='lower', extent=[ xedges[0], xedges[-1], yedges[0], yedges[-1] ], aspect='auto') elif opts['upper'][n] in [ 'kde', 'kde2d', 'contour', 'contourf' ]: density = gaussian_kde( v[:, [col, row]].T, bw_method=opts['kde_offdiag']['bw_method']) X, Y = np.meshgrid( np.linspace(limits[col][0], limits[col][1], opts['kde_offdiag']['bins']), np.linspace(limits[row][0], limits[row][1], opts['kde_offdiag']['bins'])) positions = np.vstack([X.ravel(), Y.ravel()]) Z = np.reshape(density(positions).T, X.shape) if opts['upper'][n] == 'kde' or opts['upper'][ n] == 'kde2d': h = plt.imshow( Z, extent=[ limits[col][0], limits[col][1], limits[row][0], limits[row][1] ], origin='lower', aspect='auto', ) elif opts['upper'][n] == 'contour': Z = (Z - Z.min()) / (Z.max() - Z.min()) h = plt.contour( X, Y, Z, origin='lower', extent=[ limits[col][0], limits[col][1], limits[row][0], limits[row][1] ], colors=opts['samples_colors'][n], **opts['contour_offdiag']) else: pass elif opts['upper'][n] == 'scatter': h = plt.scatter( v[:, col], v[:, row], color=opts['samples_colors'][n], **opts['scatter_offdiag']) elif opts['upper'][n] == 'plot': h = plt.plot(v[:, col], v[:, row], color=opts['samples_colors'][n], **opts['plot_offdiag']) else: pass if len(points) > 0: for n, v in enumerate(points): h = plt.plot(v[:, col], v[:, row], color=opts['points_colors'][n], **opts['points_offdiag']) if len(subset) < dim: for row in range(len(subset)): ax = axes[row, len(subset) - 1] x0, x1 = ax.get_xlim() y0, y1 = ax.get_ylim() text_kwargs = {'fontsize': plt.rcParams['font.size'] * 2.} ax.text(x1 + (x1 - x0) / 8., (y0 + y1) / 2., '...', **text_kwargs) if row == len(subset) - 1: ax.text(x1 + (x1 - x0) / 12., y0 - (y1 - y0) / 1.5, '...', rotation=-45, **text_kwargs) return fig, axes