def main(otu_table, mapping_data, categories, output_dir, \
    samples_to_plot = None, legend = False, xaxis = True):
    """Creates stacked bar plots for an otu table
    INPUTS:
        otu_table -- an open OTU table

        mapping_data -- a tab delimited string containing the mapping data 
                    passed from the mapping file.

        categories -- a dictionary keying a mapping category to the 
                    corresponding sample IDs and taxonomy for a collapsed 
                    biom table

        output_dir -- the location of the directory where output files should be
                    saved. If this directory does not exist, it will be created.

        samples_to_plot -- a list of sample ids to plot. If no value is passed, 
                    then all samples in the biom table are analyzed.

    OUTPUTS:
        A pdf of stacked taxonomy will be generated for each sample and saved 
        in the output directory. These will follow the file name format 
        Figure_4_<SAMPLEID>.pdf
    """
    # Sets constants
    LEVEL = 2
    FILEPREFIX = 'Figure_4_'
    MICHAEL_POLLAN = '000007108.1075657'
    NUM_TAXA = 9
    NUM_CATS_TO_PLOT = 7
    
    # Loads the mapping file
    map_dict = map_to_2D_dict(mapping_data)
    
    (common_taxa, whole_sample_ids, whole_summary) = \
        summarize_human_taxa(otu_table, LEVEL)

    # Converts final taxa to a clean list
    common_phyla = []
    for taxon in common_taxa: 
        common_phyla.append(taxon[1].strip(' p__').strip('[').strip(']'))
    common_taxa = common_phyla
   
    # Checks that the correct sample ids are plotted
    if samples_to_plot == None:
        sample_ids = whole_sample_ids
    else:
        sample_ids = samples_to_plot

    # Identifies Michael Pollan's pre-ABX sample
    mp_sample_pos = whole_sample_ids.index(MICHAEL_POLLAN)
    mp_sample_taxa = whole_summary[:,mp_sample_pos]

    # Loads the category dictionary
    categories = load_category_files(category_fp, LEVEL)

    # Generates a figure for each sample
    for idx, sample_id in enumerate(whole_sample_ids):
        if sample_id in sample_ids:
            # Preallocates a numpy array for the plotting data
            tax_array = zeros((NUM_TAXA, NUM_CATS_TO_PLOT))        
            meta_data = map_dict[sample_id] 
            cat_list = ['You', 'Average', 'Similar Diet', ' Similar BMI', 
                        'Same Gender', 'Similar Age', 
                        'Michael Pollan', '']

            #cat_list.append('Your Fecal Sample')
            #cat_list.append('Average Fecal Samples')
        
            tax_array[:,0] = whole_summary[:,idx]
            tax_array[:,1] = mean(whole_summary, 1)
        
            cat_watch = 2
            # Identifies the appropriate metadata categories
            for cat in categories:                      
                # Pulls metadata for the sample and category
                mapping_key = meta_data[cat]
                # Pulls taxonomic summary and group descriptions for the category
                tax_summary = categories[cat]['Taxa Summary']
                group_descriptions = categories[cat]['Groups']               
                # Amends plotting tables
                try:
                    mapping_col = group_descriptions.index(mapping_key)
                except:
                    raise ValueError, 'The %s cannot be found in %s.' \
                    % (mapping_key, cat)
                tax_array[:,cat_watch] = tax_summary[:,mapping_col]

                cat_watch = cat_watch + 1

            tax_array[:,-1] = mp_sample_taxa
            # Plots the data
            filename = pjoin(output_dir, '%s%s.pdf' \
                % (FILEPREFIX, sample_id))
            plot_american_gut(tax_array, filename)
    def test_summarize_human_taxa(self):
       # Defines the known values
        known_ids = ['00010', '00100', '00200', '00111', '00112', '00211']

        known_taxa = [(u'k__Bacteria', u' p__Firmicutes'),
                      (u'k__Bacteria', u' p__Bacteroidetes'),
                      (u'k__Bacteria', u' p__Proteobacteria'),
                      (u'k__Bacteria', u' p__Actinobacteria'),
                      (u'k__Bacteria', u' p__Verrucomicrobia'),
                      (u'k__Bacteria', u' p__Tenericutes'),
                      (u'k__Bacteria', u' p__Cyanobacteria'),
                      (u'k__Bacteria', u' p__Fusobacteria'),
                      (u'k__Bacteria', u' p__Other')]
        
        table_known = array([[0.5000, 0.2500, 0.5500, 0.3000, 0.45000, 0.1000],
                             [0.3000, 0.5000, 0.3500, 0.6200, 0.45000, 0.7500],
                             [0.0000, 0.0000, 0.0000, 0.0000, 0.00000, 0.0000],
                             [0.0800, 0.0000, 0.0000, 0.0030, 0.00000, 0.0000],
                             [0.0500, 0.0100, 0.0020, 0.0010, 0.00000, 0.0000],
                             [0.0200, 0.0800, 0.0500, 0.0020, 0.05000, 0.0000],
                             [0.0010, 0.0200, 0.0000, 0.0000, 0.00010, 0.0000],
                             [0.0080, 0.0000, 0.0020, 0.0000, 0.00000, 0.0500],
                             [0.0210, 0.1100, 0.0460, 0.0240, 0.01990, 0.0000]])

        # Creates an otu table for testing which corresponds with the 
        sample_ids = ['00010', '00100', '00200', '00111', '00112', '00211']

        observation_ids = ['1001', '2001', '2002', '2003', '3001', '3003', 
                           '4001', '5001', '6001', '7001', '8001', '9001', 
                           '9002', '9003']

        observation_md = [{'taxonomy': (u'k__Bacteria', u' p__Bacteroidetes',
                                        u'c__Bacteroidia')},
                          {'taxonomy': (u'k__Bacteria', u' p__Firmicutes',
                                        u'c__Clostridia')},
                          {'taxonomy': (u'k__Bacteria', u' p__Firmicutes',
                                        u'c__Erysipelotrichi')},
                          {'taxonomy': (u'k__Bacteria', u' p__Firmicutes',
                                        u'c__Bacilli')},
                          {'taxonomy': (u'k__Bacteria', u' p__Proteobacteria',
                                        u'c__Alphaproteobacteria')},
                          {'taxonomy': (u'k__Bacteria', u' p__Proteobacteria', 
                                        u'c__Gammaproteobacteria')},
                          {'taxonomy': (u'k__Bacteria', u' p__Tenericutes', 
                                        u'c__Mollicutes')},
                          {'taxonomy': (u'k__Bacteria', u' p__Actinobacteria', 
                                        u'c__Coriobacteriia')},
                          {'taxonomy': (u'k__Bacteria', u' p__Verrucomicrobia', 
                                        u'c__Verrucomicrobiae')},
                          {'taxonomy': (u'k__Bacteria', u' p__Cyanobacteria', 
                                        u'c__4C0d-2')},
                          {'taxonomy': (u'k__Bacteria', u' p__Fusobacteria', 
                                        u'c__Fusobacteriia')},
                          {'taxonomy': (u'k__Bacteria', u' p__TM7', 
                                        u'c__TM7-2')},
                          {'taxonomy': (u'k__Bacteria', u' p__Acidobacteria', 
                                        u'c__Chloracidobacteria')},
                          {'taxonomy': (u'k__Bacteria', u' p__', u'c__')}]

        data = array([[ 1691,  3004, 18606,  6914,  1314, 22843],
                      [ 2019,  1091,  8163,  1112,   738,  2362],
                      [   67,     4,  2835,   310,    85,   161],
                      [  731,   407, 18240,  1924,   492,   522],
                      [    8,     1,     0,    53,     8,   275],
                      [  105,   179,     0,   504,    79,  2771],
                      [  451,     0,     0,    33,     0,     0],
                      [  282,    60,   106,    11,     0,     0],
                      [  113,   481,  2658,    22,   146,     0],
                      [    6,   120,     0,     0,     0,     0],
                      [   45,     0,   106,     0,     0,  1523],
                      [   39,   341,  1761,   139,    18,     0],
                      [   21,   268,   153,     8,    15,     0],
                      [   59,    51,   531,   120,    25,     0]])                     

        otu_table = table_factory(data = data, sample_ids = sample_ids, \
                                  observation_ids = observation_ids, \
                                  observation_metadata = observation_md, \
                                  constructor = SparseOTUTable)

        # Tests that the table is summarized correctly
        (test_taxa, test_ids, test_table) = summarize_human_taxa(otu_table, 2)

        # Checks that all the outputs are correct
        self.assertEqual(test_ids, known_ids)
        self.assertEqual(test_taxa, known_taxa)
        self.assertEqual(test_table.all(), table_known.all())