Beispiel #1
0
def filter_by_abundance(df, abundance_threshold=0.1, percent_axis=0, filter_axis=1):
    """
    Calculates taxa percent abundance per sample. (taxa abundance / overall abundance in a sample).

    Input:
        - df: pandas dataframe describing taxa/ASV abundance per sample.
        - abundance_threshold: % value threshold (0: 0% cutoff, 1: 100% cutoff)
        - percent_axis: axis used to calculate % abundance
        - filter_axis: axis used to filter (filter if at least one of the samples has % abundance >= threshold)

        If samples as columns, and taxa/ASV as rows, do percent_axis=0, filter_axis=1.
        If samples as rows, and taxa/ASV as columns, do percent_axis=1, filter_axis=0.

        Default options assume samples as columns and taxa/ASV as rows.
    """
    percent_df = calculate_percent_value(df, percent_axis)

    # keep taxa/ASVs if at least one of the samples has % abundance >= threshold
    to_keep = (percent_df >= abundance_threshold).any(filter_axis)

    # Raise error if no entries after filtering?
    if(to_keep.any() == False):
        raise AXIOME3Error("Zero entries after filtering by abundance at {} threshold".format(abundance_threshold))

    # Note that original row index should preserved.
    # Will be buggy if row index is reindexed. (add test case for this?)
    filtered = df.loc[to_keep, ]

    return filtered
Beispiel #2
0
def filter_by_keyword(taxa, keyword=None):
    """
	Filter df by user specified keywordw

	Input:
		- taxa: original taxa name (pd.Series)
		- keyword: keyword string to search

	Returns:
		- boolean; True if match, False otherwise (pd.Series)
	"""
    # If keyword not specified, return all True
    if (keyword is None):
        default = pd.Series([True for i in range(0, taxa.shape[0])])
        default.index = taxa.index

        return default

    match = taxa.str.contains(str(keyword), case=False, regex=False)

    # Raise ValueError if 0 match?
    if (match.any() == False):
        message = "Specified search term, {term}, is NOT found in any entries".format(
            term=keyword)

        raise AXIOME3Error(message)

    return match
Beispiel #3
0
def convert_qiime2_2_skbio(pcoa_artifact):
    """
    Convert QIIME2 PCoA artifact to skbio OrdinationResults object.

    ** Will throw errors if the artifact type is NOT PCoAResults **
    You may check Artifact type by checking the "type" property of the Artifact
    object after loading the artifact via 'Artifact.load(artifact)'
    """
    try:
        pcoa_artifact = Artifact.load(pcoa_artifact)

        # Check Artifact type
        if (str(pcoa_artifact.type) != "PCoAResults"):
            msg = "Input QIIME2 Artifact is not of the type 'PCoAResults'!"
            raise AXIOME3Error(msg)

        pcoa = pcoa_artifact.view(ordination.OrdinationResults)
    except AXIOME3Error:
        raise

    # Rename PCoA coordinates index (so left join can be performed later)
    coords = pcoa.samples

    coords.index.names = ['SampleID']

    # Rename columns to have more meaningful names
    num_col = coords.shape[1]
    col_names = ['Axis ' + str(i) for i in range(1, num_col + 1)]
    coords.columns = col_names

    pcoa.samples = coords

    return pcoa
Beispiel #4
0
def find_sample_intersection(feature_table_df, abundance_df,
                             sample_metadata_df, environmental_metadata_df):
    """
	Find intersection of feature table, sample metadata, and environmental metadata.

	Inputs:
		- feature_table_df: feature table in pandas DataFrame (samples as row, taxa/ASV as columns)
		- abundance_df: feature table in pandas DataFrame (samples as row, taxa/ASV as columns; to overlay as taxa bubbles later)
		- sample_metadata_df: sample metadata in pandas DataFrame (samples as row, metadata as columns)
		- environmental_metadata_df: environmental metadata in pandas DataFrame (samples as row, metadata as columns)

		Assumes sampleID as index
	"""
    combined_df = pd.concat([
        feature_table_df, abundance_df, sample_metadata_df,
        environmental_metadata_df
    ],
                            join="inner",
                            axis=1)

    intersection_samples = combined_df.index

    if (len(intersection_samples) == 0):
        raise AXIOME3Error(
            "Feature table, sample metadata, and environmental metadata do NOT share any samples..."
        )

    intersection_feature_table_df = feature_table_df.loc[
        intersection_samples, ]
    intersection_abundance_df = abundance_df.loc[intersection_samples, ]
    intersection_sample_metadata_df = sample_metadata_df.loc[
        intersection_samples, ]
    intersection_environmental_metadata_df = environmental_metadata_df.loc[
        intersection_samples, ]

    # summary about samples not used
    feature_table_omitted_samples = ','.join([
        str(sample) for sample in feature_table_df.index
        if sample not in intersection_samples
    ])
    sample_metadata_omitted_samples = ','.join([
        str(sample) for sample in sample_metadata_df.index
        if sample not in intersection_samples
    ])
    environmental_metadata_omitted_samples = ','.join([
        str(sample) for sample in environmental_metadata_df.index
        if sample not in intersection_samples
    ])

    sample_summary = dedent("""\
		Omitted samples in feature table,{feature_table_omitted_samples}
		Omitted samples in sample metadata,{sample_metadata_omitted_samples}
		Omitted samples in environmental metadata,{environmental_metadata_omitted_samples}
	""".format(feature_table_omitted_samples=feature_table_omitted_samples,
            sample_metadata_omitted_samples=sample_metadata_omitted_samples,
            environmental_metadata_omitted_samples=
            environmental_metadata_omitted_samples))

    return intersection_feature_table_df, intersection_abundance_df, intersection_sample_metadata_df, intersection_environmental_metadata_df, sample_summary
Beispiel #5
0
def filter_by_abundance(df, abundance_col, cutoff=0.2):
    """
	Filter dataframe by a specified column by abundance
	"""
    if (abundance_col not in df.columns):
        raise AXIOME3Error(
            "Column {col} does not exist in the dataframe".format(
                col=abundance_col))

    filtered_df = df[df[abundance_col] >= cutoff]

    if (filtered_df.shape[0] == 0):
        raise AXIOME3Error(
            "No entries left with {cutoff} abundance threshold".format(
                cutoff=cutoff))

    return filtered_df
Beispiel #6
0
def check_column_exists(metadata_df, target_primary, target_secondary=None):
    """
	Check if metadata has specified target columns
	"""
    # Make sure user specified target columns actually exist in the dataframe
    if (target_primary not in metadata_df.columns):
        msg = "Column '{column}' does NOT exist in the metadata!".format(
            column=target_primary)

        raise AXIOME3Error(msg)

    if (target_secondary is not None
            and target_secondary not in metadata_df.columns):
        msg = "Column '{column}' does NOT exist in the metadata!".format(
            column=target_secondary)

        raise AXIOME3Error(msg)
Beispiel #7
0
def calculate_vector_magnitude_df(df: pd.DataFrame, col1: str, col2: str):
    """Given df with coordinate columns col1 and col2, calculate magnitude"""
    for col in [col1, col2]:
        if col not in df.columns:
            raise AXIOME3Error(f"{col} not present in the df!")

    df['magnitude'] = np.sqrt(df[col1]**2 + df[col2]**2)

    return df
Beispiel #8
0
def collapse_taxa(feature_table_artifact,
                  taxonomy_artifact,
                  collapse_level="asv"):
    """
	Collapse feature table to user specified taxa level (ASV by default).

	Input:
		- QIIME2 artifact of type FeatureData[Taxonomy]

	Returns:
		- pd.DataFrame
			(taxa/ASV as rows, samples as columns, numeric index, appends 'Taxon' column)
	"""
    collapse_level = collapse_level.lower()

    if (collapse_level not in VALID_COLLAPSE_LEVELS):
        raise AXIOME3Error(
            "Specified collapse level, {collapse_level}, is NOT valid!".format(
                collapse_level=collapse_level))

    # handle ASV case
    if (collapse_level == "asv"):
        # By default, feature table has samples as rows, and ASV as columns
        feature_table_df = feature_table_artifact.view(pd.DataFrame)

        # Transpose feature table
        feature_table_df_T = feature_table_df.T

        # By default, taxonomy has ASV as rows, and metadata as columns
        taxonomy_df = taxonomy_artifact.view(pd.DataFrame)

        # Combine the two df (joins on index (ASV))
        combined_df = feature_table_df_T.join(taxonomy_df)

        # Drop "Confidence" column and use numeric index
        final_df = combined_df.drop(["Confidence"],
                                    axis="columns").reset_index(drop=True)

        return final_df

    table_artifact = collapse(table=feature_table_artifact,
                              taxonomy=taxonomy_artifact,
                              level=VALID_COLLAPSE_LEVELS[collapse_level])

    # By default, it has samples as rows, and taxa as columns
    collapsed_df = table_artifact.collapsed_table.view(pd.DataFrame)

    # Transpose
    collapsed_df_T = collapsed_df.T

    # Append "Taxon" column
    collapsed_df_T["Taxon"] = collapsed_df_T.index

    # Reset index
    final_df = collapsed_df_T.reset_index(drop=True)

    return final_df
Beispiel #9
0
def check_artifact_type(artifact_path, artifact_type):
    q2_artifact = Artifact.load(artifact_path)

    # Raise ValueError if not appropriate type
    if(str(q2_artifact.type) != ARTIFACT_TYPES[artifact_type]):
        msg = "Input QIIME2 Artifact is not of the type '{}'".format(ARTIFACT_TYPES[artifact_type])
        raise AXIOME3Error(msg)

    return q2_artifact
Beispiel #10
0
def filter_by_wascore_threshold(normalized_wascores_df, wa_threshold):
	"""
	Filter weighted average DataFrame by normalized abundance
	"""
	if('abundance' not in normalized_wascores_df.columns):
		raise AXIOME3Error("normalized taxa count column does not exist")

	filtered_df = normalized_wascores_df[normalized_wascores_df['abundance'] > wa_threshold]

	return filtered_df
Beispiel #11
0
def convert_col_dtype(df, col, dtype):
    """
	Convert given column type to specified dtype
	"""
    if (col not in df.columns):
        raise AXIOME3Error(
            "Column, '{}', does not exist in the dataframe".format(col))

    try:
        df[col] = df[col].astype(dtype)
    except ValueError:
        raise AXIOME3Error(
            "Column, '{col}', cannot be converted to data type, '{dtype}'".
            format(col=col, dtype=dtype))
    except TypeError:
        raise AXIOME3Error(
            "Data type, '{dtype}', is not supported by pandas".format(
                dtype=dtype))

    return df
Beispiel #12
0
def round_percentage(df, abundance_col, num_decimal=3):
    if (abundance_col not in df.columns):
        raise AXIOME3Error(
            "Column {col} does not exist in the dataframe".format(
                col=abundance_col))

    # display value in % instead of decimal
    df[abundance_col] = df[abundance_col] * 100
    df[abundance_col] = df[abundance_col].round(1)

    return df
Beispiel #13
0
def alphabetical_sort_df(df, cols):
    """
	Alphabetically sort dataframe by a given column

	Input;
		cols: list of columns to sort dataframe by
	"""
    for col in cols:
        if (col not in df.columns):
            raise AXIOME3Error(
                "Column {col} does not exist in the dataframe".format(col=col))

    sorted_df = df.sort_values(by=cols)

    return sorted_df
Beispiel #14
0
def load_env_metadata(env_metadata_path):
    # Use QIIME2 Metadata API to load metadata
    env_metadata_obj = Metadata.load(env_metadata_path)
    env_metadata_df = env_metadata_obj.to_dataframe()

    # Rename index
    env_metadata_df.index.names = ['SampleID']
    # environmental metadata columns MUST be numeric type
    # Drop all non-numeric columns
    numeric_env_df = env_metadata_df.select_dtypes(include='number')

    if (len(numeric_env_df.columns) == 0):
        raise AXIOME3Error(
            "Environmental metadata must contain at least one numeric column!")

    return numeric_env_df
Beispiel #15
0
def calculate_dissimilarity_matrix(feature_table, method="Bray-Curtis"):
	"""
	Calculates dissimilarity matarix using the feature table.
	It uses R's vegan package (using rpy2 interface)

	Inputs:
		- feature_table_df: feature table (rpy2.robjects)
		- method: dissimilarity index (see 'vegdist' R documentation for supported methods)

	Outputs:
		- distance matrix (rpy2.robjects)
	"""
	vegan = importr('vegan')

	if (method not in VEGDIST_OPTIONS):
		raise AXIOME3Error("Specified dissmilarity method, {method} is not supported!".format(method=method))

	return vegan.vegdist(feature_table, VEGDIST_OPTIONS[method])
Beispiel #16
0
def split_manifest(manifest_path, output_dir):
    manifest_df = pd.read_csv(manifest_path)

    if ('run_ID' not in manifest_df.columns):
        raise AXIOME3Error("'run_ID' column must exist in the manifes file!")

    # Get unique run IDs
    unique_run_IDs = manifest_df['run_ID'].unique()

    # Split table and write
    for run_ID in unique_run_IDs:
        output_filename = "manifest_" + str(run_ID) + ".csv"
        output_filepath = os.path.join(output_dir, output_filename)

        single_run_table = manifest_df[manifest_df['run_ID'] == run_ID]
        # Drop run_ID column
        single_run_table = single_run_table.drop(['run_ID'], axis=1)
        if (single_run_table.shape[0] != 0):
            pd.DataFrame.to_csv(single_run_table, output_filepath, index=False)
Beispiel #17
0
def process_input_with_R(intersection_feature_table_df, intersection_abundance_df,
	intersection_sample_metadata_df, intersection_environmental_metadata_df,
	dissmilarity_index, R2_threshold, pval_threshold, wa_threshold,
	PC_axis_one, PC_axis_two, output_dir):

	feature_table_filepath = os.path.join(output_dir, "feature_table.csv")
	taxa_filepath = os.path.join(output_dir, "abundance_df.csv")
	metadata_path = os.path.join(output_dir, "metadata_df.csv")
	env_filepath = os.path.join(output_dir, "env_metadata_df.csv")
	
	intersection_feature_table_df.to_csv(feature_table_filepath, index_label="SampleID")
	intersection_abundance_df.to_csv(taxa_filepath, index_label="SampleID")
	intersection_sample_metadata_df.to_csv(metadata_path, index_label="SampleID")
	intersection_environmental_metadata_df.to_csv(env_filepath, index_label="SampleID")

	cmd = [
		'Rscript',
		'/pipeline/AXIOME3/scripts/qiime2_helper/pcoa_triplot.R',
		feature_table_filepath,
		taxa_filepath,
		metadata_path,
		env_filepath,
		dissmilarity_index,
		str(R2_threshold),
		str(pval_threshold),
		str(wa_threshold),
		str(PC_axis_one),
		str(PC_axis_two),
		output_dir
	]

	proc = subprocess.Popen(
		cmd,
		stdout=subprocess.PIPE,
		stderr=subprocess.PIPE
	)

	stdout, stderr = proc.communicate()

	return_code = proc.returncode

	if not (return_code == 0):
		raise AXIOME3Error("R error: " + stderr.decode('utf-8'))
Beispiel #18
0
def group_by_taxa(taxa, groupby="phylum", collapse_level="asv"):
    """
	Extract taxa name from SILVA taxa format at the user specified level.

	Input:
	- taxa: list of SILVA taxa names (pd.Series)
	- groupby: taxa to group by
	- collapse_level: taxa level to collapse feature table at
	"""
    groupby = groupby.lower()
    collapse_level = collapse_level.lower()

    # Taxa to group by should be more general than collapsed level
    # e.g. groupby="asv", collapse_level="phylum" is not allowed
    if (VALID_COLLAPSE_LEVELS[groupby] >
            VALID_COLLAPSE_LEVELS[collapse_level]):
        raise AXIOME3Error(
            "taxa to groupby must be more general than the taxa to collapse!\nspecified groupby:{groupby}\nspecified collapse level:{collapse_level}"
            .format(groupby=groupby, collapse_level=collapse_level))

    def groupby_helper(taxa_name, groupby):
        """
		Expected input format: domain;phylum;class ... ;genus;species
		Not all entries may exist
		"""
        split = taxa_name.split(';')

        # If entry does not exist, return it as unclassified
        if (len(split) < VALID_COLLAPSE_LEVELS[groupby]):
            return "unclassified"

        selected = split[VALID_COLLAPSE_LEVELS[groupby] - 1]

        new_taxa = re.sub(r"\s*D_[0-6]__", "", selected)

        return new_taxa

    grouped_taxa = [groupby_helper(t, groupby) for t in taxa]

    return grouped_taxa
Beispiel #19
0
def generate_pcoa_plot(pcoa,
                       metadata,
                       colouring_variable,
                       shape_variable=None,
                       primary_dtype="category",
                       secondary_dtype="category",
                       palette='Paired',
                       brewer_type='qual',
                       alpha=0.9,
                       stroke=0.6,
                       point_size=6,
                       x_axis_text_size=10,
                       y_axis_text_size=10,
                       legend_title_size=10,
                       legend_text_size=10,
                       PC_axis1=1,
                       PC_axis2=2):

    # raise AXIOME3Error if PC_axis1 == PC_axis2
    if (PC_axis1 == PC_axis2):
        raise AXIOME3Error("PC axis one and PC axis two cannot be equal!")

    # Load metadata file
    metadata_df = load_metadata(metadata)

    # Inner join metadata file with ordinations
    pcoa_coords = pcoa.samples
    pcoa_data_samples = pd.merge(pcoa_coords,
                                 right=metadata_df,
                                 left_index=True,
                                 right_index=True)

    # Make x and y axis labels
    proportions = pcoa.proportion_explained

    x_explained_idx = PC_axis1 - 1
    y_explained_idx = PC_axis2 - 1
    pc_1 = 'Axis ' + str(PC_axis1)
    pc_2 = 'Axis ' + str(PC_axis2)

    x_explained = str(round(proportions[x_explained_idx] * 100, 1))
    y_explained = str(round(proportions[y_explained_idx] * 100, 1))

    # Convert user specified columns to category
    # **BIG ASSUMPTION HERE**
    pcoa_data_samples = convert_col_dtype(pcoa_data_samples,
                                          colouring_variable, primary_dtype)

    if (shape_variable is not None):
        pcoa_data_samples = convert_col_dtype(pcoa_data_samples,
                                              shape_variable, secondary_dtype)

    # Pre-format target variables
    #primary_target_fill = 'factor(' + str(colouring_variable) + ')'
    primary_target_fill = str(colouring_variable)

    if (shape_variable is not None):
        secondary_target_fill = str(shape_variable)

        ggplot_obj = ggplot(
            pcoa_data_samples,
            aes(x=pc_1,
                y=pc_2,
                fill=primary_target_fill,
                shape=secondary_target_fill))
    else:
        ggplot_obj = ggplot(pcoa_data_samples,
                            aes(x=pc_1, y=pc_2, fill=primary_target_fill))

    # Plot the data
    pcoa_plot = (
        ggplot_obj + geom_point(size=point_size, alpha=alpha, stroke=stroke) +
        theme_bw() +
        theme(panel_grid=element_blank(),
              line=element_line(colour='black'),
              panel_border=element_rect(colour='black'),
              legend_title=element_text(size=legend_title_size, face='bold'),
              legend_key=element_blank(),
              legend_text=element_text(size=legend_text_size),
              axis_title_x=element_text(size=x_axis_text_size),
              axis_title_y=element_text(size=y_axis_text_size),
              legend_key_height=5,
              text=element_text(family='Arial', colour='black')) +
        xlab(pc_1 + ' (' + x_explained + '%)') +
        ylab(pc_2 + ' (' + y_explained + '%)'))

    # Custom colours
    color_len = len(pcoa_data_samples[colouring_variable].unique())
    color_name = str(colouring_variable)
    pcoa_plot = add_fill_colours_from_users(pcoa_plot, color_name, palette,
                                            brewer_type)

    # Custom shapes
    if (shape_variable is not None):
        shape_len = len(pcoa_data_samples[shape_variable].unique())
        shape_name = str(shape_variable)

        pcoa_plot = add_discrete_shape(pcoa_plot, shape_len, shape_name)

    return pcoa_plot
Beispiel #20
0
def make_triplot(merged_df,
                 vector_arrow_df,
                 wascores_df,
                 proportion_explained,
                 fill_variable,
                 PC_axis_one=1,
                 PC_axis_two=2,
                 alpha=0.9,
                 stroke=0.6,
                 point_size=6,
                 x_axis_text_size=10,
                 y_axis_text_size=10,
                 legend_title_size=10,
                 legend_text_size=10,
                 fill_variable_dtype="category",
                 palette='Paired',
                 brewer_type='qual',
                 sample_text_size=6,
                 taxa_text_size=6,
                 vector_arrow_text_size=6):
    """

	"""
    # raise AXIOME3Error if PC_axis1 == PC_axis2
    if (PC_axis_one == PC_axis_two):
        raise AXIOME3Error("PC axis one and PC axis two cannot be equal!")

    # convert data type to user specified value
    merged_df = convert_col_dtype(merged_df, fill_variable,
                                  fill_variable_dtype)

    if (str(merged_df[fill_variable].dtype) == 'category'):
        # Remove unused categories
        merged_df[fill_variable] = merged_df[
            fill_variable].cat.remove_unused_categories()

    # PC axes to visualize
    pc1 = 'Axis ' + str(PC_axis_one)
    pc2 = 'Axis ' + str(PC_axis_two)

    # Plot the data
    base_plot = ggplot(
        merged_df, aes(x=pc1, y=pc2, label=merged_df.index,
                       fill=fill_variable))
    base_points = geom_point(size=point_size, alpha=alpha, stroke=stroke)

    base_anno = geom_text(size=taxa_text_size)

    PC_axis_one_variance = str(
        round(proportion_explained.loc[pc1, 'proportion_explained'], 1))
    PC_axis_two_variance = str(
        round(proportion_explained.loc[pc2, 'proportion_explained'], 1))
    x_label_placeholder = pc1 + " (" + PC_axis_one_variance + "%)"
    y_label_placeholder = pc2 + " (" + PC_axis_two_variance + "%)"
    x_lab = xlab(x_label_placeholder)
    y_lab = ylab(y_label_placeholder)

    my_themes = theme(
        panel_grid=element_blank(),  # No grid
        panel_border=element_rect(colour='black'),  # black outline
        legend_key=element_blank(),  # No legend background
        axis_title_x=element_text(size=x_axis_text_size),  # x axis label size
        axis_title_y=element_text(size=y_axis_text_size),  # y axis label size
        legend_title=element_text(size=legend_title_size,
                                  face='bold'),  # legend title size
        legend_text=element_text(size=legend_text_size),  # legend text size
        aspect_ratio=1,
        text=element_text(family='Arial',
                          colour='black')  # Arial font, black colour
    )

    plot = (
        base_plot + base_points +
        #base_anno +
        x_lab + y_lab + theme_bw() + my_themes)

    # Add point colour if categorical
    if (str(merged_df[fill_variable].dtype) == 'category'):
        fill_name = str(fill_variable)
        plot = add_fill_colours_from_users(plot, fill_name, palette,
                                           brewer_type)

    # Taxa points
    if (wascores_df.shape[0] > 0):
        taxa_points = geom_point(aes(x=pc1, y=pc2, size='abundance'),
                                 colour="black",
                                 fill='none',
                                 data=wascores_df,
                                 stroke=stroke,
                                 inherit_aes=False,
                                 show_legend=True)

        # Taxa annotation
        taxa_anno = geom_text(aes(x=pc1, y=pc2, label=wascores_df.index),
                              colour="black",
                              data=wascores_df,
                              inherit_aes=False,
                              size=taxa_text_size)

        if (point_size <= 5):
            taxa_max_size = point_size * 4
        elif (point_size <= 10):
            taxa_max_size = point_size * 3
        else:
            taxa_max_size = point_size * 1.5

        breakpoints, labels = get_axis_breakpoints(
            wascores_df["abundance"].min(), wascores_df["abundance"].max(), 4)
        plot = plot + taxa_points + scale_size_area(
            max_size=taxa_max_size,
            breaks=breakpoints,
            labels=labels,
        ) + taxa_anno
        #plot = plot + taxa_points + scale_size_area(max_size=taxa_max_size,) + taxa_anno

    # if vector arrows pass the thresohld
    if (vector_arrow_df.shape[0] > 0):
        env_arrow = geom_segment(aes(x=0,
                                     xend=pc1,
                                     y=0,
                                     yend=pc2,
                                     colour=vector_arrow_df.index),
                                 data=vector_arrow_df,
                                 arrow=arrow(length=0.1),
                                 inherit_aes=False,
                                 show_legend=False)

        env_anno = geom_text(aes(x=pc1,
                                 y=pc2,
                                 label=vector_arrow_df.index,
                                 colour=vector_arrow_df.index),
                             size=vector_arrow_text_size,
                             data=vector_arrow_df,
                             inherit_aes=False,
                             show_legend=False)

        plot = plot + env_arrow + env_anno

    return plot
Beispiel #21
0
def collapse_taxa(feature_table_artifact,
                  taxonomy_artifact,
                  sampling_depth=0,
                  collapse_level="asv"):
    """
	Collapse feature table to user specified taxa level (ASV by default).

	Input:
		- QIIME2 artifact of type FeatureData[Taxonomy]

	Returns:
		- pd.DataFrame
			(taxa/ASV as rows, samples as columns, numeric index, appends 'Taxon' column)
	"""
    collapse_level = collapse_level.lower()

    if (collapse_level not in VALID_COLLAPSE_LEVELS):
        raise AXIOME3Error(
            "Specified collapse level, {collapse_level}, is NOT valid!".format(
                collapse_level=collapse_level))

    # Rarefy the table to user specified sampling depth
    if (sampling_depth < 0):
        raise AXIOME3Error("Sampling depth cannot be a negative number!")
    # don't rarefy is sampling depth equals 0
    if (sampling_depth > 0):
        try:
            rarefied = rarefy(feature_table_artifact,
                              sampling_depth=sampling_depth)
        except ValueError:
            raise AXIOME3Error(
                "No samples or features left after rarefying at {}".format(
                    sampling_depth))
        feature_table_artifact = rarefied.rarefied_table

    # handle ASV case
    if (collapse_level == "asv"):
        # By default, feature table has samples as rows, and ASV as columns
        feature_table_df = feature_table_artifact.view(pd.DataFrame)

        # Transpose feature table
        feature_table_df_T = feature_table_df.T

        # By default, taxonomy has ASV as rows, and metadata as columns
        taxonomy_df = taxonomy_artifact.view(pd.DataFrame)

        # Combine the two df (joins on index (ASV))
        combined_df = feature_table_df_T.join(taxonomy_df, how='inner')

        # Drop "Confidence" column and use numeric index
        final_df = combined_df.drop(["Confidence"],
                                    axis="columns").reset_index(drop=True)

        if (final_df.shape[0] == 0 or final_df.shape[1] == 0):
            raise AXIOME3Error(
                "No data to process. Please check if 1. input feature table is empty, 2. input taxonomy is empty, 3. input feature table and taxonomy share common ASVs."
            )

        return final_df

    try:
        table_artifact = collapse(table=feature_table_artifact,
                                  taxonomy=taxonomy_artifact,
                                  level=VALID_COLLAPSE_LEVELS[collapse_level])
    except ValueError:
        raise AXIOME3Error(
            "No data to process. Please check if 1. input feature table is empty, 2. input taxonomy is empty, 3. input feature table and taxonomy share common features."
        )

    # By default, it has samples as rows, and taxa as columns
    collapsed_df = table_artifact.collapsed_table.view(pd.DataFrame)

    # Transpose
    collapsed_df_T = collapsed_df.T

    # Append "Taxon" column
    collapsed_df_T["Taxon"] = collapsed_df_T.index

    # Reset index
    final_df = collapsed_df_T.reset_index(drop=True)

    return final_df