Ejemplo n.º 1
0
def get_numerical_dates(meta_dict,
                        name_col=None,
                        date_col='date',
                        fmt=None,
                        min_max_year=None):
    if fmt:
        from datetime import datetime
        numerical_dates = {}
        for k, m in meta_dict.items():
            v = m[date_col]
            if type(v) != str:
                print("WARNING: %s has an invalid data string:" % k, v)
                continue
            elif 'XX' in v:
                ambig_date = ambiguous_date_to_date_range(v, fmt, min_max_year)
                if ambig_date is None or None in ambig_date:
                    numerical_dates[k] = [
                        None, None
                    ]  #don't send to numeric_date or will be set to today
                else:
                    numerical_dates[k] = [numeric_date(d) for d in ambig_date]
            else:
                try:
                    numerical_dates[k] = numeric_date(datetime.strptime(
                        v, fmt))
                except:
                    numerical_dates[k] = None
    else:
        numerical_dates = {k: float(v) for k, v in meta_dict.items()}

    return numerical_dates
def flu_subsampling(metadata, viruses_per_month, time_interval, titer_fname=None):
    # Filter metadata by date using the given time interval. Using numeric dates
    # here allows users to define time intervals to the day and filter viruses
    # at that same level of precision.
    time_interval_start = round(numeric_date(time_interval[1]), 2)
    time_interval_end = round(numeric_date(time_interval[0]), 2)
    metadata = {
        strain: record
        for strain, record in metadata.items()
        if time_interval_start <= record["num_date"] <= time_interval_end
    }

    #### DEFINE THE PRIORITY
    if titer_fname:
        HI_titer_count = count_titer_measurements(titer_fname)
        def priority(strain):
            return HI_titer_count[strain]
    else:
        print("No titer counts provided - using random priorities")
        def priority(strain):
            return np.random.random()

    subcat_threshold = int(np.ceil(1.0*viruses_per_month/len(subcats)))

    virus_by_super_category, virus_by_category = populate_categories(metadata)
    def threshold_fn(x):
        #x is the subsampling category, in this case a tuple of (region, year, month)

        # if there are not enough viruses by super category, take everything
        if len(virus_by_super_category[x[1:]]) < viruses_per_month:
            return viruses_per_month

        # otherwise, sort sub categories by strain count
        sub_counts = sorted([(r, virus_by_super_category[(r, x[1], x[2])]) for r in subcats],
                             key=lambda y:len(y[1]))

        # if all (the smallest) subcat has more strains than the threshold, return threshold
        if len(sub_counts[0][1]) > subcat_threshold:
            return subcat_threshold


        strains_selected = 0
        tmp_subcat_threshold = subcat_threshold
        for ri, (r, strains) in enumerate(sub_counts):
            current_threshold = int(np.ceil(1.0*(viruses_per_month-strains_selected)/(len(subcats)-ri)))
            if r==x[0]:
                return current_threshold
            else:
                strains_selected += min(len(strains), current_threshold)
        return subcat_threshold

    selected_strains = []
    for cat, val in virus_by_category.items():
        val.sort(key=priority, reverse=True)
        selected_strains.extend(val[:threshold_fn(cat)])

    return selected_strains
Ejemplo n.º 3
0
def convert_date_fields_to_numeric(date_fields):
    nXX = date_fields.count("XX")
    if nXX == 0:
      return numeric_date(datetime.strptime("-".join(date_fields), "%Y-%M-%d"))
    elif nXX == 1:
      return numeric_date(datetime.strptime("-".join(date_fields[0:2])+"-15", "%Y-%M-%d"))
    elif nXX == 2:
      return date_fields[0]
    else:
      raise Exception("Unknown date format -- ", date_fields)
Ejemplo n.º 4
0
    def _distribution_to_human_readable(self, dist):

        if dist.is_delta:
            date = utils.numeric_date() - self.date2dist.get_date(
                dist.peak_pos)

            return [date - 0.5, date - 1e-10, date, date + 1e-10,
                    date + 0.5], [0, 0, 1.0, 0, 0]

        peak_pos = dist.peak_pos
        fwhm = dist.fwhm
        raw_x = dist.x[(dist.x > peak_pos - 3 * fwhm)
                       & (dist.x < peak_pos + 3 * fwhm)]
        dates_x = utils.numeric_date() - np.array(
            map(self.date2dist.get_date, raw_x))
        y = dist.prob_relative(raw_x)
        return dates_x, y
Ejemplo n.º 5
0
def get_numerical_date_from_value(value, fmt=None, min_max_year=None):
    value = str(value)
    if re.match(r'^-*\d+\.\d+$', value):
        # numeric date which can be negative
        return float(value)
    if value.isnumeric():
        # year-only date is ambiguous
        value = fmt.replace('%Y', value).replace('%m',
                                                 'XX').replace('%d', 'XX')
    if 'XX' in value:
        ambig_date = ambiguous_date_to_date_range(value, fmt, min_max_year)
        if ambig_date is None or None in ambig_date:
            return [None,
                    None]  #don't send to numeric_date or will be set to today
        return [numeric_date(d) for d in ambig_date]
    try:
        return numeric_date(datetime.strptime(value, fmt))
    except:
        return None
Ejemplo n.º 6
0
def get_numerical_date_from_value(value,
                                  fmt=None,
                                  min_max_year=None,
                                  raise_error=True):
    if type(value) != str:
        if raise_error:
            raise ValueError(value)
        else:
            numerical_date = None
    elif 'XX' in value:
        ambig_date = ambiguous_date_to_date_range(value, fmt, min_max_year)
        if ambig_date is None or None in ambig_date:
            numerical_date = [
                None, None
            ]  #don't send to numeric_date or will be set to today
        else:
            numerical_date = [numeric_date(d) for d in ambig_date]
    else:
        try:
            numerical_date = numeric_date(datetime.strptime(value, fmt))
        except:
            numerical_date = None

    return numerical_date
Ejemplo n.º 7
0
def get_numerical_dates(meta_dict, name_col = None, date_col='date', fmt=None, min_max_year=None):
    if fmt:
        from datetime import datetime
        numerical_dates = {}
        for k,m in meta_dict.items():
            v = m[date_col]
            if type(v)!=str:
                print("WARNING: %s has an invalid data string:"%k,v)
                continue
            elif 'XX' in v:
                ambig_date = ambiguous_date_to_date_range(v, fmt, min_max_year)
                if ambig_date is None or None in ambig_date:
                    numerical_dates[k] = [None, None] #don't send to numeric_date or will be set to today
                else:
                    numerical_dates[k] = [numeric_date(d) for d in ambig_date]
            else:
                try:
                    numerical_dates[k] = numeric_date(datetime.strptime(v, fmt))
                except:
                    numerical_dates[k] = None
    else:
        numerical_dates = {k:float(v) for k,v in meta_dict.items()}

    return numerical_dates
Ejemplo n.º 8
0
    def tt_from_file(self, infile, root='none'):
        from treetime.gtr import GTR
        from treetime import io, utils
        gtr = GTR.standard()
        self.tt = io.treetime_from_newick(gtr, infile)
        io.set_seqs_to_leaves(self.tt, self.aln)
        io.set_node_dates_from_dic(self.tt, {seq.id:utils.numeric_date(seq.attributes['date'])
                                for seq in self.aln if 'date' in seq.attributes})
        self.tree = self.tt.tree
        if root=='midpoint':
            self.tt.tree.root_at_midpoint()
            self.tt.set_additional_tree_params()
        elif root=='oldest':
            tmp = self.tt.reroot_to_oldest()

        for node in self.tree.get_terminals():
            if node.name in self.sequence_lookup:
                seq = self.sequence_lookup[node.name]
                for attr in seq.attributes:
                    if attr == 'date':
                        node.date = seq.attributes['date'].strftime('%Y-%m-%d')
                    else:
                        node.__setattr__(attr, seq.attributes[attr])
Ejemplo n.º 9
0
def flu_subsampling(metadata,
                    viruses_per_month,
                    time_interval,
                    priorities_fname=None):
    # Filter metadata by date using the given time interval. Using numeric dates
    # here allows users to define time intervals to the day and filter viruses
    # at that same level of precision.
    time_interval_start = numeric_date(time_interval[1])
    time_interval_end = numeric_date(time_interval[0])
    metadata = {
        strain: record
        for strain, record in metadata.items()
        if time_interval_start <= record["num_date"] <= time_interval_end
    }

    #### DEFINE THE PRIORITY
    if priorities_fname:
        strain_to_priority = defaultdict(int)
        for s, p in read_priorities(priorities_fname).items():
            strain_to_priority[s] = p

        def priority(strain):
            return float(strain_to_priority[strain]) + np.random.random()
    else:
        print("No priorities file provided - using random priorities")

        def priority(strain):
            return np.random.random()

    print("Viruses per month:", viruses_per_month)

    # Request an equal number of viruses per subcategory.
    subcat_threshold = int(np.ceil(float(viruses_per_month) / len(subcats)))
    print("Subcategory threshold:", subcat_threshold)

    virus_by_super_category, virus_by_category = populate_categories(metadata)

    def threshold_fn(x):
        #x is the subsampling category, in this case a tuple of (region, year, month)

        # if there are not enough viruses by super category, take everything
        if len(virus_by_super_category[x[1:]]) < viruses_per_month:
            return viruses_per_month

        # otherwise, sort sub categories by strain count
        sub_counts = sorted([(r, virus_by_category[(r, x[1], x[2])])
                             for r in subcats],
                            key=lambda y: len(y[1]))

        # if all (the smallest) subcat has more strains than the threshold, return threshold
        if len(sub_counts[0][1]) > subcat_threshold:
            return subcat_threshold

        # We assume no strains have been selected yet.
        strains_selected = 0

        tmp_subcat_threshold = subcat_threshold
        for ri, (r, strains) in enumerate(sub_counts):
            current_threshold = int(
                np.ceil(1.0 * (viruses_per_month - strains_selected) /
                        (len(subcats) - ri)))
            if r == x[0]:
                return current_threshold
            else:
                strains_selected += min(len(strains), current_threshold)
        return subcat_threshold

    selected_strains = []
    for cat, val in list(virus_by_category.items()):
        tmp = sorted(val, key=priority, reverse=True)
        selected_strains.extend(tmp[:threshold_fn(cat)])

    return selected_strains
Ejemplo n.º 10
0
    def _read_metadata_from_file(self, infile):
        """
        @brief      Reads a metadata from file or handle.

        @param      self    The object
        @param      infile  The input file or file handle

        """

        try:
            # read the metadata file into pandas dataframe.
            df = pandas.read_csv(infile, index_col=0, sep=r'\s*,\s*')

            # check the metadata has strain names in the first column
            if 'name' not in df.index.name.lower():
                print(
                    "Cannot read metadata: first column should contain the names of the strains"
                )
                return

            # look for the column containing sampling dates
            # We assume that the dates might be given in eihter human-readable format
            # (e.g. ISO dates), or be already converted to the numeric format.
            potential_date_columns = []
            potential_numdate_columns = []

            # Scan the dataframe columns and find ones which likely to store the
            # dates
            for ci, col in enumerate(df.columns):
                if 'date' in col.lower():
                    try:  #  avoid date parsing when can be parsed as float
                        tmp = float(df.iloc[0, ci])
                        potential_numdate_columns.append((ci, col))
                    except:  #  otherwise add as potential date column
                        potential_date_columns.append((ci, col))

            # if a potential numeric date column was found, use it
            # (use the first, if there are more than one)
            if len(potential_numdate_columns) >= 1:

                name = potential_numdate_columns[0][1]
                # Use this column as numdate_given
                dates = df[name].to_dict()

            elif len(potential_date_columns) >= 1:

                #try to parse the csv file with dates in the idx column:
                idx = potential_date_columns[0][0]
                name = potential_date_columns[0][1]
                # NOTE as the 0th column is the index, we should parse the dates
                # for the column idx + 1
                df = pandas.read_csv(infile,
                                     index_col=0,
                                     sep=r'\s*,\s*',
                                     parse_dates=[1 + idx])
                dates = {
                    k: utils.numeric_date(df.loc[k, name])
                    for k in df.index
                }

                df.loc[:, name] = map(lambda x: str(x.date()), df.loc[:, name])

            else:
                print(
                    "Metadata file has no column which looks like a sampling date!"
                )

            metadata = df.to_dict(orient='index')
            return dates, metadata

        except:

            print("Cannot read the metadata file. Exception caught")
        for name in metadata[segment]:
            if name in sequence_names_by_segment[segment]:
                filtered_metadata[segment][name] = metadata[segment][name]

    # filter down to strains with sequences for all required segments
    guide_segment = args.segments[0]
    strains_with_all_segments = set.intersection(*(set(filtered_metadata[x].keys()) for x in args.segments))
    # exclude outlier strains
    strains_with_all_segments.difference_update(set(excluded_strains))
    # subsample by region, month, year
    selected_strains = flu_subsampling({x:filtered_metadata[guide_segment][x] for x in strains_with_all_segments},
                                  args.viruses_per_month, time_interval, titer_fname=args.titers)

    # add strains that need to be included
    for strain in included_strains:
        if strain in strains_with_all_segments and strain not in selected_strains:
            # Do not include strains sampled too far in the past or strains
            # sampled from the future relative to the requested build interval.
            if (filtered_metadata[guide_segment][strain]['year'] >= lower_reference_cutoff.year and
                filtered_metadata[guide_segment][strain]['num_date'] <= numeric_date(upper_reference_cutoff)):
                selected_strains.append(strain)

    # Confirm that none of the selected strains were sampled outside of the
    # requested interval.
    for strain in selected_strains:
        assert filtered_metadata[guide_segment][strain]['num_date'] <= numeric_date(upper_reference_cutoff)

    # write the list of selected strains to file
    with open(args.output, 'w') as ofile:
        ofile.write('\n'.join(selected_strains))
        nargs=2,
        help="explicit time interval to use -- overrides resolutions"
        " expects YYYY-MM-DD YYYY-MM-DD")
    parser.add_argument('--region',
                        type=str,
                        help="region to draw sequences from")
    parser.add_argument(
        '--exclude',
        help=
        "a text file containing strains (one per line) that will be excluded")

    args = parser.parse_args()

    region = args.region
    time_interval = sorted([
        numeric_date(x)
        for x in determine_time_interval(args.time_interval, args.resolution)
    ])
    # read strains to exclude
    excluded_strains = read_strain_list(args.exclude) if args.exclude else []

    # read in meta data, parse numeric dates, and exclude outlier strains
    metadata = {
        k: val
        for k, val in parse_metadata(['segment'], [args.metadata]).items()
        if k not in excluded_strains
    }['segment']

    sequences = []
    print(time_interval)
    for seq in metadata:
Ejemplo n.º 13
0
    )

    parser.add_argument('--metadata', type=str, required=True,
                        help="file with metadata associated with viral sequences")
    parser.add_argument('--output', type=str,  help="names of file to save age_distribution histogram to ")
    parser.add_argument('-r', '--resolution',default='3y', type = str,
                        help = "single resolution to include (default: 3y)")
    parser.add_argument('--time-interval', nargs=2, help="explicit time interval to use -- overrides resolutions"
                                                        " expects YYYY-MM-DD YYYY-MM-DD")
    parser.add_argument('--region', type=str, help="region to draw sequences from")
    parser.add_argument('--exclude', help="a text file containing strains (one per line) that will be excluded")

    args = parser.parse_args()

    region=args.region
    time_interval = sorted([numeric_date(x)
            for x in determine_time_interval(args.time_interval, args.resolution)])
    # read strains to exclude
    excluded_strains = read_strain_list(args.exclude) if args.exclude else []

    # read in meta data, parse numeric dates, and exclude outlier strains
    metadata = {k:val for k,val in parse_metadata(['segment'], [args.metadata]).items()
                if k not in excluded_strains}['segment']

    sequences = []
    print(time_interval)
    for seq in metadata:
        if metadata[seq]["num_date"]>=time_interval[0] and \
           metadata[seq]["num_date"]<time_interval[1]:
            sequences.append(metadata[seq])
Ejemplo n.º 14
0
def flu_subsampling(metadata,
                    viruses_per_month,
                    time_interval,
                    titer_fnames=None,
                    priority_region=None,
                    priority_region_fraction=0.5):
    # Filter metadata by date using the given time interval. Using numeric dates
    # here allows users to define time intervals to the day and filter viruses
    # at that same level of precision.
    time_interval_start = numeric_date(time_interval[1])
    time_interval_end = numeric_date(time_interval[0])
    metadata = {
        strain: record
        for strain, record in metadata.items()
        if time_interval_start <= record["num_date"] <= time_interval_end
    }

    #### DEFINE THE PRIORITY
    if titer_fnames:
        HI_titer_count = defaultdict(int)
        for fname in titer_fnames:
            for s, k in count_titer_measurements(fname).items():
                HI_titer_count[s] += k

        def priority(strain):
            return HI_titer_count[strain] + 0.0 * np.random.random(
            ) - metadata[strain]['date'].count('X')
    else:
        print("No titer counts provided - using random priorities")

        def priority(strain):
            return 0.0 * np.random.random() - metadata[strain]['date'].count(
                'X')

    print("Viruses per category:", viruses_per_month)

    if priority_region is None:
        # Request an equal number of viruses per subcategory.
        subcat_threshold = int(np.ceil(
            float(viruses_per_month) / len(subcats)))
        print("Subcategory threshold:", subcat_threshold)
    else:
        # Give priority to the given region and request fewer viruses per other region.
        subcats.remove(priority_region)
        priority_region_threshold = int(
            np.ceil(priority_region_fraction * viruses_per_month))
        subcat_threshold = int(
            np.ceil((1 - priority_region_fraction) * viruses_per_month /
                    len(subcats)))
        print("Priority region threshold:", priority_region_threshold)
        print("Subcategory threshold:", subcat_threshold)

    virus_by_super_category, virus_by_category = populate_categories(metadata)

    def threshold_fn(x):
        #x is the subsampling category, in this case a tuple of (region, year, month)

        # if there are not enough viruses by super category, take everything
        if len(virus_by_super_category[x[1:]]) < viruses_per_month:
            return viruses_per_month

        # otherwise, sort sub categories by strain count
        sub_counts = sorted([(r, virus_by_category[(r, x[1], x[2])])
                             for r in subcats],
                            key=lambda y: len(y[1]))

        # If a priority region has been requested, return either the preferred
        # number of viruses for that region or the total number of viruses
        # sampled for that region during the current month and year.
        if priority_region == x[0]:
            return min(priority_region_threshold, len(virus_by_category[x]))

        # if all (the smallest) subcat has more strains than the threshold, return threshold
        if len(sub_counts[0][1]) > subcat_threshold:
            return subcat_threshold

        if priority_region is None:
            # If no region is given priority, we assume no strains have been selected yet.
            strains_selected = 0
        else:
            # If a priority region is given, we assume that region's proportion
            # of the total viruses per month have been selected given sufficient strains.
            # Otherwise, set strains_selected to the number of available viruses.
            # The remaining regions divide up the remaining viruses per month.
            strains_selected = min(
                len(virus_by_category[(priority_region, x[1], x[2])]),
                int(np.ceil(priority_region_fraction * viruses_per_month)))

        tmp_subcat_threshold = subcat_threshold
        for ri, (r, strains) in enumerate(sub_counts):
            current_threshold = int(
                np.ceil(1.0 * (viruses_per_month - strains_selected) /
                        (len(subcats) - ri)))
            if r == x[0]:
                return current_threshold
            else:
                strains_selected += min(len(strains), current_threshold)
        return subcat_threshold

    selected_strains = []
    for cat, val in list(virus_by_category.items()):
        tmp = sorted(val, key=priority, reverse=True)
        selected_strains.extend(tmp[:threshold_fn(cat)])

    return selected_strains
Ejemplo n.º 15
0
        time_interval,
        titer_fnames=args.titers,
        priority_region=args.priority_region,
        priority_region_fraction=args.priority_region_fraction)

    # add strains that need to be included
    # these strains don't have to exist in all segments, just the guide segment
    for strain in included_strains:
        if strain not in selected_strains and strain in filtered_metadata[
                guide_segment]:
            # Do not include strains sampled too far in the past or strains
            # sampled from the future relative to the requested build interval.
            if (filtered_metadata[guide_segment][strain]['year'] >=
                    lower_reference_cutoff.year
                    and filtered_metadata[guide_segment][strain]['num_date'] <=
                    numeric_date(upper_reference_cutoff)):
                selected_strains.append(strain)

    # summary of selected strains by region
    summary(selected_strains, filtered_metadata, args.segments, ['region'])
    summary(selected_strains, filtered_metadata, args.segments,
            ['year', 'month'])

    # Confirm that none of the selected strains were sampled outside of the
    # requested interval.
    for strain in selected_strains:
        assert filtered_metadata[guide_segment][strain][
            'num_date'] <= numeric_date(upper_reference_cutoff)

    # write the list of selected strains to file
    with open(args.output, 'w') as ofile:
Ejemplo n.º 16
0
                if not (x[0] in ['N', '-'] or x[-1] in ['N', '-'])
            ])

    tmrca = np.linspace(2019.7,
                        np.min([x['numdate'] for x in tips.values()]) - 0.001,
                        101)

    tsum = np.sum([np.mean(v['numdate']) for v in tips.values()])
    ntips = len(tips)
    L = 29000
    for mu in [3e-4, 5e-4, 1e-3]:
        logp = -mu * (tsum - ntips * tmrca) * L
        for tip in tips.values():
            logp += len(tip['mutations']) * np.log(tip['numdate'] - tmrca)
        p = np.exp(logp)
        p /= p.sum()
        plt.plot(tmrca, p, label=f"rate={mu:1.1e} per site and year", lw=2)

    plt.title(
        'TMRCA of 2019-nCov assuming a star tree\nand Poisson statistics of mutations',
        fontsize=16)
    plt.xlabel('TMRCA', fontsize=16)
    plt.ylabel('Probability density', fontsize=16)
    ticks = ['2019-10-01', '2019-11-01', '2019-12-01', '2020-01-01']
    plt.legend(loc=2, fontsize=12)
    plt.tick_params(labelsize=12)
    plt.xticks([
        numeric_date(datetime.datetime.strptime(x, '%Y-%m-%d')) for x in ticks
    ], ticks)
    plt.savefig(args.output)
Ejemplo n.º 17
0
def read_metadata_from_file(infile, log):
    """
    @brief      Reads a metadata from file or handle.
    @param      self    The object
    @param      infile  The input file or file handle
    """
    try:
        # read the metadata file into pandas dataframe.
        df = pandas.read_csv(infile,
                             index_col=0,
                             sep=r"\s*,\s*",
                             engine="python")
        # check the metadata has strain names in the first column
        # look for the column containing sampling dates
        # We assume that the dates might be given either in human-readable format
        # (e.g. ISO dates), or be already converted to the numeric format.
        if "name" not in df.index.name.lower():
            print(
                "Cannot read metadata: first column should contain the names of the strains",
                file=log,
            )
            return
        potential_date_columns = []
        potential_numdate_columns = []
        # Scan the dataframe columns and find ones which likely to store the
        # dates
        for ci, col in enumerate(df.columns):
            d = df.iloc[0, ci]
            if type(d) == str and d[0] in ['"', "'"] and d[-1] in ['"', "'"]:
                for i, tmp_d in enumerate(df.iloc[:, ci]):
                    df.iloc[i, ci] = tmp_d.strip(d[0])
            if "date" in col.lower():
                try:  #  avoid date parsing when can be parsed as float
                    tmp = float(df.iloc[0, ci])
                    potential_numdate_columns.append((ci, col))
                except:  #  otherwise add as potential date column
                    potential_date_columns.append((ci, col))
        # if a potential numeric date column was found, use it
        # (use the first, if there are more than one)
        if len(potential_numdate_columns) >= 1:
            name = potential_numdate_columns[0][1]
            # Use this column as numdate_given
            dates = df[name].to_dict()
            for k, val in dates.items():
                try:
                    dates[k] = float(val)
                except:
                    dates[k] = None

        elif len(potential_date_columns) >= 1:
            # try to parse the csv file with dates in the idx column:
            idx = potential_date_columns[0][0]
            name = potential_date_columns[0][1]
            # NOTE as the 0th column is the index, we should parse the dates
            # for the column idx + 1
            df = pandas.read_csv(
                infile,
                index_col=0,
                sep=r"\s*,\s*",
                parse_dates=[1 + idx],
                engine="python",
            )
            dates = {k: numeric_date(df.loc[k, name]) for k in df.index}
            df.loc[:, name] = map(lambda x: str(x.date()), df.loc[:, name])
        else:
            print(
                "Metadata file has no column which looks like a sampling date!",
                file=log,
            )
        metadata = df.to_dict(orient="index")
        for k, val in metadata.items():
            if type(k) == str and k[0] in ["'", '"'] and k[-1] in ["'", '"']:
                metadata[k.strip(k[0])] = val
                dates[k.strip(k[0])] = dates[k]
        return dates, metadata
    except:
        print("Cannot read the metadata file. Exception caught!", file=log)
        raise
        return {}, {}
Ejemplo n.º 18
0
def run(args):
    '''
    filter and subsample a set of sequences into an analysis set
    '''

    #Set flags if VCF
    is_vcf = False
    is_compressed = False
    if any([args.sequences.lower().endswith(x) for x in ['.vcf', '.vcf.gz']]):
        is_vcf = True
        if args.sequences.lower().endswith('.gz'):
            is_compressed = True

    ### Check users has vcftools. If they don't, a one-blank-line file is created which
    #   allows next step to run but error very badly.
    if is_vcf:
        from shutil import which
        if which("vcftools") is None:
            print(
                "ERROR: 'vcftools' is not installed! This is required for VCF data. "
                "Please see the augur install instructions to install it.")
            return 1

    ####Read in files

    #If VCF, open and get sequence names
    if is_vcf:
        seq_keep, all_seq = read_vcf(args.sequences)

    #if Fasta, read in file to get sequence names and sequences
    else:
        try:
            seqs = SeqIO.to_dict(SeqIO.parse(args.sequences, 'fasta'))
        except ValueError as error:
            print("ERROR: Problem reading in {}:".format(args.sequences))
            print(error)
            return 1
        seq_keep = list(seqs.keys())
        all_seq = seq_keep.copy()

    try:
        meta_dict, meta_columns = read_metadata(args.metadata)
    except ValueError as error:
        print("ERROR: Problem reading in {}:".format(args.metadata))
        print(error)
        return 1

    #####################################
    #Filtering steps
    #####################################

    # remove sequences without meta data
    tmp = []
    for seq_name in seq_keep:
        if seq_name in meta_dict:
            tmp.append(seq_name)
        else:
            print("No meta data for %s, excluding from all further analysis." %
                  seq_name)
            #print("CURRENT DIR IS ",os.getcwd())
            #print("args.output is ",args.output)
            with open(
                    os.path.join(os.getcwd(), "results",
                                 "QC_missing_metadata.txt"),
                    'a') as nometadata_f:
                nometadata_f.write(seq_name + "\n")
    seq_keep = tmp

    # remove strains explicitly excluded by name
    # read list of strains to exclude from file and prune seq_keep
    num_excluded_by_name = 0
    if args.exclude:
        try:
            with open(args.exclude, 'r') as ifile:
                to_exclude = set()
                for line in ifile:
                    if line[0] != comment_char:
                        # strip whitespace and remove all text following comment character
                        exclude_name = line.split(comment_char)[0].strip()
                        to_exclude.add(exclude_name)
            tmp = [
                seq_name for seq_name in seq_keep if seq_name not in to_exclude
            ]
            num_excluded_by_name = len(seq_keep) - len(tmp)
            seq_keep = tmp
        except FileNotFoundError as e:
            print("ERROR: Could not open file of excluded strains '%s'" %
                  args.exclude,
                  file=sys.stderr)
            sys.exit(1)

    # exclude strain my metadata field like 'host=camel'
    # match using lowercase
    num_excluded_by_metadata = {}
    if args.exclude_where:
        for ex in args.exclude_where:
            try:
                col, val = re.split(r'!?=', ex)
                #print("***********************      COL IS ", str(col))
            except (ValueError, TypeError):
                print(
                    "invalid --exclude-where clause \"%s\", should be of from property=value or property!=value"
                    % ex)
            else:
                to_exclude = set()
                for seq_name in seq_keep:
                    if "!=" in ex:  # i.e. property!=value requested
                        if meta_dict[seq_name].get(
                                col, 'unknown').lower() != val.lower():
                            to_exclude.add(seq_name)
                    else:  # i.e. property=value requested
                        if meta_dict[seq_name].get(
                                col, 'unknown').lower() == val.lower():
                            to_exclude.add(seq_name)
                tmp = [
                    seq_name for seq_name in seq_keep
                    if seq_name not in to_exclude
                ]
                num_excluded_by_metadata[ex] = len(seq_keep) - len(tmp)
                seq_keep = tmp

    # filter by sequence length
    num_excluded_by_length = 0
    if args.min_length:
        #print("MIN LENGTH ", args.min_length)
        if is_vcf:  #doesn't make sense for VCF, ignore.
            print("WARNING: Cannot use min_length for VCF files. Ignoring...")
        else:
            #myout = open("/data/PROJETS/COVID_19/TestCarmenLiaMurall_20200514/log.txt",'w')
            #myout.write("Minimum length : " + str(args.min_length) + "\n")
            seq_keep_by_length = []
            for seq_name in seq_keep:
                sequence = seqs[seq_name].seq
                length = sum(
                    map(lambda x: sequence.count(x),
                        ["a", "t", "g", "c", "A", "T", "G", "C"]))
                #print(">>>>>>>>>>>>>>>>>>> " + str(seq_name) + " length : " + str(length))
                '''
                with open("/data/PROJETS/Covid19_NextStrainBuilds/Test20201022_2/temp/seq_length.txt",'a') as length_file:
                    length_file.write(str(seq_name) + " : " + str(length) + "\n")
                    pass
                '''

                if length >= args.min_length:
                    seq_keep_by_length.append(seq_name)
                    #print(">>>>>>>>>>>>>>>>>>> " + str(seq_name) + " keeped by length : " + str(length))
                    #myout.write(str(seq_name) + " : " + str(length) + "\n")
                else:
                    print(">>>>>>>>>>>>>>>>>>> " + str(seq_name) +
                          " rejected by length : " + str(length))
                    with open(
                            os.path.join(os.getcwd(), "results",
                                         "QC_short_seq.txt"),
                            'a') as shortseq_f:
                        shortseq_f.write(seq_name + " " + str(length) + "\n")
                    #myout.write(str(seq_name) + " : " + str(length) + " => ************* REJECTED ****************\n")
            num_excluded_by_length = len(seq_keep) - len(seq_keep_by_length)
            seq_keep = seq_keep_by_length
            #myout.close()

    # filter by date
    num_excluded_by_date = 0
    if (args.min_date or args.max_date) and 'date' in meta_columns:

        #EricF add
        from treetime.utils import numeric_date
        from datetime import datetime
        dates = get_numerical_dates(meta_dict, fmt="%Y-%m-%d")
        #print("******************************************* DATES ********************************")
        #print(dates)
        dates = {k: v for (k, v) in dates.items() if v is not None}
        #print(dates)
        dates_2 = get_numerical_dates(meta_dict, fmt="%Y-%m")
        #print("******************************************* DATES 2 ********************************")
        #print(dates_2)
        dates_2 = {k: None for (k, v) in dates_2.items() if v is not None}
        #print(dates_2)
        dates_3 = get_numerical_dates(meta_dict, fmt="%Y")
        dates_3 = {k: None for (k, v) in dates_3.items() if v is not None}
        #print("******************************************* DATES 3 ********************************")
        #print(dates_3)
        dates.update(dates_2)
        dates.update(dates_3)
        #print("FINAL DATE ",dates)
        tmp = [s for s in seq_keep if dates[s] is not None]
        #print("TMP ",tmp)
        #print("NUM DATE IS ",str(numeric_date(datetime.strptime(args.max_date,"%Y-%m-%d"))))

        if args.min_date:
            #Eric Fournier 2020-10-30 comment
            #tmp = [s for s in tmp if (np.isscalar(dates[s]) or all(dates[s])) and np.max(dates[s])>args.min_date]
            #

            #Eric Fournier 2020-10-30 add
            my_min_date = numeric_date(
                datetime.strptime(args.min_date, "%Y-%m-%d"))
            tmp = [
                s for s in tmp if (np.isscalar(dates[s]) or all(dates[s]))
                and np.max(dates[s]) > my_min_date
            ]

            #print("TMP 1 ", tmp, "\n")

            for s1 in set(seq_keep) - set(tmp):
                pass
            for s2 in set(tmp):
                pass

        if args.max_date:
            my_max_date = numeric_date(
                datetime.strptime(args.max_date, "%Y-%m-%d"))
            #print("my_max_date ", str(my_max_date))
            #EricF comment
            #tmp = [s for s in tmp if (np.isscalar(dates[s]) or all(dates[s])) and np.min(dates[s])<args.max_date]
            tmp = [
                s for s in tmp if (np.isscalar(dates[s]) or all(dates[s]))
                and np.min(dates[s]) < my_max_date
            ]
            #print("TMP 2 ", tmp)
            for s1 in set(seq_keep) - set(tmp):
                #print(s1, " ", np.isscalar(dates[s1]), "  ", all(dates[s1]), "  ", np.min(dates[s1])<my_max_date)
                #print(s1, " ", np.isscalar(dates[s1]), "  ",dates[s1])
                pass
            for s2 in set(tmp):
                #print(s2, " ", np.isscalar(dates[s2]), "  ",dates[s2])
                pass
            #print("TMP IS ", str(tmp))
        num_excluded_by_date = len(seq_keep) - len(tmp)
        #print("Filter by date ",str(set(seq_keep) - set(tmp)))
        seq_keep = tmp

    # exclude sequences with non-nucleotide characters
    num_excluded_by_nuc = 0
    if args.non_nucleotide:
        good_chars = {
            'A', 'C', 'G', 'T', '-', 'N', 'R', 'Y', 'S', 'W', 'K', 'M', 'D',
            'H', 'B', 'V', '?'
        }
        tmp = [
            s for s in seq_keep
            if len(set(str(seqs[s].seq).upper()).difference(good_chars)) == 0
        ]
        num_excluded_by_nuc = len(seq_keep) - len(tmp)
        seq_keep = tmp

    # subsampling. This will sort sequences into groups by meta data fields
    # specified in --group-by and then take at most --sequences-per-group
    # from each group. Within each group, sequences are optionally sorted
    # by a priority score specified in a file --priority
    # Fix seed for the RNG if specified
    if args.subsample_seed:
        random.seed(args.subsample_seed)
    num_excluded_subsamp = 0
    if args.group_by and args.sequences_per_group:
        #print("************************* SEQUENCE PER GROUP IS ",str(args.sequences_per_group))
        spg = args.sequences_per_group
        seq_names_by_group = defaultdict(list)
        canadian_seq_list = []

        for seq_name in seq_keep:
            if (args.all_canadian_seq):
                if re.search(r'^Canada/', seq_name):
                    canadian_seq_list.append(seq_name)

            group = []
            m = meta_dict[seq_name]
            #print("m is ", str(m))
            #print("group by is ",str(args.group_by))
            # collect group specifiers
            for c in args.group_by:
                if c in m:
                    group.append(m[c])
                elif c in ['month', 'year'] and 'date' in m:
                    try:
                        year = int(m["date"].split('-')[0])
                    except:
                        print("WARNING: no valid year, skipping", seq_name,
                              m["date"])
                        continue
                    if c == 'month':
                        try:
                            month = int(m["date"].split('-')[1])
                        except:
                            month = random.randint(1, 12)
                        group.append((year, month))
                    else:
                        group.append(year)
                else:
                    group.append('unknown')
            seq_names_by_group[tuple(group)].append(seq_name)

        #If didnt find any categories specified, all seqs will be in 'unknown' - but don't sample this!
        if len(seq_names_by_group) == 1 and ('unknown' in seq_names_by_group or
                                             ('unknown', )
                                             in seq_names_by_group):
            print(
                "WARNING: The specified group-by categories (%s) were not found."
                % args.group_by,
                "No sequences-per-group sampling will be done.")
            if any([x in args.group_by for x in ['year', 'month']]):
                print(
                    "Note that using 'year' or 'year month' requires a column called 'date'."
                )
            print("\n")
        else:
            # Check to see if some categories are missing to warn the user
            group_by = set([
                'date' if cat in ['year', 'month'] else cat
                for cat in args.group_by
            ])
            missing_cats = [cat for cat in group_by if cat not in meta_columns]
            if missing_cats:
                print("WARNING:")
                if any([cat != 'date' for cat in missing_cats]):
                    print(
                        "\tSome of the specified group-by categories couldn't be found: ",
                        ", ".join([
                            str(cat) for cat in missing_cats if cat != 'date'
                        ]))
                if any([cat == 'date' for cat in missing_cats]):
                    print(
                        "\tA 'date' column could not be found to group-by year or month."
                    )
                print(
                    "\tFiltering by group may behave differently than expected!\n"
                )

            if args.priority:  # read priorities
                priorities = read_priority_scores(args.priority)

            # subsample each groups, either by taking the spg highest priority strains or
            # sampling at random from the sequences in the group
            seq_subsample = []

            if (args.all_canadian_seq):
                to_delete = [
                    mykey for mykey in seq_names_by_group
                    if re.search(r'Canada', str(mykey))
                ]
                for mykey in to_delete:
                    del seq_names_by_group[mykey]

            if (args.max_dup_per_group):
                for group, sequences_in_group in seq_names_by_group.items():
                    seq_names_by_group[group] = dedup(sequences_in_group, seqs,
                                                      args.max_dup_per_group)

            for group, sequences_in_group in seq_names_by_group.items():
                if args.priority:  #sort descending by priority
                    my_seq_ids = sorted(sequences_in_group,
                                        key=lambda x: priorities[x],
                                        reverse=True)[:spg]
                    seq_subsample.extend(
                        sorted(sequences_in_group,
                               key=lambda x: priorities[x],
                               reverse=True)[:spg])
                else:
                    seq_subsample.extend(
                        sequences_in_group if len(sequences_in_group) <= spg
                        else random.sample(sequences_in_group, spg))
                #print("GROUP ",str(group)," sequence in group ",str(sequences_in_group))

            num_excluded_subsamp = len(seq_keep) - len(seq_subsample)
            seq_keep = seq_subsample

            seq_keep += canadian_seq_list

    if (args.sequences_per_group == 0):
        seq_keep_tmp = seq_keep.copy()
        seq_keep = []

        for myseq_id in seq_keep_tmp:
            if myseq_id in ['Wuhan-Hu-1/2019', 'Wuhan/WH01/2019']:
                seq_keep.append(myseq_id)

    if (args.priority):
        print("Total context : ", str(len(seq_keep)))

    #print("********************* FINAL SEQ KEEP ",str(seq_keep))

    # force include sequences specified in file.
    # Note that this might re-add previously excluded sequences
    # Note that we are also not checking for existing meta data here
    num_included_by_name = 0
    if args.include and os.path.isfile(args.include):
        with open(args.include, 'r') as ifile:
            to_include = set([
                line.strip() for line in ifile
                if line[0] != comment_char and len(line.strip()) > 0
            ])

        for s in to_include:
            if s not in seq_keep:
                seq_keep.append(s)
                num_included_by_name += 1

    # add sequences with particular meta data attributes
    num_included_by_metadata = 0
    if args.include_where:
        to_include = []
        for ex in args.include_where:
            try:
                col, val = ex.split("=")
            except (ValueError, TypeError):
                print(
                    "invalid include clause %s, should be of from property=value"
                    % ex)
                continue

            # loop over all sequences and re-add sequences
            for seq_name in all_seq:
                if seq_name in meta_dict:
                    if meta_dict[seq_name].get(col) == val:
                        to_include.append(seq_name)
                else:
                    print("WARNING: no metadata for %s, skipping" % seq_name)
                    #print("CURRENT DIR IS ",os.getcwd())
                    #print("args.output is ",args.output)
                    continue

        for s in to_include:
            if s not in seq_keep:
                seq_keep.append(s)
                num_included_by_metadata += 1

    ####Write out files

    if is_vcf:
        #get the samples to be deleted, not to keep, for VCF
        dropped_samps = list(set(all_seq) - set(seq_keep))
        if len(dropped_samps) == len(
                all_seq):  #All samples have been dropped! Stop run, warn user.
            print(
                "ERROR: All samples have been dropped! Check filter rules and metadata file format."
            )
            return 1
        write_vcf(args.sequences, args.output, dropped_samps)

    else:
        seq_to_keep = [seq for id, seq in seqs.items() if id in seq_keep]
        if len(seq_to_keep) == 0:
            #print("ERROR2: All samples have been dropped! Check filter rules and metadata file format.")
            #return 1
            pass  # eric fournier 2020-07-07
        SeqIO.write(seq_to_keep, args.output, 'fasta')

    print("\n%i sequences were dropped during filtering" %
          (len(all_seq) - len(seq_keep), ))
    if args.exclude:
        print("\t%i of these were dropped because they were in %s" %
              (num_excluded_by_name, args.exclude))
    if args.exclude_where:
        for key, val in num_excluded_by_metadata.items():
            print("\t%i of these were dropped because of '%s'" % (val, key))
    if args.min_length:
        print(
            "\t%i of these were dropped because they were shorter than minimum length of %sbp"
            % (num_excluded_by_length, args.min_length))
    if (args.min_date or args.max_date) and 'date' in meta_columns:
        print(
            "\t%i of these were dropped because of their date (or lack of date)"
            % (num_excluded_by_date))
    if args.non_nucleotide:
        print(
            "\t%i of these were dropped because they had non-nucleotide characters"
            % (num_excluded_by_nuc))
    if args.group_by and args.sequences_per_group:
        seed_txt = ", using seed {}".format(
            args.subsample_seed) if args.subsample_seed else ""
        print("\t%i of these were dropped because of subsampling criteria%s" %
              (num_excluded_subsamp, seed_txt))

    if args.include and os.path.isfile(args.include):
        print("\n\t%i sequences were added back because they were in %s" %
              (num_included_by_name, args.include))
    if args.include_where:
        print("\t%i sequences were added back because of '%s'" %
              (num_included_by_metadata, args.include_where))

    print("%i sequences have been written out to %s" %
          (len(seq_keep), args.output))
Ejemplo n.º 19
0
        "--additional-years-back-for-references",
        type=int,
        default=5,
        help=
        "Additional number of years prior to the given timepoint to allow reference strains"
    )
    parser.add_argument(
        "--reference-strains",
        help=
        "text file containing list of reference strains that should be included from the original strains even if they were sampled prior to the minimum date determined by the requested number of years before the given timepoint"
    )
    args = parser.parse_args()

    # Convert date string to a datetime instance.
    timepoint = pd.to_datetime(args.timepoint)
    numeric_timepoint = np.around(numeric_date(timepoint), 2)

    # Load metadata with strain names and dates.
    metadata, columns = read_metadata(args.metadata)

    # Convert string dates with potential ambiguity (e.g., 2010-05-XX) into
    # floating point dates.
    dates = get_numerical_dates(metadata, fmt="%Y-%m-%d")

    # Setup reference strains.
    if args.reference_strains:
        reference_strains = read_strain_list(args.reference_strains)
    else:
        reference_strains = []

    # If a given number of years back has been requested, determine what the
Ejemplo n.º 20
0
if __name__ == '__main__':
    parser = argparse.ArgumentParser(
        description="Estimate Tmrca assuming a star topology and a poisson mutation process",
        formatter_class=argparse.ArgumentDefaultsHelpFormatter
    )
    parser.add_argument("--infectious-period", type=float, default=10, help="infections period in days")
    parser.add_argument("--start", type=str, default='2019-11-27', help="start of the outbreak")
    parser.add_argument("--population", nargs='+', type=int, default = [300,1000,3000,10000],
                         help="number of individuals now")
    parser.add_argument("--output", required=True, help="figure file for line graph")
    args = parser.parse_args()

    d = 365.0/args.infectious_period
    n_vals = np.array(args.population)

    T = numeric_date() - numeric_date(datetime.datetime.strptime(args.start, '%Y-%m-%d'))
    b_vals = d*np.linspace(0.5, 6, 111)

    weeks = int(T*365/7)
    inf_period = int(365/d)

    fs=16
    plt.figure()
    plt.title(f"Start: {weeks} weeks ago = {args.start}. Infectious period {inf_period} days")
    for n in n_vals:
        lh =  LH(n,T,b_vals,d)
        lh /= lh.sum()
        lh /= (b_vals[1]-b_vals[0])/d
        plt.plot(b_vals/d, lh, lw=3, label=f'n={n}')
    plt.ylabel('Probability density', fontsize=fs)
    plt.xlim([0.9,4])