예제 #1
0
    def subsample(self, sampling_threshold, **kwargs):
        if self.seqs is None:
            self.load_sequences()

        def sampling_category(x):
            return (x.attributes['region'],
                    x.attributes['date'].year,
                    x.attributes['date'].month)

        # load HI titer count to prioritize sequences
        HI_titer_count = {}
        with myopen(self.HI_strains_fname,'r') as ifile:
            for line in ifile:
                strain, count = line.strip().split()
                HI_titer_count[strain]=int(count)

        def sampling_priority(seq):
            sname = seq.attributes['strain']
            if sname in HI_titer_count:
                pr = HI_titer_count[sname]
            else:
                pr = 0
            return pr + len(seq.seq)*0.0001 - 0.01*np.sum([seq.seq.count(nuc) for nuc in 'NRWYMKSHBVD'])

        self.seqs.subsample(category = sampling_category,
                            threshold=sampling_threshold,
                            priority=sampling_priority, **kwargs)
예제 #2
0
def load_titers(titer_prefix):
    HI_titer_count = {}
    with myopen(titer_prefix + "_strains.tsv", 'r') as ifile:
        for line in ifile:
            strain, count = line.strip().split()
            HI_titer_count[strain] = int(count)
    return HI_titer_count
예제 #3
0
    def read_titers(self, fname):
        self.titer_fname = fname
        if "excluded_tables" in self.kwargs:
            self.excluded_tables = self.kwargs["excluded_tables"]
        else:
            self.excluded_tables = []

        strains = set()
        measurements = defaultdict(list)
        sources = set()
        with myopen(fname, 'r') as infile:
            for line in infile:
                entries = line.strip().split()
                test, ref_virus, serum, src_id, val = (entries[0], entries[1],entries[2],
                                                        entries[3], float(entries[4]))
                ref = (ref_virus, serum)
                if src_id not in self.excluded_tables:
                    try:
                        measurements[(test, (ref_virus, serum))].append(val)
                        strains.update([test, ref_virus])
                        sources.add(src_id)
                    except:
                        print(line.strip())
        self.titers = measurements
        self.strains = list(strains)
        self.sources = list(sources)
        print("Read titers from",self.titer_fname, 'found:')
        print(' ---', len(self.strains), "strains")
        print(' ---', len(self.sources), "data sources")
        print(' ---', sum([len(x) for x in measurements.values()]), " total measurements")
예제 #4
0
    def subsample(self, sampling_threshold):

        def sampling_category(x):
            return (x.attributes['region'],
                    x.attributes['date'].year,
                    x.attributes['date'].month)

        # load HI titer count to prioritize sequences
        HI_titer_count = {}
        with myopen(self.HI_strains_fname,'r') as ifile:
            for line in ifile:
                strain, count = line.strip().split()
                HI_titer_count[strain]=int(count)
        def sampling_priority(seq):
            sname = seq.attributes['strain'].upper()
            if sname in HI_titer_count:
                pr = HI_titer_count[sname]
            else:
                pr = 0
            return pr + len(seq.seq)*0.0001 - 0.01*np.sum([seq.seq.count(nuc) for nuc in 'NRWYMKSHBVD'])

        # constrain to time interval
        self.seqs.raw_seqs = {k:s for k,s in self.seqs.raw_seqs.iteritems() if
                                        s.attributes['date']>=self.time_interval[0] and
                                        s.attributes['date']<self.time_interval[1]}
        self.seqs.subsample(category = sampling_category,
                            threshold=sampling_threshold,
                            priority=sampling_priority )
예제 #5
0
    def read_titers(self, fname):
        self.titer_fname = fname
        if "excluded_tables" in self.kwargs:
            self.excluded_tables = self.kwargs["excluded_tables"]
        else:
            self.excluded_tables = []

        strains = set()
        measurements = defaultdict(list)
        sources = set()
        with myopen(fname, 'r') as infile:
            for line in infile:
                entries = line.strip().split()
                test, ref_virus, serum, src_id, val = (entries[0], entries[1],
                                                       entries[2], entries[3],
                                                       float(entries[4]))
                ref = (ref_virus, serum)
                if src_id not in self.excluded_tables:
                    try:
                        measurements[(test, (ref_virus, serum))].append(val)
                        strains.update([test, ref_virus])
                        sources.add(src_id)
                    except:
                        print(line.strip())
        self.titers = measurements
        self.strains = list(strains)
        self.sources = list(sources)
        print("Read titers from", self.titer_fname, 'found:')
        print(' ---', len(self.strains), "strains")
        print(' ---', len(self.sources), "data sources")
        print(' ---', sum([len(x) for x in measurements.values()]),
              " total measurements")
예제 #6
0
 def dump(self):
     from cPickle import dump
     from Bio import Phylo
     for attr_name, fname in self.file_dumps.iteritems():
         if hasattr(self,attr_name):
             print("dumping",attr_name)
             if attr_name=='seqs': self.seqs.raw_seqs = None
             with myopen(fname, 'wb') as ofile:
                 if attr_name=='tree':
                     Phylo.write(self.tree.tree, ofile, 'newick')
                 else:
                     dump(getattr(self,attr_name), ofile, -1)
예제 #7
0
 def load(self):
     from cPickle import load
     for attr_name, fname in self.file_dumps.iteritems():
         if os.path.isfile(fname):
             with myopen(fname, 'r') as ifile:
                 if attr_name=='tree':
                     continue
                 else:
                     setattr(self, attr_name, load(ifile))
     fname = self.file_dumps['tree']
     if os.path.isfile(fname):
         self.build_tree(fname)
예제 #8
0
 def dump(self):
     '''
     write the current state to file
     '''
     from cPickle import dump
     from Bio import Phylo
     for attr_name, fname in self.file_dumps.iteritems():
         if hasattr(self, attr_name):
             print("dumping", attr_name)
             #if attr_name=='seqs': self.seqs.all_seqs = None
             with myopen(fname, 'wb') as ofile:
                 if attr_name == 'nodes':
                     continue
                 elif attr_name == 'tree':
                     #biopython trees don't pickle well, write as newick + node info
                     self.tree.dump(fname, self.file_dumps['nodes'])
                 else:
                     dump(getattr(self, attr_name), ofile, -1)
예제 #9
0
 def dump(self):
     '''
     write the current state to file
     '''
     from cPickle import dump
     from Bio import Phylo
     for attr_name, fname in self.file_dumps.iteritems():
         if hasattr(self,attr_name):
             print("dumping",attr_name)
             if attr_name=='seqs': self.seqs.raw_seqs = None
             with myopen(fname, 'wb') as ofile:
                 if attr_name=='nodes':
                     continue
                 elif attr_name=='tree':
                     #biopython trees don't pickle well, write as newick + node info
                     self.tree.dump(fname, self.file_dumps['nodes'])
                 else:
                     dump(getattr(self,attr_name), ofile, -1)
예제 #10
0
    def load(self):
        '''
        reconstruct instance from files
        '''
        from cPickle import load
        for attr_name, fname in self.file_dumps.iteritems():
            if attr_name == 'tree':
                continue
            if os.path.isfile(fname):
                with myopen(fname, 'r') as ifile:
                    print('loading', attr_name, 'from file', fname)
                    setattr(self, attr_name, load(ifile))

        tree_name = self.file_dumps['tree']
        if os.path.isfile(tree_name):
            if os.path.isfile(self.file_dumps['nodes']):
                node_file = self.file_dumps['nodes']
            else:
                node_file = None
            # load tree, build if no tree file available
            self.build_tree(tree_name, node_file, root='none')
예제 #11
0
    def load(self):
        '''
        reconstruct instance from files
        '''
        from cPickle import load
        for attr_name, fname in self.file_dumps.iteritems():
            if attr_name=='tree':
                continue
            if os.path.isfile(fname):
                with myopen(fname, 'r') as ifile:
                    print('loading',attr_name,'from file',fname)
                    setattr(self, attr_name, load(ifile))

        tree_name = self.file_dumps['tree']
        if os.path.isfile(tree_name):
            if os.path.isfile(self.file_dumps['nodes']):
                node_file = self.file_dumps['nodes']
            else:
                node_file = None
            # load tree, build if no tree file available
            self.build_tree(tree_name, node_file, root='none')
예제 #12
0
    def subsample(self, sampling_threshold, **kwargs):

        def sampling_category(x):
            '''
            Subsample per region, per month.
            '''
            return (x.attributes['region'],
                    x.attributes['date'].year,
                    x.attributes['date'].month)

        # load titer count to prioritize sequences
        titer_count = {}
        forced_strains = []
        with myopen(self.strains_fname,'r') as ifile:
            for line in ifile:
                strain, count = line.strip().split()
                titer_count[strain]=int(count)
                if strain.startswith('DENV%s'%serotype):
                    forced_strains.append(strain)

        def sampling_priority(seq):
            '''
            Prefers more titer measurements, longer sequences. Penalizes ambiguity codes.
            '''
            sname = seq.attributes['strain']
            if sname in titer_count:
                pr = titer_count[sname]
            else:
                pr = 0

            if seq.id in forced_strains:
                pr += 10.0

            return pr + len(seq.seq)*0.00005 - 0.01*np.sum([seq.seq.count(nuc) for nuc in 'NRWYMKSHBVD'])

        self.seqs.subsample(category = sampling_category,
                            threshold=sampling_threshold,
                            priority=sampling_priority,
                            **kwargs)
예제 #13
0
    def load_from_file(filename, excluded_sources=None):
        """Load titers from a tab-delimited file.

        Args:
            filename (str): tab-delimited file containing titer strains, serum,
                            and values
            excluded_sources (list of str): sources in the titers file to exclude

        Returns:
            dict: titer measurements indexed by test strain, reference strain,
                  and serum with a list of raw floating point values per index
            list: distinct strains present as either test or reference viruses
            list: distinct sources of titers

        >>> measurements, strains, sources = TiterCollection.load_from_file("tests/titer_model/h3n2_titers_subset.tsv")
        >>> type(measurements)
        <type 'dict'>
        >>> len(measurements)
        11
        >>> measurements[("A/Acores/11/2013", ("A/Alabama/5/2010", "F27/10"))]
        [80.0]
        >>> len(strains)
        13
        >>> len(sources)
        5
        >>> measurements, strains, sources = TiterCollection.load_from_file("tests/titer_model/h3n2_titers_subset.tsv", excluded_sources=["NIMR_Sep2013_7-11.csv"])
        >>> len(measurements)
        5
        >>> measurements.get(("A/Acores/11/2013", ("A/Alabama/5/2010", "F27/10")))
        >>>
        """
        if excluded_sources is None:
            excluded_sources = []

        measurements = defaultdict(list)
        strains = set()
        sources = set()

        with myopen(filename, 'r') as infile:
            for line in infile:
                entries = line.strip().split()
                try:
                    val = float(entries[4])
                except:
                    continue
                test, ref_virus, serum, src_id = (entries[0], entries[1],
                                                  entries[2], entries[3])

                ref = (ref_virus, serum)
                if src_id not in excluded_sources:
                    try:
                        measurements[(test, (ref_virus, serum))].append(val)
                        strains.update([test, ref_virus])
                        sources.add(src_id)
                    except:
                        print(line.strip())

        logger.info("Read titers from %s, found:" % filename)
        logger.info(" --- %i strains" % len(strains))
        logger.info(" --- %i data sources" % len(sources))
        logger.info(" --- %i total measurements" %
                    sum([len(x) for x in measurements.values()]))

        return dict(measurements), list(strains), list(sources)
예제 #14
0
    def subsample_priority_region(self, sampling_threshold, priority_region='north_america', fraction=0.5, **kwargs):
        self.sequence_count_total = defaultdict(int)
        self.sequence_count_region = defaultdict(int)
        if 'repeated' in kwargs and kwargs['repeated']:
            seqs_to_count = self.seqs.seqs.values()
        else:
            seqs_to_count = self.seqs.all_seqs.values()

        for seq in seqs_to_count:
            self.sequence_count_total[(seq.attributes['date'].year,
                                  seq.attributes['date'].month)]+=1
            self.sequence_count_region[(seq.attributes['region'],
                                  seq.attributes['date'].year,
                                  seq.attributes['date'].month)]+=1

        def sampling_category(x):
            return (x.attributes['region'],
                    x.attributes['date'].year,
                    x.attributes['date'].month)


        def threshold_func(x):
            if x[0]==priority_region:
                return int(sampling_threshold*fraction)
            else:
                nregions = len(regions)-1
                total_threshold_world = sampling_threshold*(1-fraction)
                region_threshold = int(np.ceil(1.0*total_threshold_world/nregions))
                region_counts = sorted([self.sequence_count_region[(r, x[1], x[2])]
                                        for r in regions if r!=priority_region])
                if region_counts[0]>region_threshold:
                    return region_threshold
                else:
                    left_to_fill = total_threshold_world - nregions*region_counts[0]
                    thres = region_counts[0]
                    for ri, rc in zip(range(nregions-1, 0, -1), region_counts[1:]):
                        if left_to_fill - ri*(rc-thres)>0:
                            left_to_fill-=ri*(rc-thres)
                            thres = rc
                        else:
                            thres += left_to_fill/ri
                            break
                    return max(1,int(thres))


        # load HI titer count to prioritize sequences
        HI_titer_count = {}
        with myopen(self.HI_strains_fname,'r') as ifile:
            for line in ifile:
                strain, count = line.strip().split()
                HI_titer_count[strain]=int(count)

        def sampling_priority(seq):
            sname = seq.attributes['strain']
            if sname in HI_titer_count:
                pr = HI_titer_count[sname]
            else:
                pr = 0
            return pr + len(seq.seq)*0.0001 - 0.01*np.sum([seq.seq.count(nuc) for nuc in 'NRWYMKSHBVD'])

        self.seqs.subsample(category = sampling_category,
                            threshold=threshold_func,
                            priority=sampling_priority, **kwargs)