def test_transformed(self):
     """correctly return transform counts"""
     input = dict(counts=[[0, 1], [2, 3], [4, 5], [6, 7], [8, 9]])
     coll = RegionCollection(**input)
     c, r, stderr = coll.transformed()
     self.assertEqual(c, [4, 5])
     freqs = coll.asfreqs()
     c, r, stderr = freqs.transformed(counts_func=column_sum)
     self.assertFloatEqual(
         c, numpy.array([20. / 45., 25. / 45], dtype=c.dtype))
Exemple #2
0
class RegionStudy(object):
    """ Specifies the RegionCollection associated with an expression
            data set. Used to collate data for plot_counts.py.

    Members: collection (a RegionCollection), window_start, window_end,
            collection_label
    Methods: filterByGenes, filterByCutoff, normaliseByBases,
            asPlotLines
    """
    def __init__(self, collection_fn, counts_func, *args, **kwargs):
        super(RegionStudy, self).__init__(*args, **kwargs)
        rr = RunRecord('Study')
        # Keep the source file name for labelling purposes
        self.collection_path = collection_fn
        fn = collection_fn.split('/')[-1].rstrip('.gz')
        self.collection_label = ' '.join(fn.replace('_', ' ').split('.')[:-1])
        try:
            self.data_collection = RegionCollection(filename=collection_fn)
        except IOError:
            rr.dieOnCritical('Collection will not load', collection_fn)

        # Frequency normalized counts need to be converted
        if counts_func is column_sum:
            self.data_collection = self.data_collection.asfreqs()
        self.counts_func = counts_func

        # Get feature window start and end
        try:
            self.window_upstream =\
                    self.data_collection.info['args']['window_upstream']
        except KeyError:
            rr.dieOnCritical('Collection value not defined', 'window_upstream')

        try:
            self.window_downstream =\
                    self.data_collection.info['args']['window_downstream']
        except KeyError:
            rr.dieOnCritical('Collection value not defined',
                             'window_downstream')

        try:
            self.feature_type =\
                    self.data_collection.info['args']['feature_type']
        except KeyError:
            self.feature_type = 'Unknown'

    def filterByGenes(self,
                      db_path,
                      chrom=None,
                      include_samples=None,
                      exclude_samples=None):
        """ keep only results that match selected genes """

        rr = RunRecord('filterByGenes')
        if not include_samples and not exclude_samples and not chrom:
            return

        rr.addInfo('Starting no. of genes', self.data_collection.N)

        session = make_session(db_path)
        if include_samples:
            for sample in include_samples:
                rr.addInfo('Restricting plot by include sample', sample)

        if exclude_samples:
            for sample in exclude_samples:
                rr.addInfo('Restricting plot by exclude sample', sample)

        if not chrom is None:
            rr.addInfo('Restricting plot to chromosome', chrom)

        filter_gene_ids = get_gene_ids(session,
                                       chrom=chrom,
                                       include_targets=include_samples,
                                       exclude_targets=exclude_samples)

        self.data_collection =\
                self.data_collection.filteredByLabel(filter_gene_ids)
        rr.addInfo('Remaining genes', self.data_collection.N)

        if self.data_collection is None or\
           len(self.data_collection.ranks) == 0:
            rr.dieOnCritical('Genes remaining after filtering', '0')

    def filterByCutoff(self, cutoff=None):
        """ keep only results that pass Chebyshev cutoff """
        rr = RunRecord('filterByCutoff')

        rr.addInfo('Starting no. of genes', self.data_collection.N)

        # exclude outlier genes using one-sided Chebyshev
        if cutoff is not None and cutoff != 0.0:
            try:
                cutoff = float(cutoff)
                if cutoff < 0.0 or cutoff >= 1.0:
                    rr.addError('Cutoff out of range', cutoff)
                    rr.addInfo('Cutoff set to default', 0.05)
                    cutoff = 0.05
            except ValueError:
                rr.addError('Cutoff not given as float', cutoff)
                rr.addInfo('Cutoff set to default', 0.05)
                cutoff = 0.05
                # Do Chebyshev filtering

            self.data_collection =\
                    self.data_collection.filteredChebyshevUpper(p=cutoff)
            rr.addInfo('Used Chebyshev filter cutoff', cutoff)
            rr.addInfo('No. genes after normalisation filter',
                       self.data_collection.N)
        else:
            rr.addInfo('Outlier cutoff filtering', 'Off')

        if self.data_collection is None or\
                self.data_collection.ranks.max() == 0:
            rr.dieOnCritical('No data after filtering', 'Failure')

    def normaliseByRPM(self):
        """ This requires 'mapped tags', 'tag count' or 'base count' to be present
            in the collection and gives counts per mapped million tags/bases.
            Mapped tags is the total experimental mapped tags.
            Tag count and base count are region specific.
        """
        rr = RunRecord('normaliseByRPM')
        try:
            norm_RPM = self.data_collection.info['args']['mapped tags']
            rr.addInfo("'mapped tags' value", norm_RPM)
        except KeyError:
            rr.addError('Info field not found', 'mapped tags')
            return
        norm_factor = 1000000.0 / norm_RPM
        rr.addInfo('normalising by RPMs', norm_factor)
        normalised_counts = []
        for c in self.data_collection.counts:
            c2 = c * norm_factor
            normalised_counts.append(c2)
        self.data_collection.counts = numpy.array(normalised_counts)

    def _groupAllGeneCounts(self):
        """ Group counts for all genes and return as a single PlotLine.
            Called by asPlotLines or _groupNGeneCounts().
            Returns a list.
        """
        rr = RunRecord('_groupAllGeneCounts')
        counts, ranks, se = self.data_collection.transformed(\
            counts_func=self.counts_func)
        if not len(counts):
            rr.dieOnCritical('No counts data in', 'Study._groupAllGeneCounts')

        ranks = 0  # rank is irrelevant for 'all' genes

        # Always name single lines by their collection name
        label = self.collection_label
        plot_lines = [PlotLine(counts, ranks, label, study=label, stderr=se)]
        return plot_lines

    def _groupNoGeneCounts(self):
        """ Don't group counts. Simply return a PlotLine for each set of
            counts.
            Called by asPlotLines()
        """
        rr = RunRecord('_groupNoGeneCounts')
        counts = self.data_collection.counts
        ranks = self.data_collection.ranks
        labels = self.data_collection.labels
        plot_lines = []
        for c, r, l in zip(counts, ranks, labels):
            if self.counts_func == stdev:
                stdev_ = c.std()
                if stdev_ > 0:
                    c = (c - c.mean()) / stdev_
                    plot_lines.append(
                        PlotLine(c, r, l, study=self.collection_label))
            else:
                plot_lines.append(
                    PlotLine(c, r, l, study=self.collection_label))

        # If no data was returned default to groupAllCollectionCounts
        if not len(plot_lines):
            rr.dieOnCritical('No data in collection', 'Failure')

        # If a single line is created label it with the collection name
        if len(plot_lines) == 1:
            plot_lines[0].label = [self.collection_label]

        return plot_lines

    def _groupNGeneCounts(self, group_size, p=0.0):
        """ Group counts for N genes and return as PlotLines. Defaults to
            _groupAllGeneCounts() if group size is too large.
            Called by asPlotLines()
        """
        rr = RunRecord('_groupNGeneCounts')
        plot_lines = []
        for index, (c,r,l,se) in enumerate(self.data_collection.\
                iterTransformedGroups(group_size=group_size,
                counts_func=self.counts_func, p=p)):
            plot_lines.append(
                PlotLine(c,
                         rank=r,
                         label=l,
                         study=self.collection_label,
                         stderr=se))

        # If no data was returned default to groupAllCollectionCounts
        if not len(plot_lines):
            rr.addWarning('Defaulting to ALL features. Not enough '+\
                          'features for group of size', group_size)
            plotLines = self._groupAllGeneCounts()
            return plotLines

        return plot_lines

    def asPlotLines(self, group_size, group_location, p=0.0):
        """
            Returns a list of PlotLine objects from this study.
            'p' is the Chebyshev cut-off if not None
        """
        rr = RunRecord('asPlotLines')
        if p > 0.0:
            rr.addInfo('Applying per-line Chebyshev filtering', p)

        if type(group_size) is str and group_size.lower() == 'all':
            plot_lines = self._groupAllGeneCounts()
        elif type(group_size) is int:
            if group_size == 1:
                plot_lines = self._groupNoGeneCounts()
            else:
                plot_lines = self._groupNGeneCounts(group_size, p=p)
        else:
            rr.dieOnCritical('group_size, wrong type or value',
                             [type(group_size), group_size])

        if group_location.lower() != 'all':
            rr.addInfo('grouping genes from location', group_location)
            plot_lines.sort(key=lambda x: x.rank)
            if group_location.lower() == 'top':
                plot_lines = [plot_lines[0]]
            elif group_location.lower() == 'middle':
                plot_lines = [plot_lines[int(len(plot_lines) / 2)]]
            elif group_location.lower() == 'bottom':
                plot_lines = [plot_lines[-1]]

        rr.addInfo('Plottable lines from study', len(plot_lines))
        return plot_lines
Exemple #3
0
def main():
    option_parser, opts, args =\
       parse_command_line_parameters(**script_info)

    if ',' not in opts.ylim:
        raise RuntimeError('ylim must be comma separated')

    ylim = map(float, opts.ylim.strip().split(','))

    print 'Loading counts data'
    data_collection1 = RegionCollection(filename=opts.collection1)
    window_size = data_collection1.info['args']['window_size']
    data_collection2 = RegionCollection(filename=opts.collection2)

    # filter both
    if opts.cutoff < 0 or opts.cutoff > 1:
        raise RuntimeError('The cutoff must be between 0 and 1')

    data_collection1 = data_collection1.filteredChebyshevUpper(opts.cutoff)
    data_collection2 = data_collection2.filteredChebyshevUpper(opts.cutoff)

    # make sure each collection consists ot the same genes
    shared_labels = set(data_collection1.labels) & \
                    set(data_collection2.labels)

    data_collection1 = data_collection1.filteredByLabel(shared_labels)
    data_collection2 = data_collection2.filteredByLabel(shared_labels)
    assert set(data_collection1.labels) == set(data_collection2.labels)

    if opts.sample_top is None:
        sample_top = data_collection1.N
    else:
        sample_top = opts.sample_top

    indices = range(sample_top)
    data_collection1 = data_collection1.take(indices)
    data_collection2 = data_collection2.take(indices)

    print 'Starting to plot'
    if opts.bgcolor == 'black':
        grid = {'color': 'w'}
        bgcolor = '0.1'
        vline_color = 'w'
    else:
        grid = {'color': 'k'}
        vline_color = 'k'
        bgcolor = '1.0'

    vline = dict(x=0,
                 linewidth=opts.vline_width,
                 linestyle=opts.vline_style,
                 color=vline_color)

    plot = PlottableSingle(height=opts.fig_height / 2.5,
                           width=opts.fig_width / 2.5,
                           bgcolor=bgcolor,
                           grid=grid,
                           ylim=ylim,
                           xlim=(-window_size, window_size),
                           xtick_space=opts.xgrid_lines,
                           ytick_space=opts.ygrid_lines,
                           xtick_interval=opts.xlabel_interval,
                           ytick_interval=opts.ylabel_interval,
                           xlabel_fontsize=opts.xfontsize,
                           ylabel_fontsize=opts.yfontsize,
                           vline=vline,
                           ioff=True)

    x = numpy.arange(-window_size, window_size)

    if opts.metric == 'Mean counts':
        stat = averaged
    else:
        data_collection1 = data_collection1.asfreqs()
        data_collection2 = data_collection2.asfreqs()
        stat = summed

    plot_sample(plot, data_collection1, stat_maker(stat, data_collection1), x,
                opts.title, opts.xlabel, opts.ylabel, 'b', opts.legend1,
                opts.plot_stderr)
    plot_sample(plot, data_collection2, stat_maker(stat, data_collection2), x,
                opts.title, opts.xlabel, opts.ylabel, 'r', opts.legend2,
                opts.plot_stderr)

    plot.legend()
    plot.show()
    if opts.plot_filename and not opts.test_run:
        plot.savefig(opts.plot_filename)
    else:
        print opts.plot_filename
 def test_asfreqs(self):
     """should correctly convert counts to freqs"""
     input = dict(counts=[[0, 1], [2, 3], [4, 5], [6, 7], [8, 9]])
     coll = RegionCollection(**input)
     freqs = coll.asfreqs()
     self.assertEqual(freqs.Total, 1.0)