Exemple #1
0
def pls_balances_cmd(table_file, metadata_file, category, output_file):
    metadata = pd.read_table(metadata_file, index_col=0)
    table = load_table(table_file)
    table = pd.DataFrame(np.array(table.matrix_data.todense()).T,
                         index=table.ids(axis='sample'),
                         columns=table.ids(axis='observation'))

    ctable = pd.DataFrame(clr(centralize(table + 1)),
                          index=table.index,
                          columns=table.columns)

    rfc = PLSRegression(n_components=1)
    if metadata[category].dtype != np.float:
        cats = np.unique(metadata[category])
        groups = (metadata[category] == cats[0]).astype(np.int)
    else:
        groups = metadata[category]

    rfc.fit(X=ctable.values, Y=groups)

    pls_df = pd.DataFrame(rfc.x_weights_,
                          index=ctable.columns,
                          columns=['PLS1'])
    l, r = round_balance(pls_df.values,
                         means_init=[[pls_df.PLS1.min()], [0],
                                     [pls_df.PLS1.max()]],
                         n_init=100)
    num = pls_df.loc[pls_df.PLS1 > r]
    denom = pls_df.loc[pls_df.PLS1 < l]
    diff_features = list(num.index.values)
    diff_features += list(denom.index.values)

    with open(output_file, 'w') as f:
        f.write(','.join(diff_features))
Exemple #2
0
 def test_biplot(self):
     exp = clr(centralize(clr_inv(self.beta)))
     res = regression_biplot(self.beta)
     self.assertIsInstance(res, OrdinationResults)
     u = res.samples.values
     v = res.features.values.T
     npt.assert_allclose(u @ v, np.array(exp), atol=0.5, rtol=0.5)
Exemple #3
0
def dendrogram_heatmap(output_dir: str, table: pd.DataFrame,
                       tree: TreeNode, metadata: MetadataCategory,
                       ndim=10, method='clr', color_map='viridis'):

    nodes = [n.name for n in tree.levelorder() if not n.is_tip()]

    nlen = min(ndim, len(nodes))
    numerator_color, denominator_color = '#fb9a99', '#e31a1c'
    highlights = pd.DataFrame([[numerator_color, denominator_color]] * nlen,
                              index=nodes[:nlen])
    if method == 'clr':
        mat = pd.DataFrame(clr(centralize(table)),
                           index=table.index,
                           columns=table.columns)
    elif method == 'log':
        mat = pd.DataFrame(np.log(table),
                           index=table.index,
                           columns=table.columns)

    # TODO: There are a few hard-coded constants here
    # will need to have some adaptive defaults set in the future
    fig = heatmap(mat, tree, metadata.to_series(), highlights, cmap=color_map,
                  highlight_width=0.01, figsize=(12, 8))
    fig.savefig(os.path.join(output_dir, 'heatmap.svg'))
    fig.savefig(os.path.join(output_dir, 'heatmap.pdf'))

    css = r"""
        .square {
          float: left;
          width: 100px;
          height: 20px;
          margin: 5px;
          border: 1px solid rgba(0, 0, 0, .2);
        }

        .numerator {
          background: %s;
        }

        .denominator {
          background: %s;
        }
    """ % (numerator_color, denominator_color)

    index_fp = os.path.join(output_dir, 'index.html')
    with open(index_fp, 'w') as index_f:
        index_f.write('<html><body>\n')
        index_f.write('<h1>Dendrogram heatmap</h1>\n')
        index_f.write('<img src="heatmap.svg" alt="heatmap">')
        index_f.write('<a href="heatmap.pdf">')
        index_f.write('Download as PDF</a><br>\n')
        index_f.write('<style>%s</style>' % css)
        index_f.write('<div class="square numerator">'
                      'Numerator<br/></div>')
        index_f.write('<div class="square denominator">'
                      'Denominator<br/></div>')
        index_f.write('</body></html>\n')
Exemple #4
0
def balance_classify(table, cats, num_folds, **init_kwds):
    """
    Builds a balance classifier. If categorical, it is assumed
    that the classes are binary.
    """
    skf = KFold(n_splits=num_folds, shuffle=True)

    ctable = pd.DataFrame(clr(centralize(table)),
                          index=table.index,
                          columns=table.columns)

    cv = pd.DataFrame(columns=['Q2', 'AUROC'], index=np.arange(num_folds))
    for i, (train, test) in enumerate(skf.split(ctable.values, cats.values)):

        X_train, X_test = ctable.iloc[train], ctable.iloc[test]
        Y_train, Y_test = cats.iloc[train], cats.iloc[test]
        plsc = PLSRegression(n_components=1)
        plsc.fit(X=X_train, Y=Y_train)
        pls_df = pd.DataFrame(plsc.x_weights_,
                              index=ctable.columns,
                              columns=['PLS1'])

        l, r = round_balance(pls_df, **init_kwds)
        denom = pls_df.loc[pls_df.PLS1 < l]
        num = pls_df.loc[pls_df.PLS1 > r]

        # make the prediction and evaluate the accuracy
        idx = table.index[test]
        pls_balance = (np.log(table.loc[idx, num.index] + 1).mean(axis=1) -
                       np.log(table.loc[idx, denom.index] + 1).mean(axis=1))

        group_fpr, group_tpr, thresholds = roc_curve(y_true=1 -
                                                     (Y_test == 1).astype(int),
                                                     y_score=pls_balance)

        auroc = auc(group_tpr, group_fpr)
        press = ((pls_balance - Y_test)**2).sum()
        tss = ((Y_test.mean() - Y_test)**2).sum()
        Q2 = 1 - (press / tss)

        cv.loc[i, 'Q2'] = Q2
        cv.loc[i, 'AUROC'] = auroc

    # build model on entire dataset
    plsc = PLSRegression(n_components=1)
    plsc.fit(X=table.values, Y=cats.values)
    pls_df = pd.DataFrame(plsc.x_weights_,
                          index=ctable.columns,
                          columns=['PLS1'])
    l, r = round_balance(pls_df, **init_kwds)
    denom = pls_df.loc[pls_df.PLS1 < l]
    num = pls_df.loc[pls_df.PLS1 > r]
    pls_balance = (np.log(table.loc[:, num.index]).mean(axis=1) -
                   np.log(table.loc[:, denom.index]).mean(axis=1))

    return num, denom, pls_balance, cv
Exemple #5
0
def balance_regression(table, cats, num_folds, **init_kwds):
    """
    Builds a balance classifier. If categorical, it is assumed
    that the classes are binary.
    """
    skf = KFold(n_splits=num_folds, shuffle=True)
    cats = cats * -1  # wtf??
    ctable = pd.DataFrame(clr(centralize(table)),
                          index=table.index,
                          columns=table.columns)

    cv = pd.DataFrame(columns=['Q2'], index=np.arange(num_folds))
    for i, (train, test) in enumerate(skf.split(ctable.values, cats.values)):

        X_train, X_test = ctable.iloc[train], ctable.iloc[test]
        Y_train, Y_test = cats.iloc[train], cats.iloc[test]
        plsc = PLSRegression(n_components=1)
        plsc.fit(X=X_train, Y=Y_train)
        pls_df = pd.DataFrame(plsc.x_weights_,
                              index=ctable.columns,
                              columns=['PLS1'])

        l, r = round_balance(pls_df, **init_kwds)
        denom = pls_df.loc[pls_df.PLS1 < l]
        num = pls_df.loc[pls_df.PLS1 > r]

        idx = table.index[train]
        pls_balance = (np.log(table.loc[idx, num.index] + 1).mean(axis=1) -
                       np.log(table.loc[idx, denom.index] + 1).mean(axis=1))
        b_, int_, _, _, _ = linregress(pls_balance, Y_train)

        idx = table.index[test]
        pls_balance = (np.log(table.loc[idx, num.index] + 1).mean(axis=1) -
                       np.log(table.loc[idx, denom.index] + 1).mean(axis=1))
        pred = pls_balance * b_ + int_

        press = ((pred - Y_test)**2).sum()
        tss = ((Y_test.mean() - Y_test)**2).sum()
        Q2 = 1 - (press / tss)

        cv.loc[i, 'Q2'] = Q2

    # build model on entire dataset
    plsc = PLSRegression(n_components=1)
    plsc.fit(X=table.values, Y=cats.values)
    pls_df = pd.DataFrame(plsc.x_weights_,
                          index=ctable.columns,
                          columns=['PLS1'])
    l, r = round_balance(pls_df, **init_kwds)
    denom = pls_df.loc[pls_df.PLS1 < l]
    num = pls_df.loc[pls_df.PLS1 > r]
    pls_balance = (np.log(table.loc[:, num.index]).mean(axis=1) -
                   np.log(table.loc[:, denom.index]).mean(axis=1))

    return num, denom, pls_balance, cv
Exemple #6
0
 def test_center_log(self):
     dat = np.array([[10, 20, 1, 20, 5, 100, 844, 100],
                     [10, 20, 2, 19, 0, 100, 849, 200],
                     [10, 20, 3, 18, 5, 100, 844, 300],
                     [10, 20, 4, 17, 0, 100, 849, 400],
                     [10, 20, 5, 16, 4, 100, 845, 500],
                     [10, 20, 6, 15, 0, 100, 849, 600],
                     [10, 20, 7, 14, 3, 100, 846, 700],
                     [10, 20, 8, 13, 0, 100, 849, 800],
                     [10, 20, 9, 12, 7, 100, 842, 900]]) + 1
     obs = self.test2.center_log()
     exp = clr(dat)
     assert_array_almost_equal(exp, obs.data)
     obs = self.test2.center_log(centralize=True)
     exp = clr(centralize(dat))
     assert_array_almost_equal(exp, obs.data)
Exemple #7
0
def regression_biplot(coefficients: pd.DataFrame) -> skbio.OrdinationResults:
    coefs = clr(centralize(clr_inv(coefficients)))
    u, s, v = np.linalg.svd(coefs)
    pc_ids = ['PC%d' % i for i in range(len(s))]
    samples = pd.DataFrame(u[:, :len(s)] @ np.diag(s),
                           columns=pc_ids, index=coefficients.index)
    features = pd.DataFrame(v.T[:, :len(s)],
                            columns=pc_ids, index=coefficients.columns)
    short_method_name = 'regression_biplot'
    long_method_name = 'Multinomial regression biplot'
    eigvals = pd.Series(s, index=pc_ids)
    proportion_explained = eigvals / eigvals.sum()
    res = OrdinationResults(short_method_name, long_method_name, eigvals,
                            samples=samples, features=features,
                            proportion_explained=proportion_explained)
    return res
    def test_centralize(self):
        cmat = centralize(closure(self.data1))
        npt.assert_allclose(cmat,
                            np.array([[0.22474487, 0.22474487, 0.55051026],
                                      [0.41523958, 0.41523958, 0.16952085]]))
        cmat = centralize(closure(self.data5))
        npt.assert_allclose(cmat,
                            np.array([[0.22474487, 0.22474487, 0.55051026],
                                      [0.41523958, 0.41523958, 0.16952085]]))

        with self.assertRaises(ValueError):
            centralize(self.bad1)
        with self.assertRaises(ValueError):
            centralize(self.bad2)

        centralize(self.data1)
        npt.assert_allclose(self.data1,
                            np.array([[2, 2, 6],
                                      [4, 4, 2]]))
Exemple #9
0
    def test_center_log_ration(self):
        from skbio.stats.composition import clr, centralize

        dat = np.array([[10, 20, 1, 20, 5, 100, 844, 100],
                        [10, 20, 2, 19, 0, 100, 849, 200],
                        [10, 20, 3, 18, 5, 100, 844, 300],
                        [10, 20, 4, 17, 0, 100, 849, 400],
                        [10, 20, 5, 16, 4, 100, 845, 500],
                        [10, 20, 6, 15, 0, 100, 849, 600],
                        [10, 20, 7, 14, 3, 100, 846, 700],
                        [10, 20, 8, 13, 0, 100, 849, 800],
                        [10, 20, 9, 12, 7, 100, 842, 900]]) + 1
        obs = self.test2.center_log_ratio()
        exp = clr(dat)
        assert_array_almost_equal(exp, obs.data)
        obs = self.test2.center_log_ratio(centralize=True)
        exp = clr(centralize(dat))
        assert_array_almost_equal(exp, obs.data)
    def test_centralize(self):
        cmat = centralize(closure(self.cdata1))
        npt.assert_allclose(
            cmat,
            np.array([[0.22474487, 0.22474487, 0.55051026],
                      [0.41523958, 0.41523958, 0.16952085]]))
        cmat = centralize(closure(self.cdata5))
        npt.assert_allclose(
            cmat,
            np.array([[0.22474487, 0.22474487, 0.55051026],
                      [0.41523958, 0.41523958, 0.16952085]]))

        with self.assertRaises(ValueError):
            centralize(self.bad1)
        with self.assertRaises(ValueError):
            centralize(self.bad2)

        # make sure that inplace modification is not occurring
        centralize(self.cdata1)
        npt.assert_allclose(self.cdata1, np.array([[2, 2, 6], [4, 4, 2]]))
    def test_centralize(self):
        cmat = centralize(closure(self.cdata1))
        npt.assert_allclose(cmat,
                            np.array([[0.22474487, 0.22474487, 0.55051026],
                                      [0.41523958, 0.41523958, 0.16952085]]))
        cmat = centralize(closure(self.cdata5))
        npt.assert_allclose(cmat,
                            np.array([[0.22474487, 0.22474487, 0.55051026],
                                      [0.41523958, 0.41523958, 0.16952085]]))

        with self.assertRaises(ValueError):
            centralize(self.bad1)
        with self.assertRaises(ValueError):
            centralize(self.bad2)

        # make sure that inplace modification is not occurring
        centralize(self.cdata1)
        npt.assert_allclose(self.cdata1,
                            np.array([[2, 2, 6],
                                      [4, 4, 2]]))
Exemple #12
0
def dendrogram_heatmap(output_dir: str,
                       table: pd.DataFrame,
                       tree: TreeNode,
                       metadata: MetadataCategory,
                       ndim=10,
                       method='clr',
                       color_map='viridis'):

    nodes = [n.name for n in tree.levelorder()]
    nlen = min(ndim, len(nodes))
    highlights = pd.DataFrame([['#00FF00', '#FF0000']] * nlen,
                              index=nodes[:nlen])
    if method == 'clr':
        mat = pd.DataFrame(clr(centralize(table)),
                           index=table.index,
                           columns=table.columns)
    elif method == 'log':
        mat = pd.DataFrame(np.log(table),
                           index=table.index,
                           columns=table.columns)

    # TODO: There are a few hard-coded constants here
    # will need to have some adaptive defaults set in the future
    fig = heatmap(mat,
                  tree,
                  metadata.to_series(),
                  highlights,
                  cmap=color_map,
                  highlight_width=0.01,
                  figsize=(12, 8))
    fig.savefig(os.path.join(output_dir, 'heatmap.svg'))

    index_fp = os.path.join(output_dir, 'index.html')
    with open(index_fp, 'w') as index_f:
        index_f.write('<html><body>\n')
        index_f.write('<h1>Dendrogram heatmap</h1>\n')
        index_f.write('<img src="heatmap.svg" alt="heatmap">')
        index_f.write('</body></html>\n')
Exemple #13
0
def balance_taxonomy(output_dir: str, table: pd.DataFrame, tree: TreeNode,
                     taxonomy: pd.DataFrame,
                     balance_name: str,
                     pseudocount: float = 0.5,
                     taxa_level: int = 0,
                     n_features: int = 10,
                     threshold: float = None,
                     metadata: qiime2.MetadataColumn = None) -> None:
    if threshold is not None and isinstance(metadata,
                                            qiime2.CategoricalMetadataColumn):
        raise ValueError('Categorical metadata column detected. Only specify '
                         'a threshold when using a numerical metadata column.')

    # make sure that the table and tree match up
    table, tree = match_tips(add_pseudocount(table, pseudocount), tree)

    # parse out headers for taxonomy
    taxa_data = list(taxonomy['Taxon'].apply(lambda x: x.split(';')).values)
    taxa_df = pd.DataFrame(taxa_data, index=taxonomy.index)

    # fill in NAs
    def f(x):
        y = np.array(list(map(lambda k: k is not None, x)))
        i = max(0, np.where(y)[0][-1])
        x[np.logical_not(y)] = [x[i]] * np.sum(np.logical_not(y))
        return x
    taxa_df = taxa_df.apply(f, axis=1)

    num_clade = tree.find(balance_name).children[NUMERATOR]
    denom_clade = tree.find(balance_name).children[DENOMINATOR]

    if num_clade.is_tip():
        num_features = pd.DataFrame(
            {num_clade.name: taxa_df.loc[num_clade.name]}
            ).T
        r = 1
    else:
        num_features = taxa_df.loc[num_clade.subset()]
        r = len(list(num_clade.tips()))

    if denom_clade.is_tip():
        denom_features = pd.DataFrame(
            {denom_clade.name: taxa_df.loc[denom_clade.name]}
            ).T
        s = 1
    else:
        denom_features = taxa_df.loc[denom_clade.subset()]
        s = len(list(denom_clade.tips()))

    b = (np.log(table.loc[:, num_features.index]).mean(axis=1) -
         np.log(table.loc[:, denom_features.index]).mean(axis=1))

    b = b * np.sqrt(r * s / (r + s))
    balances = pd.DataFrame(b, index=table.index,
                            columns=[balance_name])

    # the actual colors for the numerator and denominator
    num_color = sns.color_palette("Paired")[0]
    denom_color = sns.color_palette("Paired")[1]

    fig, (ax_num, ax_denom) = plt.subplots(2)
    balance_barplots(tree, balance_name, taxa_level, taxa_df,
                     denom_color=denom_color, num_color=num_color,
                     axes=(ax_num, ax_denom))

    ax_num.set_title(
        r'$%s_{numerator} \; taxa \; (%d \; taxa)$' % (
            balance_name, len(num_features)))
    ax_denom.set_title(
        r'$%s_{denominator} \; taxa \; (%d \; taxa)$' % (
            balance_name, len(denom_features)))
    ax_denom.set_xlabel('Number of unique taxa')
    plt.tight_layout()
    fig.savefig(os.path.join(output_dir, 'barplots.svg'))
    fig.savefig(os.path.join(output_dir, 'barplots.pdf'))

    dcat = None
    multiple_cats = False
    if metadata is not None:
        fig2, ax = plt.subplots()
        c = metadata.to_series()
        data, c = match(balances, c)
        data[c.name] = c
        y = data[balance_name]

        # check if continuous
        if isinstance(metadata, qiime2.NumericMetadataColumn):
            ax.scatter(c.values, y)
            ax.set_xlabel(c.name)
            if threshold is None:
                threshold = c.mean()
            dcat = c.apply(
                lambda x: '%s < %f' % (c.name, threshold)
                if x < threshold
                else '%s > %f' % (c.name, threshold)
            )
            sample_palette = pd.Series(sns.color_palette("Set2", 2),
                                       index=dcat.value_counts().index)

        elif isinstance(metadata, qiime2.CategoricalMetadataColumn):

            sample_palette = pd.Series(
                sns.color_palette("Set2", len(c.value_counts())),
                index=c.value_counts().index)

            try:
                pd.to_numeric(metadata.to_series())
            except ValueError:
                pass
            else:
                raise ValueError('Categorical metadata column '
                                 f'{metadata.name!r} contains only numerical '
                                 'values. At least one value must be '
                                 'non-numerical.')

            balance_boxplot(balance_name, data, y=c.name, ax=ax,
                            palette=sample_palette)
            if len(c.value_counts()) > 2:
                warnings.warn(
                    'More than 2 categories detected in categorical metadata '
                    'column. Proportion plots will not be displayed',
                    stacklevel=2)
                multiple_cats = True
            else:
                dcat = c

        else:
            # Some other type of MetadataColumn
            raise NotImplementedError()

        ylabel = (r"$%s = \ln \frac{%s_{numerator}}"
                  "{%s_{denominator}}$") % (balance_name,
                                            balance_name,
                                            balance_name)
        ax.set_title(ylabel, rotation=0)
        ax.set_ylabel('log ratio')
        fig2.savefig(os.path.join(output_dir, 'balance_metadata.svg'))
        fig2.savefig(os.path.join(output_dir, 'balance_metadata.pdf'))

        if not multiple_cats:
            # Proportion plots
            # first sort by clr values and calculate average fold change
            ctable = pd.DataFrame(clr(centralize(table)),
                                  index=table.index, columns=table.columns)

            left_group = dcat.value_counts().index[0]
            right_group = dcat.value_counts().index[1]

            lidx, ridx = (dcat == left_group), (dcat == right_group)
            if b.loc[lidx].mean() > b.loc[ridx].mean():
                # double check ordering and switch if necessary
                # careful - the left group is also commonly associated with
                # the denominator.
                left_group = dcat.value_counts().index[1]
                right_group = dcat.value_counts().index[0]
                lidx, ridx = (dcat == left_group), (dcat == right_group)
            # we are not performing a statistical test here
            # we're just trying to figure out a way to sort the data.
            num_fold_change = ctable.loc[:, num_features.index].apply(
                lambda x: ttest_ind(x[ridx], x[lidx])[0])
            num_fold_change = num_fold_change.sort_values(
                ascending=False
            )

            denom_fold_change = ctable.loc[:, denom_features.index].apply(
                lambda x: ttest_ind(x[ridx], x[lidx])[0])
            denom_fold_change = denom_fold_change.sort_values(
                ascending=True
            )

            metadata = pd.DataFrame({dcat.name: dcat})
            top_num_features = num_fold_change.index[:n_features]
            top_denom_features = denom_fold_change.index[:n_features]

            fig3, (ax_denom, ax_num) = plt.subplots(1, 2)
            proportion_plot(
                table, metadata,
                category=metadata.columns[0],
                left_group=left_group,
                right_group=right_group,
                feature_metadata=taxa_df,
                label_col=taxa_level,
                num_features=top_num_features,
                denom_features=top_denom_features,
                # Note that the syntax is funky and counter
                # intuitive. This will need to be properly
                # fixed here
                # https://github.com/biocore/gneiss/issues/244
                num_color=sample_palette.loc[right_group],
                denom_color=sample_palette.loc[left_group],
                axes=(ax_num, ax_denom))
            # The below is overriding the default colors in the
            # numerator / denominator this will also need to be fixed in
            # https://github.com/biocore/gneiss/issues/244
            max_ylim, min_ylim = ax_denom.get_ylim()
            num_h, denom_h = n_features, n_features

            space = (max_ylim - min_ylim) / (num_h + denom_h)
            ymid = (max_ylim - min_ylim) * num_h
            ymid = ymid / (num_h + denom_h) - 0.5 * space

            ax_denom.axhspan(min_ylim, ymid,
                             facecolor=num_color,
                             zorder=0)
            ax_denom.axhspan(ymid, max_ylim,
                             facecolor=denom_color,
                             zorder=0)

            ax_num.axhspan(min_ylim, ymid,
                           facecolor=num_color,
                           zorder=0)
            ax_num.axhspan(ymid, max_ylim,
                           facecolor=denom_color,
                           zorder=0)

            fig3.subplots_adjust(
                # the left side of the subplots of the figure
                left=0.3,
                # the right side of the subplots of the figure
                right=0.9,
                # the bottom of the subplots of the figure
                bottom=0.1,
                # the top of the subplots of the figure
                top=0.9,
                # the amount of width reserved for blank space
                # between subplots
                wspace=0,
                # the amount of height reserved for white space
                # between subplots
                hspace=0.2,
            )

            fig3.savefig(os.path.join(output_dir, 'proportion_plot.svg'))
            fig3.savefig(os.path.join(output_dir, 'proportion_plot.pdf'))

    index_fp = os.path.join(output_dir, 'index.html')
    with open(index_fp, 'w') as index_f:
        index_f.write('<html><body>\n')
        if metadata is not None:
            index_f.write('<h1>Balance vs %s </h1>\n' % c.name)
            index_f.write(('<img src="balance_metadata.svg" '
                           'alt="barplots">\n\n'
                           '<a href="balance_metadata.pdf">'
                           'Download as PDF</a><br>\n'))

        if not multiple_cats:
            index_f.write('<h1>Proportion Plot </h1>\n')
            index_f.write(('<img src="proportion_plot.svg" '
                           'alt="proportions">\n\n'
                           '<a href="proportion_plot.pdf">'
                           'Download as PDF</a><br>\n'))

        index_f.write(('<h1>Balance Taxonomy</h1>\n'
                       '<img src="barplots.svg" alt="barplots">\n\n'
                       '<a href="barplots.pdf">'
                       'Download as PDF</a><br>\n'
                       '<h3>Numerator taxa</h3>\n'
                       '<a href="numerator.csv">\n'
                       'Download as CSV</a><br>\n'
                       '<h3>Denominator taxa</h3>\n'
                       '<a href="denominator.csv">\n'
                       'Download as CSV</a><br>\n'))

        num_features.to_csv(os.path.join(output_dir, 'numerator.csv'),
                            header=True, index=True)
        denom_features.to_csv(os.path.join(output_dir, 'denominator.csv'),
                              header=True, index=True)
        index_f.write('</body></html>\n')
microbe_iv['group'] = microbe_iv['group'].map(catdict)
metabolite_iv['group'] = metabolite_iv['group'].map(catdict)

# highlight features with p-value <= 0.001
max_pval = 0.001

microbe_iv.loc[microbe_iv.pval > max_pval, 'group'] = 'None'
print('Number of significant microbes: %d' %
      microbe_iv[microbe_iv['group'] != 'None'].shape[0])

metabolite_iv.loc[metabolite_iv.pval > max_pval, 'group'] = 'None'
print('Number of significant metabolites: %d' %
      metabolite_iv[metabolite_iv['group'] != 'None'].shape[0])

plssvd = PLSSVD(n_components=3)
plssvd.fit(X=clr(centralize(multiplicative_replacement(microbes))),
           Y=clr(centralize(multiplicative_replacement(metabolites))))


def standardize(A):
    A = (A - np.mean(A, axis=0)) / np.std(A, axis=0)
    return A


pls_microbes = pd.DataFrame(standardize(plssvd.x_weights_),
                            columns=['PCA1', 'PCA2', 'PCA3'],
                            index=microbes.columns)
pls_metabolites = pd.DataFrame(standardize(plssvd.y_weights_),
                               columns=['PCA1', 'PCA2', 'PCA3'],
                               index=metabolites.columns)