def test_match_tips_intersect_tree_immutable(self): # tests to see if tree changes. table = pd.DataFrame([[0, 0, 1], [2, 3, 4], [5, 5, 3], [0, 0, 1]], index=['s1', 's2', 's3', 's4'], columns=['a', 'b', 'd']) tree = TreeNode.read([u"(((a,b)f, c),d)r;"]) match_tips(table, tree) self.assertEqual(str(tree), u"(((a,b)f,c),d)r;\n")
def test_biom_match_tips_intersect_tree_immutable(self): # tests to see if tree changes. table = Table( np.array([[0, 0, 1], [2, 3, 4], [5, 5, 3], [0, 0, 1]]).T, ['a', 'b', 'd'], ['s1', 's2', 's3', 's4']) exp_table = Table( np.array([[0, 0, 1], [2, 3, 4], [5, 5, 3], [0, 0, 1]]).T, ['a', 'b', 'd'], ['s1', 's2', 's3', 's4']) tree = TreeNode.read([u"(((a,b)f, c),d)r;"]) match_tips(table, tree) self.assertEqual(exp_table, table) self.assertEqual(str(tree), u"(((a,b)f,c),d)r;\n")
def ilr_transform(table: pd.DataFrame, tree: skbio.TreeNode) -> pd.DataFrame: """Performs isometric logratio (ilr) transformation on feature-table. This creates a new table with balances (groups of features) that distinguish samples. Zeros must first be removed from the table (e.g. add-pseudocount). For source documentation check out: https://numpydoc.readthedocs.io/en/latest/ Parameters ----------- table : pd.DataFrame Dataframe of the feature table where rows correspond to samples and columns are features. The values within the table must be positive and nonzero. tree : skbio.TreeNode A tree relating all of the features to balances or log-contrasts (hierarchy). This tree must be bifurcating (i.e. has exactly 2 nodes). The internal nodes of the tree will be renamed. Returns -------- balances : pd.DataFrame Balances calculated from the feature table. Balance represents the log ratio of subchildren values below the specified internal node. """ _table, _tree = match_tips(table, tree) basis, nodes = balance_basis(_tree) balances = ilr(_table.values, basis) in_nodes = [n.name for n in _tree.levelorder() if not n.is_tip()] return pd.DataFrame(balances, columns=in_nodes, index=table.index)
def test_biom_match_tips_intersect_columns(self): # table has less columns than tree tips table = Table( np.array([[0, 0, 1], [2, 3, 4], [5, 5, 3], [0, 0, 1]]).T, ['a', 'b', 'd'], ['s1', 's2', 's3', 's4']) tree = TreeNode.read([u"(((a,b)f, c),d)r;"]) table = Table( np.array([[0, 0, 1], [2, 3, 4], [5, 5, 3], [0, 0, 1]]).T, ['a', 'b', 'd'], ['s1', 's2', 's3', 's4']) exp_table = Table( np.array([[1, 0, 0], [4, 2, 3], [3, 5, 5], [1, 0, 0]]).T, ['d', 'a', 'b'], ['s1', 's2', 's3', 's4']) exp_tree = TreeNode.read([u"(d,(a,b)f)r;"]) res_table, res_tree = match_tips(table, tree) self.assertEqual(exp_table, res_table) self.assertEqual(str(exp_tree), str(res_tree))
def ilr_phylogenetic(table: pd.DataFrame, tree: skbio.TreeNode, pseudocount: float = 0.5) -> ( pd.DataFrame, skbio.TreeNode): t = tree.copy() t.bifurcate() table, t = match_tips(table, t) t = rename_internal_nodes(t) return ilr_transform(add_pseudocount(table, pseudocount), t), t
def ilr_transform(table: pd.DataFrame, tree: skbio.TreeNode) -> pd.DataFrame: _table, _tree = match_tips(table, tree) basis, _ = balance_basis(_tree) balances = ilr(_table.values, basis) in_nodes = [n.name for n in _tree.levelorder() if not n.is_tip()] return pd.DataFrame(balances, columns=in_nodes, index=table.index)
def test_match_tips(self): table = pd.DataFrame( [[0, 0, 1, 1], [2, 2, 4, 4], [5, 5, 3, 3], [0, 0, 0, 1]], index=['s1', 's2', 's3', 's4'], columns=['a', 'b', 'c', 'd']) tree = TreeNode.read([u"(((a,b)f, c),d)r;"]) exp_table, exp_tree = table, tree res_table, res_tree = match_tips(table, tree) pdt.assert_frame_equal(exp_table, res_table) self.assertEqual(str(exp_tree), str(res_tree))
def assign_ids(input_table: pd.DataFrame, input_tree: skbio.TreeNode) -> (pd.DataFrame, skbio.TreeNode): t = input_tree.copy() t.bifurcate() ids = ['%sL-%s' % (i, uuid.uuid4()) for i, n in enumerate(t.levelorder(include_self=True)) if not n.is_tip()] t = rename_internal_nodes(t, names=ids) _table, _t = match_tips(input_table, t) return _table, _t
def test_match_tips(self): table = pd.DataFrame([[0, 0, 1, 1], [2, 2, 4, 4], [5, 5, 3, 3], [0, 0, 0, 1]], index=['s1', 's2', 's3', 's4'], columns=['a', 'b', 'c', 'd']) tree = TreeNode.read([u"(((a,b)f, c),d)r;"]) exp_table, exp_tree = table, tree res_table, res_tree = match_tips(table, tree) pdt.assert_frame_equal(exp_table, res_table) self.assertEqual(str(exp_tree), str(res_tree))
def ilr_phylogenetic_differential( differential: pd.DataFrame, tree: skbio.TreeNode) -> (pd.DataFrame, skbio.TreeNode): t = tree.copy() t.bifurcate() diff, _tree = match_tips(differential.T, t) _tree = rename_internal_nodes(_tree) in_nodes = [n.name for n in _tree.levelorder() if not n.is_tip()] basis = _balance_basis(_tree)[0] basis = pd.DataFrame(basis.T, index=diff.columns, columns=in_nodes) diff_balances = (diff @ basis).T diff_balances.index.name = 'featureid' return diff_balances, t
def test_match_tips_intersect_columns(self): # table has less columns than tree tips table = pd.DataFrame([[0, 0, 1], [2, 3, 4], [5, 5, 3], [0, 0, 1]], index=['s1', 's2', 's3', 's4'], columns=['a', 'b', 'd']) tree = TreeNode.read([u"(((a,b)f, c),d)r;"]) exp_table = pd.DataFrame([[1, 0, 0], [4, 2, 3], [3, 5, 5], [1, 0, 0]], index=['s1', 's2', 's3', 's4'], columns=['d', 'a', 'b']) exp_tree = TreeNode.read([u"(d,(a,b)f)r;"]) res_table, res_tree = match_tips(table, tree) pdt.assert_frame_equal(exp_table, res_table) self.assertEqual(str(exp_tree), str(res_tree))
def test_biom_match_tips_intersect_tips(self): # there are less tree tips than table columns table = Table( np.array([[0, 0, 1, 1], [2, 3, 4, 4], [5, 5, 3, 3], [0, 0, 0, 1]]).T, ['a', 'b', 'c', 'd'], ['s1', 's2', 's3', 's4']) tree = TreeNode.read([u"((a,b)f,c)r;"]) exp_table = Table( np.array([[0, 0, 1], [2, 3, 4], [5, 5, 3], [0, 0, 0]]).T, ['a', 'b', 'c'], ['s1', 's2', 's3', 's4']) exp_tree = tree res_table, res_tree = match_tips(table, tree) self.assertEqual(exp_table, res_table) self.assertEqual(str(exp_tree), str(res_tree))
def _intersect_of_table_metadata_tree(table, metadata, tree): """ Matches tips, features and samples between the table, metadata and tree. This module returns the features and samples that are contained in all 3 objects. Parameters ---------- table : pd.DataFrame Contingency table where samples correspond to rows and features correspond to columns. metadata: pd.DataFrame Metadata table that contains information about the samples contained in the `table` object. Samples correspond to rows and covariates correspond to columns. tree : skbio.TreeNode Tree object where the leaves correspond to the columns contained in the table. Returns ------- pd.DataFrame Subset of `table` with common row names as `metadata` and common columns as `tree.tips()` pd.DataFrame Subset of `metadata` with common row names as `table` skbio.TreeNode Subtree of `tree` with common tips as `table` """ if np.any(table <= 0): raise ValueError('Cannot handle zeros or negative values in `table`. ' 'Use pseudocounts or ``multiplicative_replacement``.') _table, _metadata = match(table, metadata) _table, _tree = match_tips(_table, tree) non_tips_no_name = [(n.name is None) for n in _tree.levelorder() if not n.is_tip()] if len(non_tips_no_name) == 0: raise ValueError('There are no internal nodes in `tree` after' 'intersection with `table`.') if len(_table.index) == 0: raise ValueError('There are no internal nodes in `table` after ' 'intersection with `metadata`.') if any(non_tips_no_name): _tree = rename_internal_nodes(_tree) return _table, _metadata, _tree
def ilr_phylogenetic_ordination( table: pd.DataFrame, tree: skbio.TreeNode, pseudocount: float = 0.5, top_k_var: int = 10, clades: list = None ) -> (OrdinationResults, skbio.TreeNode, pd.DataFrame): t = tree.copy() t.bifurcate() _table, _tree = match_tips(table, t) _tree = rename_internal_nodes(_tree) if not clades: in_nodes = [n.name for n in _tree.levelorder() if not n.is_tip()] basis = _balance_basis(_tree)[0] _table = add_pseudocount(_table, pseudocount) basis = pd.DataFrame(basis.T, index=_table.columns, columns=in_nodes) balances = np.log(_table) @ basis var = balances.var(axis=0).sort_values(ascending=False) clades = var.index[:top_k_var] balances = balances[clades] basis = basis[clades] else: clades = clades[0].split(',') balances, basis = _fast_ilr(_tree, _table, clades, pseudocount=0.5) var = balances.var(axis=0).sort_values(ascending=False) balances.index.name = 'sampleid' # feature metadata eigvals = var prop = var[clades] / var.sum() balances = OrdinationResults( short_method_name='ILR', long_method_name='Phylogenetic Isometric Log Ratio Transform', samples=balances, features=pd.DataFrame(np.eye(len(clades)), index=clades), eigvals=eigvals, proportion_explained=prop) basis.index.name = 'featureid' return balances, _tree, basis
def dendrogram_heatmap(output_dir: str, table: pd.DataFrame, tree: TreeNode, metadata: qiime2.CategoricalMetadataColumn, pseudocount: float = 0.5, ndim: int = 10, method: str = 'clr', color_map: str = 'viridis'): table, tree = match_tips(add_pseudocount(table, pseudocount), tree) nodes = [n.name for n in tree.levelorder() if not n.is_tip()] nlen = min(ndim, len(nodes)) numerator_color, denominator_color = '#fb9a99', '#e31a1c' highlights = pd.DataFrame([[numerator_color, denominator_color]] * nlen, index=nodes[:nlen]) if method == 'clr': mat = pd.DataFrame(clr(centralize(table)), index=table.index, columns=table.columns) elif method == 'log': mat = pd.DataFrame(np.log(table), index=table.index, columns=table.columns) c = metadata.to_series() table, c = match(table, c) # TODO: There are a few hard-coded constants here # will need to have some adaptive defaults set in the future fig = heatmap(mat, tree, c, highlights, cmap=color_map, highlight_width=0.01, figsize=(12, 8)) fig.savefig(os.path.join(output_dir, 'heatmap.svg')) fig.savefig(os.path.join(output_dir, 'heatmap.pdf')) css = r""" .square { float: left; width: 100px; height: 20px; margin: 5px; border: 1px solid rgba(0, 0, 0, .2); } .numerator { background: %s; } .denominator { background: %s; } """ % (numerator_color, denominator_color) index_fp = os.path.join(output_dir, 'index.html') with open(index_fp, 'w') as index_f: index_f.write('<html><body>\n') index_f.write('<h1>Dendrogram heatmap</h1>\n') index_f.write('<img src="heatmap.svg" alt="heatmap">') index_f.write('<a href="heatmap.pdf">') index_f.write('Download as PDF</a><br>\n') index_f.write('<style>%s</style>' % css) index_f.write('<div class="square numerator">' 'Numerator<br/></div>') index_f.write('<div class="square denominator">' 'Denominator<br/></div>') index_f.write('</body></html>\n')
def balance_taxonomy(output_dir: str, table: pd.DataFrame, tree: TreeNode, taxonomy: pd.DataFrame, balance_name: str, pseudocount: float = 0.5, taxa_level: int = 0, n_features: int = 10, threshold: float = None, metadata: qiime2.MetadataColumn = None) -> None: if threshold is not None and isinstance(metadata, qiime2.CategoricalMetadataColumn): raise ValueError('Categorical metadata column detected. Only specify ' 'a threshold when using a numerical metadata column.') # make sure that the table and tree match up table, tree = match_tips(add_pseudocount(table, pseudocount), tree) # parse out headers for taxonomy taxa_data = list(taxonomy['Taxon'].apply(lambda x: x.split(';')).values) taxa_df = pd.DataFrame(taxa_data, index=taxonomy.index) # fill in NAs def f(x): y = np.array(list(map(lambda k: k is not None, x))) i = max(0, np.where(y)[0][-1]) x[np.logical_not(y)] = [x[i]] * np.sum(np.logical_not(y)) return x taxa_df = taxa_df.apply(f, axis=1) num_clade = tree.find(balance_name).children[NUMERATOR] denom_clade = tree.find(balance_name).children[DENOMINATOR] if num_clade.is_tip(): num_features = pd.DataFrame( {num_clade.name: taxa_df.loc[num_clade.name]} ).T r = 1 else: num_features = taxa_df.loc[num_clade.subset()] r = len(list(num_clade.tips())) if denom_clade.is_tip(): denom_features = pd.DataFrame( {denom_clade.name: taxa_df.loc[denom_clade.name]} ).T s = 1 else: denom_features = taxa_df.loc[denom_clade.subset()] s = len(list(denom_clade.tips())) b = (np.log(table.loc[:, num_features.index]).mean(axis=1) - np.log(table.loc[:, denom_features.index]).mean(axis=1)) b = b * np.sqrt(r * s / (r + s)) balances = pd.DataFrame(b, index=table.index, columns=[balance_name]) # the actual colors for the numerator and denominator num_color = sns.color_palette("Paired")[0] denom_color = sns.color_palette("Paired")[1] fig, (ax_num, ax_denom) = plt.subplots(2) balance_barplots(tree, balance_name, taxa_level, taxa_df, denom_color=denom_color, num_color=num_color, axes=(ax_num, ax_denom)) ax_num.set_title( r'$%s_{numerator} \; taxa \; (%d \; taxa)$' % ( balance_name, len(num_features))) ax_denom.set_title( r'$%s_{denominator} \; taxa \; (%d \; taxa)$' % ( balance_name, len(denom_features))) ax_denom.set_xlabel('Number of unique taxa') plt.tight_layout() fig.savefig(os.path.join(output_dir, 'barplots.svg')) fig.savefig(os.path.join(output_dir, 'barplots.pdf')) dcat = None multiple_cats = False if metadata is not None: fig2, ax = plt.subplots() c = metadata.to_series() data, c = match(balances, c) data[c.name] = c y = data[balance_name] # check if continuous if isinstance(metadata, qiime2.NumericMetadataColumn): ax.scatter(c.values, y) ax.set_xlabel(c.name) if threshold is None: threshold = c.mean() dcat = c.apply( lambda x: '%s < %f' % (c.name, threshold) if x < threshold else '%s > %f' % (c.name, threshold) ) sample_palette = pd.Series(sns.color_palette("Set2", 2), index=dcat.value_counts().index) elif isinstance(metadata, qiime2.CategoricalMetadataColumn): sample_palette = pd.Series( sns.color_palette("Set2", len(c.value_counts())), index=c.value_counts().index) try: pd.to_numeric(metadata.to_series()) except ValueError: pass else: raise ValueError('Categorical metadata column ' f'{metadata.name!r} contains only numerical ' 'values. At least one value must be ' 'non-numerical.') balance_boxplot(balance_name, data, y=c.name, ax=ax, palette=sample_palette) if len(c.value_counts()) > 2: warnings.warn( 'More than 2 categories detected in categorical metadata ' 'column. Proportion plots will not be displayed', stacklevel=2) multiple_cats = True else: dcat = c else: # Some other type of MetadataColumn raise NotImplementedError() ylabel = (r"$%s = \ln \frac{%s_{numerator}}" "{%s_{denominator}}$") % (balance_name, balance_name, balance_name) ax.set_title(ylabel, rotation=0) ax.set_ylabel('log ratio') fig2.savefig(os.path.join(output_dir, 'balance_metadata.svg')) fig2.savefig(os.path.join(output_dir, 'balance_metadata.pdf')) if not multiple_cats: # Proportion plots # first sort by clr values and calculate average fold change ctable = pd.DataFrame(clr(centralize(table)), index=table.index, columns=table.columns) left_group = dcat.value_counts().index[0] right_group = dcat.value_counts().index[1] lidx, ridx = (dcat == left_group), (dcat == right_group) if b.loc[lidx].mean() > b.loc[ridx].mean(): # double check ordering and switch if necessary # careful - the left group is also commonly associated with # the denominator. left_group = dcat.value_counts().index[1] right_group = dcat.value_counts().index[0] lidx, ridx = (dcat == left_group), (dcat == right_group) # we are not performing a statistical test here # we're just trying to figure out a way to sort the data. num_fold_change = ctable.loc[:, num_features.index].apply( lambda x: ttest_ind(x[ridx], x[lidx])[0]) num_fold_change = num_fold_change.sort_values( ascending=False ) denom_fold_change = ctable.loc[:, denom_features.index].apply( lambda x: ttest_ind(x[ridx], x[lidx])[0]) denom_fold_change = denom_fold_change.sort_values( ascending=True ) metadata = pd.DataFrame({dcat.name: dcat}) top_num_features = num_fold_change.index[:n_features] top_denom_features = denom_fold_change.index[:n_features] fig3, (ax_denom, ax_num) = plt.subplots(1, 2) proportion_plot( table, metadata, category=metadata.columns[0], left_group=left_group, right_group=right_group, feature_metadata=taxa_df, label_col=taxa_level, num_features=top_num_features, denom_features=top_denom_features, # Note that the syntax is funky and counter # intuitive. This will need to be properly # fixed here # https://github.com/biocore/gneiss/issues/244 num_color=sample_palette.loc[right_group], denom_color=sample_palette.loc[left_group], axes=(ax_num, ax_denom)) # The below is overriding the default colors in the # numerator / denominator this will also need to be fixed in # https://github.com/biocore/gneiss/issues/244 max_ylim, min_ylim = ax_denom.get_ylim() num_h, denom_h = n_features, n_features space = (max_ylim - min_ylim) / (num_h + denom_h) ymid = (max_ylim - min_ylim) * num_h ymid = ymid / (num_h + denom_h) - 0.5 * space ax_denom.axhspan(min_ylim, ymid, facecolor=num_color, zorder=0) ax_denom.axhspan(ymid, max_ylim, facecolor=denom_color, zorder=0) ax_num.axhspan(min_ylim, ymid, facecolor=num_color, zorder=0) ax_num.axhspan(ymid, max_ylim, facecolor=denom_color, zorder=0) fig3.subplots_adjust( # the left side of the subplots of the figure left=0.3, # the right side of the subplots of the figure right=0.9, # the bottom of the subplots of the figure bottom=0.1, # the top of the subplots of the figure top=0.9, # the amount of width reserved for blank space # between subplots wspace=0, # the amount of height reserved for white space # between subplots hspace=0.2, ) fig3.savefig(os.path.join(output_dir, 'proportion_plot.svg')) fig3.savefig(os.path.join(output_dir, 'proportion_plot.pdf')) index_fp = os.path.join(output_dir, 'index.html') with open(index_fp, 'w') as index_f: index_f.write('<html><body>\n') if metadata is not None: index_f.write('<h1>Balance vs %s </h1>\n' % c.name) index_f.write(('<img src="balance_metadata.svg" ' 'alt="barplots">\n\n' '<a href="balance_metadata.pdf">' 'Download as PDF</a><br>\n')) if not multiple_cats: index_f.write('<h1>Proportion Plot </h1>\n') index_f.write(('<img src="proportion_plot.svg" ' 'alt="proportions">\n\n' '<a href="proportion_plot.pdf">' 'Download as PDF</a><br>\n')) index_f.write(('<h1>Balance Taxonomy</h1>\n' '<img src="barplots.svg" alt="barplots">\n\n' '<a href="barplots.pdf">' 'Download as PDF</a><br>\n' '<h3>Numerator taxa</h3>\n' '<a href="numerator.csv">\n' 'Download as CSV</a><br>\n' '<h3>Denominator taxa</h3>\n' '<a href="denominator.csv">\n' 'Download as CSV</a><br>\n')) num_features.to_csv(os.path.join(output_dir, 'numerator.csv'), header=True, index=True) denom_features.to_csv(os.path.join(output_dir, 'denominator.csv'), header=True, index=True) index_f.write('</body></html>\n')
def heatmap(table, tree, mdvar, highlights=None, cmap='viridis', linewidth=0.5, grid_col='w', grid_width=2, highlight_width=0.02, figsize=(5, 5), **kwargs): """ Creates heatmap plotting object Parameters ---------- table : pd.DataFrame Contain sample/feature labels along with table of values. Rows correspond to samples, and columns correspond to features. tree: skbio.TreeNode Tree representing the feature hierarchy. highlights: pd.DataFrame or dict of tuple of str List of internal nodes in the tree to highlight. Each internal node must contain two colors, one for the left subtree and the other for the right subtree highlight. The first color will always correspond to the left subtree, and the second color will always correspond to the right subtree. cmap : str Specifies the matplotlib colormap for the heatmap (default='viridis') linewidth : int Width of dendrogram lines. mdvar: pd.Series Metadata values for samples. The index must correspond to the index of `table`. highlight_width : int Width of highlights. (default=0.02) grid_col: str Color of vertical lines for highlighting sample metadata. (default='w') grid_width: int Width of vertical lines for highlighting sample metadata. (default=2) figsize: tuple of int Species (width, height) for figure. (default=(5, 5)) **kwargs : dict Arguments to be passed into the heatmap plotting. Returns ------- matplotlib.pyplot.figure Matplotlib figure object Note ---- The highlights parameter assumes that the tree is bifurcating. """ # match the tips dendrogram_width = 20 table, tree = match_tips(table, tree) table = table.T # get edges from tree t = SquareDendrogram.from_tree(tree) t = _tree_coordinates(t) pts = t.coords(width=dendrogram_width, height=table.shape[0]) edges = pts[['child0', 'child1']] edges = edges.dropna(subset=['child0', 'child1']) edges = edges.unstack() edges = pd.DataFrame({'src_node': edges.index.get_level_values(1), 'dest_node': edges.values}) edge_list = [] for i in edges.index: src = edges.loc[i, 'src_node'] dest = edges.loc[i, 'dest_node'] sx, sy = pts.loc[src].x, pts.loc[src].y dx, dy = pts.loc[dest].x, pts.loc[dest].y edge_list.append( {'x0': sx, 'y0': sy, 'x1': sx, 'y1': dy} ) edge_list.append( {'x0': sx, 'y0': dy, 'x1': dx, 'y1': dy} ) edge_list = pd.DataFrame(edge_list) # now plot the stuff fig = plt.figure(figsize=figsize, facecolor='white') plt.rcParams['axes.facecolor'] = 'white' xwidth = 0.2 top_buffer = 0.1 height = 0.8 cbar_pad = 0.95 # dendrogram axes on the left side [axm_x, axm_y, axm_w, axm_h] = [0, top_buffer, xwidth, height] # create a split for the highlights if highlights is not None: h = len(highlights) else: h = 0 hwidth = highlight_width [axs_x, axs_y, axs_w, axs_h] = [xwidth, top_buffer, hwidth * h, height] # heatmap axes hstart = xwidth + (h * hwidth) [ax1_x, ax1_y, ax1_w, ax1_h] = [hstart, top_buffer, cbar_pad-hstart, height] # plot dendrogram ax_dendrogram = fig.add_axes([axm_x, axm_y, axm_w, axm_h], frame_on=True) _plot_dendrogram(ax_dendrogram, table, edge_list, linewidth=linewidth) # plot highlights for dendrogram if highlights is not None: ax_highlights = fig.add_axes([axs_x, axs_y, axs_w, axs_h], frame_on=True, sharey=ax_dendrogram) _plot_highlights_dendrogram(ax_highlights, table, t, highlights) # plot heatmap ax_heatmap = fig.add_axes([ax1_x, ax1_y, ax1_w, ax1_h], frame_on=True, sharey=ax_dendrogram) cbar = _plot_heatmap(ax_heatmap, table, mdvar, grid_col, grid_width, cmap, **kwargs) # plot the colorbar fig.colorbar(cbar) return fig
def heatmap(table, tree, mdvar, highlights=None, cmap='viridis', linewidth=0.5, grid_col='w', grid_width=2, highlight_width=0.02, figsize=(5, 5), **kwargs): """ Creates heatmap plotting object Parameters ---------- table : pd.DataFrame Contain sample/feature labels along with table of values. Rows correspond to samples, and columns correspond to features. tree: skbio.TreeNode Tree representing the feature hierarchy. highlights: pd.DataFrame or dict of tuple of str List of internal nodes in the tree to highlight. Each internal node must contain two colors, one for the left subtree and the other for the right subtree highlight. The first color will always correspond to the left subtree, and the second color will always correspond to the right subtree. cmap : str Specifies the matplotlib colormap for the heatmap (default='viridis') linewidth : int Width of dendrogram lines. mdvar: pd.Series Metadata values for samples. The index must correspond to the index of `table`. highlight_width : int Width of highlights. (default=0.02) grid_col: str Color of vertical lines for highlighting sample metadata. (default='w') grid_width: int Width of vertical lines for highlighting sample metadata. (default=2) figsize: tuple of int Species (width, height) for figure. (default=(5, 5)) **kwargs : dict Arguments to be passed into the heatmap plotting. Returns ------- matplotlib.pyplot.figure Matplotlib figure object Note ---- The highlights parameter assumes that the tree is bifurcating. """ # match the tips dendrogram_width = 20 table, tree = match_tips(table, tree) table = table.T # get edges from tree t = SquareDendrogram.from_tree(tree) t = _tree_coordinates(t) pts = t.coords(width=dendrogram_width, height=table.shape[0]) edges = pts[['child0', 'child1']] edges = edges.dropna(subset=['child0', 'child1']) edges = edges.unstack() edges = pd.DataFrame({'src_node': edges.index.get_level_values(1), 'dest_node': edges.values}) edge_list = [] for i in edges.index: src = edges.loc[i, 'src_node'] dest = edges.loc[i, 'dest_node'] sx, sy = pts.loc[src].x, pts.loc[src].y dx, dy = pts.loc[dest].x, pts.loc[dest].y edge_list.append( {'x0': sx, 'y0': sy, 'x1': sx, 'y1': dy} ) edge_list.append( {'x0': sx, 'y0': dy, 'x1': dx, 'y1': dy} ) edge_list = pd.DataFrame(edge_list) # now plot the stuff fig = plt.figure(figsize=figsize, facecolor='white') plt.rcParams['axes.facecolor'] = 'white' xwidth = 0.2 top_buffer = 0.1 height = 0.8 cbar_pad = 0.95 # dendrogram axes on the left side [axm_x, axm_y, axm_w, axm_h] = [0, top_buffer, xwidth, height] # create a split for the highlights if highlights is not None: h = len(highlights) else: h = 0 hwidth = highlight_width [axs_x, axs_y, axs_w, axs_h] = [xwidth, top_buffer, hwidth * h, height] # heatmap axes hstart = xwidth + (h * hwidth) [ax1_x, ax1_y, ax1_w, ax1_h] = [hstart, top_buffer, cbar_pad - hstart, height] # plot dendrogram ax_dendrogram = fig.add_axes([axm_x, axm_y, axm_w, axm_h], frame_on=True) _plot_dendrogram(ax_dendrogram, table, edge_list, linewidth=linewidth) # plot highlights for dendrogram if highlights is not None: ax_highlights = fig.add_axes([axs_x, axs_y, axs_w, axs_h], frame_on=True, sharey=ax_dendrogram) _plot_highlights_dendrogram(ax_highlights, table, t, highlights) # plot heatmap ax_heatmap = fig.add_axes([ax1_x, ax1_y, ax1_w, ax1_h], frame_on=True, sharey=ax_dendrogram) cbar = _plot_heatmap(ax_heatmap, table, mdvar, grid_col, grid_width, cmap, **kwargs) # plot the colorbar fig.colorbar(cbar) return fig