예제 #1
0
def test_biom_lib():
    """
    the biom-format library has difficult dependencies to get
    correct. this just makes sure the module is available

    this example is taken from: 
    http://biom-format.org/documentation/table_objects.html
    """
    import biom
    from biom.table import Table as BiomTable

    import numpy as np

    data = np.arange(40).reshape(10, 4)
    sample_ids = ['S%d' % i for i in range(4)]
    observ_ids = ['O%d' % i for i in range(10)]
    sample_metadata = [{
        'environment': 'A'
    }, {
        'environment': 'B'
    }, {
        'environment': 'A'
    }, {
        'environment': 'B'
    }]
    observ_metadata = [{
        'taxonomy': ['Bacteria', 'Firmicutes']
    }, {
        'taxonomy': ['Bacteria', 'Firmicutes']
    }, {
        'taxonomy': ['Bacteria', 'Proteobacteria']
    }, {
        'taxonomy': ['Bacteria', 'Proteobacteria']
    }, {
        'taxonomy': ['Bacteria', 'Proteobacteria']
    }, {
        'taxonomy': ['Bacteria', 'Bacteroidetes']
    }, {
        'taxonomy': ['Bacteria', 'Bacteroidetes']
    }, {
        'taxonomy': ['Bacteria', 'Firmicutes']
    }, {
        'taxonomy': ['Bacteria', 'Firmicutes']
    }, {
        'taxonomy': ['Bacteria', 'Firmicutes']
    }]
    table = BiomTable(data,
                      observ_ids,
                      sample_ids,
                      observ_metadata,
                      sample_metadata,
                      table_id='Example Table')

    print table.ids(axis='observation')
    print table.ids(axis='sample')
    return
예제 #2
0
def load_klumpp_biom_table(data_dir, file_owner):

    otus_fn = otus_filename(data_dir)
    taxa_fn = taxa_filename(data_dir)

    box_downloads.download_from_box_no_auth(OTUS_URL,
                                            otus_fn,
                                            file_owner=file_owner)

    box_downloads.download_from_box_no_auth(TAXA_URL,
                                            taxa_fn,
                                            file_owner=file_owner)

    sample_ids, otu_ids = load_ids(otus_fn)

    otu_metadata = construct_otu_metadata(taxa_fn, otu_ids)  #taxonomy
    sample_metadata = construct_sample_metadata(otus_fn)  #control or ic

    observations_nary = load_data_matrix(otus_fn,
                                         sample_ids,
                                         otu_ids,
                                         num_skip_columns=2)

    otu_number_ids = list()
    for otu_id in otu_ids:
        number_id = otu_id.split(' ')[-1]
        int(number_id)  #make sure it's an int
        otu_number_ids.append(number_id)

    biom_table = BiomTable(observations_nary,
                           otu_number_ids,
                           sample_ids,
                           otu_metadata,
                           sample_metadata,
                           table_id='Klumpp')

    print('sample ids: ' + str(biom_table.ids(axis='sample')))
    print('otu ids: ' + str(biom_table.ids(axis='observation')))
    return biom_table
예제 #3
0
def make_modules_on_correlations(correlation_table: pd.DataFrame, feature_table: Table, min_r: float=.35) -> \
                                     (Table, nx.Graph, pd.Series):
    modules = ma.make_modules_naive(correlation_table, min_r=min_r)
    modules_rev = {asv: module for module, asvs in modules.items() for asv in asvs}
    for asv in feature_table.ids(axis='observation'):
        if asv not in modules_rev:
            modules_rev[asv] = None
    module_membership = pd.Series(modules_rev)
    coll_table = ma.collapse_modules(feature_table, modules)
    metadata = get_metadata_from_table(feature_table)
    metadata = ma.add_modules_to_metadata(modules, metadata)
    correlation_table_filtered = filter_correls(correlation_table, conet=True, min_r=min_r)
    net = correls_to_net(correlation_table_filtered, metadata=metadata)
    return coll_table, net, module_membership
예제 #4
0
class FunctionTests(TestCase):

    def setUp(self):
        self.tmp_dir = get_qiime_temp_dir()

        self.otu_table_data = np.array([[2, 1, 0],
                                        [0, 5, 0],
                                        [0, 3, 0],
                                        [1, 2, 0]])
        self.sample_names = list('YXZ')
        self.taxon_names = list('bacd')
        self.otu_metadata = [{'domain': 'Archaea'},
                             {'domain': 'Bacteria'},
                             {'domain': 'Bacteria'},
                             {'domain': 'Bacteria'}]

        self.otu_table = Table(self.otu_table_data,
                               self.taxon_names,
                               self.sample_names)

        self.otu_table_meta = Table(self.otu_table_data,
                                    self.taxon_names, self.sample_names,
                                    observation_metadata=self.otu_metadata)

        fd, self.otu_table_fp = mkstemp(dir=self.tmp_dir,
                                        prefix='test_rarefaction',
                                        suffix='.biom')
        close(fd)
        fd, self.otu_table_meta_fp = mkstemp(dir=self.tmp_dir,
                                             prefix='test_rarefaction',
                                             suffix='.biom')
        close(fd)

        self.rare_dir = mkdtemp(dir=self.tmp_dir,
                                prefix='test_rarefaction_dir', suffix='')

        write_biom_table(self.otu_table, self.otu_table_fp)
        write_biom_table(self.otu_table_meta, self.otu_table_meta_fp)

        self._paths_to_clean_up = [self.otu_table_fp, self.otu_table_meta_fp]
        self._dirs_to_clean_up = [self.rare_dir]

    def tearDown(self):
        """ cleanup temporary files """
        map(remove, self._paths_to_clean_up)
        for d in self._dirs_to_clean_up:
            if os.path.exists(d):
                rmtree(d)

    def test_rarefy_to_list(self):
        """rarefy_to_list should rarefy correctly, same names

        """
        maker = RarefactionMaker(self.otu_table_fp, 0, 1, 1, 1)
        res = maker.rarefy_to_list(include_full=True)
        self.assertItemsEqual(res[-1][2].ids(), self.otu_table.ids())
        self.assertItemsEqual(
            res[-1][2].ids(axis='observation'),
            self.otu_table.ids(axis='observation'))
        self.assertEqual(res[-1][2], self.otu_table)

        sample_value_sum = []
        for val in res[1][2].iter_data(axis='sample'):
            sample_value_sum.append(val.sum())
        npt.assert_almost_equal(sample_value_sum, [1.0, 1.0])

    def test_rarefy_to_files(self):
        """rarefy_to_files should write valid files

        """
        maker = RarefactionMaker(self.otu_table_fp, 1, 2, 1, 1)
        maker.rarefy_to_files(
            self.rare_dir,
            include_full=True,
            include_lineages=False)

        fname = os.path.join(self.rare_dir, "rarefaction_1_0.biom")
        otu_table = load_table(fname)

        self.assertItemsEqual(
            otu_table.ids(),
            self.otu_table.ids()[:2])
        # third sample had 0 seqs, so it's gone

    def test_rarefy_to_files2(self):
        """rarefy_to_files should write valid files with some metadata on otus

        """
        maker = RarefactionMaker(self.otu_table_meta_fp, 1, 2, 1, 1)
        maker.rarefy_to_files(
            self.rare_dir,
            include_full=True,
            include_lineages=False)

        fname = os.path.join(self.rare_dir, "rarefaction_1_0.biom")
        otu_table = load_table(fname)

        self.assertItemsEqual(
            otu_table.ids(),
            self.otu_table.ids()[:2])
        # third sample had 0 seqs, so it's gone

    def test_get_empty_rare(self):
        """get_rare_data should be empty when depth > # seqs in any sample"""
        self.assertRaises(TableException, get_rare_data, self.otu_table,
                          50, include_small_samples=False)

    def test_get_overfull_rare(self):
        """get_rare_data should be identical to given in this case

        here, rare depth > any sample, and include_small... = True"""
        rare_otu_table = get_rare_data(self.otu_table,
                                       50, include_small_samples=True)
        self.assertEqual(len(rare_otu_table.ids()), 3)
        # 4 observations times 3 samples = size 12 before
        self.assertEqual(len(rare_otu_table.ids(axis='observation')), 4)
        for sam in self.otu_table.ids():
            for otu in self.otu_table.ids(axis='observation'):
                rare_val = rare_otu_table.get_value_by_ids(otu, sam)
                self.assertEqual(rare_otu_table.get_value_by_ids(otu, sam),
                                 self.otu_table.get_value_by_ids(otu, sam))

    def test_get_11depth_rare(self):
        """get_rare_data should get only sample X

        """
        rare_otu_table = get_rare_data(self.otu_table,
                                       11, include_small_samples=False)
        self.assertEqual(rare_otu_table.ids(), ('X',))

        # a very complicated way to test things
        rare_values = [val[0]
                       for (val, otu_id, meta) in rare_otu_table.iter(axis='observation')]
        self.assertEqual(rare_values, [1.0, 5.0, 3.0, 2.0])
예제 #5
0
class FunctionTests(TestCase):
    def setUp(self):
        self.tmp_dir = get_qiime_temp_dir()

        self.otu_table_data = np.array([[2, 1, 0], [0, 5, 0], [0, 3, 0],
                                        [1, 2, 0]])
        self.sample_names = list('YXZ')
        self.taxon_names = list('bacd')
        self.otu_metadata = [{
            'domain': 'Archaea'
        }, {
            'domain': 'Bacteria'
        }, {
            'domain': 'Bacteria'
        }, {
            'domain': 'Bacteria'
        }]

        self.otu_table = Table(self.otu_table_data, self.taxon_names,
                               self.sample_names)

        self.otu_table_meta = Table(self.otu_table_data,
                                    self.taxon_names,
                                    self.sample_names,
                                    observation_metadata=self.otu_metadata)

        fd, self.otu_table_fp = mkstemp(dir=self.tmp_dir,
                                        prefix='test_rarefaction',
                                        suffix='.biom')
        close(fd)
        fd, self.otu_table_meta_fp = mkstemp(dir=self.tmp_dir,
                                             prefix='test_rarefaction',
                                             suffix='.biom')
        close(fd)

        self.rare_dir = mkdtemp(dir=self.tmp_dir,
                                prefix='test_rarefaction_dir',
                                suffix='')

        write_biom_table(self.otu_table, self.otu_table_fp)
        write_biom_table(self.otu_table_meta, self.otu_table_meta_fp)

        self._paths_to_clean_up = [self.otu_table_fp, self.otu_table_meta_fp]
        self._dirs_to_clean_up = [self.rare_dir]

    def tearDown(self):
        """ cleanup temporary files """
        map(remove, self._paths_to_clean_up)
        for d in self._dirs_to_clean_up:
            if os.path.exists(d):
                rmtree(d)

    def test_rarefy_to_list(self):
        """rarefy_to_list should rarefy correctly, same names

        """
        maker = RarefactionMaker(self.otu_table_fp, 0, 1, 1, 1)
        res = maker.rarefy_to_list(include_full=True)
        self.assertItemsEqual(res[-1][2].ids(), self.otu_table.ids())
        self.assertItemsEqual(res[-1][2].ids(axis='observation'),
                              self.otu_table.ids(axis='observation'))
        self.assertEqual(res[-1][2], self.otu_table)

        sample_value_sum = []
        for val in res[1][2].iter_data(axis='sample'):
            sample_value_sum.append(val.sum())
        npt.assert_almost_equal(sample_value_sum, [1.0, 1.0])

    def test_rarefy_to_files(self):
        """rarefy_to_files should write valid files

        """
        maker = RarefactionMaker(self.otu_table_fp, 1, 2, 1, 1)
        maker.rarefy_to_files(self.rare_dir,
                              include_full=True,
                              include_lineages=False)

        fname = os.path.join(self.rare_dir, "rarefaction_1_0.biom")
        otu_table = load_table(fname)

        self.assertItemsEqual(otu_table.ids(), self.otu_table.ids()[:2])
        # third sample had 0 seqs, so it's gone

    def test_rarefy_to_files2(self):
        """rarefy_to_files should write valid files with some metadata on otus

        """
        maker = RarefactionMaker(self.otu_table_meta_fp, 1, 2, 1, 1)
        maker.rarefy_to_files(self.rare_dir,
                              include_full=True,
                              include_lineages=False)

        fname = os.path.join(self.rare_dir, "rarefaction_1_0.biom")
        otu_table = load_table(fname)

        self.assertItemsEqual(otu_table.ids(), self.otu_table.ids()[:2])
        # third sample had 0 seqs, so it's gone

    def test_get_empty_rare(self):
        """get_rare_data should be empty when depth > # seqs in any sample"""
        self.assertRaises(TableException,
                          get_rare_data,
                          self.otu_table,
                          50,
                          include_small_samples=False)

    def test_get_overfull_rare(self):
        """get_rare_data should be identical to given in this case

        here, rare depth > any sample, and include_small... = True"""
        rare_otu_table = get_rare_data(self.otu_table,
                                       50,
                                       include_small_samples=True)
        self.assertEqual(len(rare_otu_table.ids()), 3)
        # 4 observations times 3 samples = size 12 before
        self.assertEqual(len(rare_otu_table.ids(axis='observation')), 4)
        for sam in self.otu_table.ids():
            for otu in self.otu_table.ids(axis='observation'):
                rare_val = rare_otu_table.get_value_by_ids(otu, sam)
                self.assertEqual(rare_otu_table.get_value_by_ids(otu, sam),
                                 self.otu_table.get_value_by_ids(otu, sam))

    def test_get_11depth_rare(self):
        """get_rare_data should get only sample X

        """
        rare_otu_table = get_rare_data(self.otu_table,
                                       11,
                                       include_small_samples=False)
        self.assertEqual(rare_otu_table.ids(), ('X', ))

        # a very complicated way to test things
        rare_values = [
            val[0]
            for (val, otu_id, meta) in rare_otu_table.iter(axis='observation')
        ]
        self.assertEqual(rare_values, [1.0, 5.0, 3.0, 2.0])
예제 #6
0
class TopLevelTests(TestCase):
    """Tests of top-level functions"""
    def setUp(self):
        """define some top-level data"""

        self.otu_table_values = array([[0, 0, 9, 5, 3, 1], [1, 5, 4, 0, 3, 2],
                                       [2, 3, 1, 1, 2, 5]])
        {
            (0, 2): 9.0,
            (0, 3): 5.0,
            (0, 4): 3.0,
            (0, 5): 1.0,
            (1, 0): 1.0,
            (1, 1): 5.0,
            (1, 2): 4.0,
            (1, 4): 3.0,
            (1, 5): 2.0,
            (2, 0): 2.0,
            (2, 1): 3.0,
            (2, 2): 1.0,
            (2, 3): 1.0,
            (2, 4): 2.0,
            (2, 5): 5.0
        }
        self.otu_table = Table(
            self.otu_table_values, ['OTU1', 'OTU2', 'OTU3'],
            ['Sample1', 'Sample2', 'Sample3', 'Sample4', 'Sample5', 'Sample6'],
            [{
                "taxonomy": ['Bacteria']
            }, {
                "taxonomy": ['Archaea']
            }, {
                "taxonomy": ['Streptococcus']
            }], [None, None, None, None, None, None])
        self.otu_table_f = Table(
            self.otu_table_values, ['OTU1', 'OTU2', 'OTU3'],
            ['Sample1', 'Sample2', 'Sample3', 'Sample4', 'Sample5', 'Sample6'],
            [{
                "taxonomy": ['1A', '1B', '1C', 'Bacteria']
            }, {
                "taxonomy": ['2A', '2B', '2C', 'Archaea']
            }, {
                "taxonomy": ['3A', '3B', '3C', 'Streptococcus']
            }], [None, None, None, None, None, None])

        self.full_lineages = [['1A', '1B', '1C', 'Bacteria'],
                              ['2A', '2B', '2C', 'Archaea'],
                              ['3A', '3B', '3C', 'Streptococcus']]
        self.metadata = [[['Sample1', 'NA', 'A'], ['Sample2', 'NA', 'B'],
                          ['Sample3', 'NA', 'A'], ['Sample4', 'NA', 'B'],
                          ['Sample5', 'NA', 'A'], ['Sample6', 'NA', 'B']],
                         ['SampleID', 'CAT1', 'CAT2'], []]
        self.tree_text = ["('OTU3',('OTU1','OTU2'))"]
        fh, self.tmp_heatmap_fpath = mkstemp(prefix='test_heatmap_',
                                             suffix='.pdf')
        close(fh)

    def test_extract_metadata_column(self):
        """Extracts correct column from mapping file"""
        obs = extract_metadata_column(self.otu_table.ids(),
                                      self.metadata,
                                      category='CAT2')
        exp = ['A', 'B', 'A', 'B', 'A', 'B']
        self.assertEqual(obs, exp)

    def test_get_order_from_categories(self):
        """Sample indices should be clustered within each category"""
        category_labels = ['A', 'B', 'A', 'B', 'A', 'B']
        obs = get_order_from_categories(self.otu_table, category_labels)
        group_string = "".join([category_labels[i] for i in obs])
        self.assertTrue("AAABBB" == group_string or group_string == "BBBAAA")

    def test_get_order_from_tree(self):
        obs = get_order_from_tree(self.otu_table.ids(axis='observation'),
                                  self.tree_text)
        exp = [2, 0, 1]
        assert_almost_equal(obs, exp)

    def test_make_otu_labels(self):
        lineages = []
        for val, id, meta in self.otu_table.iter(axis='observation'):
            lineages.append([v for v in meta['taxonomy']])
        obs = make_otu_labels(self.otu_table.ids(axis='observation'),
                              lineages,
                              n_levels=1)
        exp = ['Bacteria (OTU1)', 'Archaea (OTU2)', 'Streptococcus (OTU3)']
        self.assertEqual(obs, exp)

        full_lineages = []
        for val, id, meta in self.otu_table_f.iter(axis='observation'):
            full_lineages.append([v for v in meta['taxonomy']])
        obs = make_otu_labels(self.otu_table_f.ids(axis='observation'),
                              full_lineages,
                              n_levels=3)
        exp = [
            '1B;1C;Bacteria (OTU1)', '2B;2C;Archaea (OTU2)',
            '3B;3C;Streptococcus (OTU3)'
        ]
        self.assertEqual(obs, exp)

    def test_names_to_indices(self):
        new_order = [
            'Sample4', 'Sample2', 'Sample3', 'Sample6', 'Sample5', 'Sample1'
        ]
        obs = names_to_indices(self.otu_table.ids(), new_order)
        exp = [3, 1, 2, 5, 4, 0]
        assert_almost_equal(obs, exp)

    def test_get_log_transform(self):
        obs = get_log_transform(self.otu_table)

        data = [val for val in self.otu_table.iter_data(axis='observation')]
        xform = asarray(data, dtype=float64)

        for (i, val) in enumerate(obs.iter_data(axis='observation')):
            non_zeros = argwhere(xform[i] != 0)
            xform[i, non_zeros] = log10(xform[i, non_zeros])
            assert_almost_equal(val, xform[i])

    def test_get_clusters(self):
        data = asarray(
            [val for val in self.otu_table.iter_data(axis='observation')])
        obs = get_clusters(data, axis='row')
        self.assertTrue([0, 1, 2] == obs or obs == [1, 2, 0])
        obs = get_clusters(data, axis='column')
        exp = [2, 3, 1, 4, 0, 5]
        self.assertEqual(obs, exp)

    def test_plot_heatmap(self):
        plot_heatmap(self.otu_table,
                     self.otu_table.ids(axis='observation'),
                     self.otu_table.ids(),
                     filename=self.tmp_heatmap_fpath)
        self.assertEqual(exists(self.tmp_heatmap_fpath), True)
        remove_files(set([self.tmp_heatmap_fpath]))
예제 #7
0
def gibbs(table_fp, mapping_fp, output_dir, loo, jobs, alpha1, alpha2, beta,
          source_rarefaction_depth, sink_rarefaction_depth, restarts,
          draws_per_restart, burnin, delay, cluster_start_delay,
          source_sink_column, source_column_value, sink_column_value,
          source_category_column):
    '''Gibb's sampler for Bayesian estimation of microbial sample sources.

    For details, see the project README file.
    '''
    # Create results directory. Click has already checked if it exists, and
    # failed if so.
    os.mkdir(output_dir)

    # Load the mapping file and biom table and remove samples which are not
    # shared.
    o = open(mapping_fp, 'U')
    sample_metadata_lines = o.readlines()
    o.close()

    sample_metadata, biom_table = \
        _cli_sync_biom_and_sample_metadata(
            parse_mapping_file(sample_metadata_lines),
            load_table(table_fp))

    # If biom table has fractional counts, it can produce problems in indexing
    # later on.
    biom_table.transform(lambda data, id, metadata: np.ceil(data))

    # If biom table has sample metadata, there will be pickling errors when
    # submitting multiple jobs. We remove the metadata by making a copy of the
    # table without metadata.
    biom_table = Table(biom_table._data.toarray(),
                       biom_table.ids(axis='observation'),
                       biom_table.ids(axis='sample'))

    # Parse the mapping file and options to get the samples requested for
    # sources and sinks.
    source_samples, sink_samples = sinks_and_sources(
        sample_metadata,
        column_header=source_sink_column,
        source_value=source_column_value,
        sink_value=sink_column_value)

    # If we have no source samples neither normal operation or loo will work.
    # Will also likely get strange errors.
    if len(source_samples) == 0:
        raise ValueError('Mapping file or biom table passed contain no '
                         '`source` samples.')

    # Prepare the 'sources' matrix by collapsing the `source_samples` by their
    # metadata values.
    sources_envs, sources_data = collapse_sources(source_samples,
                                                  sample_metadata,
                                                  source_category_column,
                                                  biom_table,
                                                  sort=True)

    # Rarefiy data if requested.
    sources_data, biom_table = \
        subsample_sources_sinks(sources_data, sink_samples, biom_table,
                                source_rarefaction_depth,
                                sink_rarefaction_depth)

    # Build function that require only a single parameter -- sample -- to
    # enable parallel processing if requested.
    if loo:
        f = partial(_cli_loo_runner,
                    source_category=source_category_column,
                    alpha1=alpha1,
                    alpha2=alpha2,
                    beta=beta,
                    restarts=restarts,
                    draws_per_restart=draws_per_restart,
                    burnin=burnin,
                    delay=delay,
                    sample_metadata=sample_metadata,
                    sources_data=sources_data,
                    sources_envs=sources_envs,
                    biom_table=biom_table,
                    output_dir=output_dir)
        sample_iter = source_samples
    else:
        f = partial(_cli_sink_source_prediction_runner,
                    alpha1=alpha1,
                    alpha2=alpha2,
                    beta=beta,
                    restarts=restarts,
                    draws_per_restart=draws_per_restart,
                    burnin=burnin,
                    delay=delay,
                    sources_data=sources_data,
                    biom_table=biom_table,
                    output_dir=output_dir)
        sample_iter = sink_samples

    if jobs > 1:
        # Launch the ipcluster and wait for it to come up.
        subprocess.Popen('ipcluster start -n %s --quiet' % jobs, shell=True)
        time.sleep(cluster_start_delay)
        c = Client()
        c[:].map(f, sample_iter, block=True)
        # Shut the cluster down. Answer taken from SO:
        # http://stackoverflow.com/questions/30930157/stopping-ipcluster-engines-ipython-parallel
        c.shutdown(hub=True)
    else:
        for sample in sample_iter:
            f(sample)

    # Format results for output.
    samples = []
    samples_data = []
    for sample_fp in glob.glob(os.path.join(output_dir, '*')):
        samples.append(sample_fp.strip().split('/')[-1].split('.txt')[0])
        samples_data.append(np.loadtxt(sample_fp, delimiter='\t'))
    mp, mps = _cli_collate_results(samples, samples_data, sources_envs)

    o = open(os.path.join(output_dir, 'mixing_proportions.txt'), 'w')
    o.writelines(mp)
    o.close()
    o = open(os.path.join(output_dir, 'mixing_proportions_stds.txt'), 'w')
    o.writelines(mps)
    o.close()
class TopLevelTests(TestCase):

    """Tests of top-level functions"""

    def setUp(self):
        """define some top-level data"""

        self.otu_table_values = array([[0, 0, 9, 5, 3, 1],
                                       [1, 5, 4, 0, 3, 2],
                                       [2, 3, 1, 1, 2, 5]])
        {(0, 2): 9.0, (0, 3): 5.0, (0, 4): 3.0, (0, 5): 1.0,
         (1, 0): 1.0, (1, 1): 5.0, (1, 2): 4.0, (1, 4): 3.0, (1, 5): 2.0,
         (2, 0): 2.0, (2, 1): 3.0, (2, 2): 1.0, (2, 3): 1.0, (2, 4): 2.0, (2, 5): 5.0}
        self.otu_table = Table(self.otu_table_values,
                                       ['OTU1', 'OTU2', 'OTU3'],
                                       ['Sample1', 'Sample2', 'Sample3',
                                        'Sample4', 'Sample5', 'Sample6'],
                                       [{"taxonomy": ['Bacteria']},
                                        {"taxonomy": ['Archaea']},
                                        {"taxonomy": ['Streptococcus']}],
                                        [None, None, None, None, None, None])
        self.otu_table_f = Table(self.otu_table_values,
                                         ['OTU1', 'OTU2', 'OTU3'],
                                         ['Sample1', 'Sample2', 'Sample3',
                                          'Sample4', 'Sample5', 'Sample6'],
                                         [{"taxonomy": ['1A', '1B', '1C', 'Bacteria']},
                                          {"taxonomy":
                                           ['2A', '2B', '2C', 'Archaea']},
                                          {"taxonomy": ['3A', '3B', '3C', 'Streptococcus']}],
                                          [None, None, None, None, None, None])

        self.full_lineages = [['1A', '1B', '1C', 'Bacteria'],
                              ['2A', '2B', '2C', 'Archaea'],
                              ['3A', '3B', '3C', 'Streptococcus']]
        self.metadata = [[['Sample1', 'NA', 'A'],
                          ['Sample2', 'NA', 'B'],
                          ['Sample3', 'NA', 'A'],
                          ['Sample4', 'NA', 'B'],
                          ['Sample5', 'NA', 'A'],
                          ['Sample6', 'NA', 'B']],
                         ['SampleID', 'CAT1', 'CAT2'], []]
        self.tree_text = ["('OTU3',('OTU1','OTU2'))"]
        fh, self.tmp_heatmap_fpath = mkstemp(prefix='test_heatmap_',
                                            suffix='.pdf')
        close(fh)

    def test_extract_metadata_column(self):
        """Extracts correct column from mapping file"""
        obs = extract_metadata_column(self.otu_table.ids(),
                                      self.metadata, category='CAT2')
        exp = ['A', 'B', 'A', 'B', 'A', 'B']
        self.assertEqual(obs, exp)

    def test_get_order_from_categories(self):
        """Sample indices should be clustered within each category"""
        category_labels = ['A', 'B', 'A', 'B', 'A', 'B']
        obs = get_order_from_categories(self.otu_table, category_labels)
        group_string = "".join([category_labels[i] for i in obs])
        self.assertTrue("AAABBB" == group_string or group_string == "BBBAAA")

    def test_get_order_from_tree(self):
        obs = get_order_from_tree(
            self.otu_table.ids(axis='observation'),
            self.tree_text)
        exp = [2, 0, 1]
        assert_almost_equal(obs, exp)

    def test_make_otu_labels(self):
        lineages = []
        for val, id, meta in self.otu_table.iter(axis='observation'):
            lineages.append([v for v in meta['taxonomy']])
        obs = make_otu_labels(self.otu_table.ids(axis='observation'),
                              lineages, n_levels=1)
        exp = ['Bacteria (OTU1)', 'Archaea (OTU2)', 'Streptococcus (OTU3)']
        self.assertEqual(obs, exp)

        full_lineages = []
        for val, id, meta in self.otu_table_f.iter(axis='observation'):
            full_lineages.append([v for v in meta['taxonomy']])
        obs = make_otu_labels(self.otu_table_f.ids(axis='observation'),
                              full_lineages, n_levels=3)
        exp = ['1B;1C;Bacteria (OTU1)',
               '2B;2C;Archaea (OTU2)',
               '3B;3C;Streptococcus (OTU3)']
        self.assertEqual(obs, exp)

    def test_names_to_indices(self):
        new_order = ['Sample4', 'Sample2', 'Sample3',
                     'Sample6', 'Sample5', 'Sample1']
        obs = names_to_indices(self.otu_table.ids(), new_order)
        exp = [3, 1, 2, 5, 4, 0]
        assert_almost_equal(obs, exp)

    def test_get_log_transform(self):
        obs = get_log_transform(self.otu_table)

        data = [val for val in self.otu_table.iter_data(axis='observation')]
        xform = asarray(data, dtype=float64)

        for (i, val) in enumerate(obs.iter_data(axis='observation')):
            non_zeros = argwhere(xform[i] != 0)
            xform[i, non_zeros] = log10(xform[i, non_zeros])
            assert_almost_equal(val, xform[i])

    def test_get_clusters(self):
        data = asarray([val for val in self.otu_table.iter_data(axis='observation')])
        obs = get_clusters(data, axis='row')
        self.assertTrue([0, 1, 2] == obs or obs == [1, 2, 0])
        obs = get_clusters(data, axis='column')
        exp = [2, 3, 1, 4, 0, 5]
        self.assertEqual(obs, exp)

    def test_plot_heatmap(self):
        plot_heatmap(
            self.otu_table, self.otu_table.ids(axis='observation'),
            self.otu_table.ids(), filename=self.tmp_heatmap_fpath)
        self.assertEqual(exists(self.tmp_heatmap_fpath), True)
        remove_files(set([self.tmp_heatmap_fpath]))
예제 #9
0
def gibbs(table_fp, mapping_fp, output_dir, loo, jobs, alpha1, alpha2, beta,
          source_rarefaction_depth, sink_rarefaction_depth,
          restarts, draws_per_restart, burnin, delay, cluster_start_delay,
          source_sink_column, source_column_value, sink_column_value,
          source_category_column):
    '''Gibb's sampler for Bayesian estimation of microbial sample sources.

    For details, see the project README file.
    '''
    # Create results directory. Click has already checked if it exists, and
    # failed if so.
    os.mkdir(output_dir)

    # Load the mapping file and biom table and remove samples which are not
    # shared.
    o = open(mapping_fp, 'U')
    sample_metadata_lines = o.readlines()
    o.close()

    sample_metadata, biom_table = \
        _cli_sync_biom_and_sample_metadata(
            parse_mapping_file(sample_metadata_lines),
            load_table(table_fp))

    # If biom table has fractional counts, it can produce problems in indexing
    # later on.
    biom_table.transform(lambda data, id, metadata: np.ceil(data))

    # If biom table has sample metadata, there will be pickling errors when
    # submitting multiple jobs. We remove the metadata by making a copy of the
    # table without metadata.
    biom_table = Table(biom_table._data.toarray(),
                       biom_table.ids(axis='observation'),
                       biom_table.ids(axis='sample'))

    # Parse the mapping file and options to get the samples requested for
    # sources and sinks.
    source_samples, sink_samples = sinks_and_sources(
        sample_metadata, column_header=source_sink_column,
        source_value=source_column_value, sink_value=sink_column_value)

    # If we have no source samples neither normal operation or loo will work.
    # Will also likely get strange errors.
    if len(source_samples) == 0:
        raise ValueError('Mapping file or biom table passed contain no '
                         '`source` samples.')

    # Prepare the 'sources' matrix by collapsing the `source_samples` by their
    # metadata values.
    sources_envs, sources_data = collapse_sources(source_samples,
                                                  sample_metadata,
                                                  source_category_column,
                                                  biom_table, sort=True)

    # Rarefiy data if requested.
    sources_data, biom_table = \
        subsample_sources_sinks(sources_data, sink_samples, biom_table,
                                source_rarefaction_depth,
                                sink_rarefaction_depth)

    # Build function that require only a single parameter -- sample -- to
    # enable parallel processing if requested.
    if loo:
        f = partial(_cli_loo_runner, source_category=source_category_column,
                    alpha1=alpha1, alpha2=alpha2, beta=beta,
                    restarts=restarts, draws_per_restart=draws_per_restart,
                    burnin=burnin, delay=delay,
                    sample_metadata=sample_metadata,
                    sources_data=sources_data, sources_envs=sources_envs,
                    biom_table=biom_table, output_dir=output_dir)
        sample_iter = source_samples
    else:
        f = partial(_cli_sink_source_prediction_runner, alpha1=alpha1,
                    alpha2=alpha2, beta=beta, restarts=restarts,
                    draws_per_restart=draws_per_restart, burnin=burnin,
                    delay=delay, sources_data=sources_data,
                    biom_table=biom_table, output_dir=output_dir)
        sample_iter = sink_samples

    if jobs > 1:
        # Launch the ipcluster and wait for it to come up.
        subprocess.Popen('ipcluster start -n %s --quiet' % jobs, shell=True)
        time.sleep(cluster_start_delay)
        c = Client()
        c[:].map(f, sample_iter, block=True)
        # Shut the cluster down. Answer taken from SO:
        # http://stackoverflow.com/questions/30930157/stopping-ipcluster-engines-ipython-parallel
        c.shutdown(hub=True)
    else:
        for sample in sample_iter:
            f(sample)

    # Format results for output.
    samples = []
    samples_data = []
    for sample_fp in glob.glob(os.path.join(output_dir, '*')):
        samples.append(sample_fp.strip().split('/')[-1].split('.txt')[0])
        samples_data.append(np.loadtxt(sample_fp, delimiter='\t'))
    mp, mps = _cli_collate_results(samples, samples_data, sources_envs)

    o = open(os.path.join(output_dir, 'mixing_proportions.txt'), 'w')
    o.writelines(mp)
    o.close()
    o = open(os.path.join(output_dir, 'mixing_proportions_stds.txt'), 'w')
    o.writelines(mps)
    o.close()