def nonsyn_chances():
    from Bio.Seq import Seq
    chances = {a + '->' + b: 0 for a in alpha[:4] for b in alpha[:4] if a != b}
    ref = HIVreference(refname='NL4-3', load_alignment=True)
    seq = ref.seq
    for _, anno in ref.annotation.iteritems():
        if anno.type != 'gene':
            continue
        seq_reg = anno.extract(seq).seq
        for icod in xrange(len(seq_reg) // 3):
            cod_anc = seq_reg[icod * 3:(icod + 1) * 3]
            aa_anc = cod_anc.translate()
            for rf in xrange(3):
                for nuc in alpha[:4]:
                    if nuc == cod_anc[rf]:
                        continue
                    aa_der = np.array(seq_reg[icod * 3:(icod + 1) * 3])
                    aa_der[rf] = nuc
                    aa_der = Seq(''.join(aa_der)).translate()
                    if aa_anc != aa_der:
                        chances[cod_anc[rf] + '->' + nuc] += 1

    chances = pd.Series(chances)
    chances /= chances.sum()
    return chances
Ejemplo n.º 2
0
def get_non_consensus_mask(patient, region, aft, ref=HIVreference(subtype="any")):
    """
    Returns a 1D vector of size aft.shape[-1] where True are the position that do not correspond to consensus sequences.
    Position that are not mapped to reference or seen too often gapped are always False.
    """
    ref_filter = trajectory.get_reference_filter(patient, region, aft, ref)
    consensus_mask = trajectory.get_reversion_map(patient, region, aft, ref)
    initial_idx = patient.get_initial_indices(region)
    consensus_mask = consensus_mask[initial_idx, np.arange(aft.shape[-1])]

    return np.logical_and(ref_filter, ~consensus_mask)
Ejemplo n.º 3
0
def load_patient_data(patient_names = 'all', 
                      q = 4, 
                      timescale = 'years', 
                      filepath = None, 
                      fromHIVEVO = False):
    if patient_names == 'all':
        patient_names = ['p{}'.format(j+1) for j in xrange(11)]
    if fromHIVEVO:
        #sys.path.append('/ebio/ag-neher/share/users/vpuller/HIVEVO/HIVEVO_access') 
        sys.path.append('/home/vadim/ebio/users/vpuller/HIVEVO/HIVEVO_access') 
        from hivevo.patients import Patient
        from hivevo.HIVreference import HIVreference
        
        ref = HIVreference(load_alignment=False)
        Lref = len(ref.seq)
        data_all = {}
        for pat_name in patient_names:
            PAT = Patient.load(pat_name)
            tt = PAT.times(unit = timescale)
            vload = PAT.n_templates_viral_load
            dilutions = PAT.n_templates_dilutions
            freqs_raw = PAT.get_allele_frequency_trajectories('genomewide', error_rate=err)[:,:q,:]
            map_to_ref = PAT.map_to_external_reference('genomewide')
            freqs = np.ma.zeros((tt.shape[0], q, Lref)); freqs.mask = True
            freqs[:,:,map_to_ref[:,0]] = freqs_raw[:,:,map_to_ref[:,1]]
            data_all[pat_name] = (tt, freqs, vload, dilutions)
            if filepath is not None:
                np.save(filepath + '{}_data.npy'.format(pat_name), freqs.data)
                np.save(filepath + '{}_mask.npy'.format(pat_name), freqs.mask)
                np.save(filepath + '{}_tt.npy'.format(pat_name), tt) 
                np.save(filepath + '{}_viral_load.npy'.format(pat_name), vload)
                np.save(filepath + '{}_dilutions.npy'.format(pat_name), dilutions.data)
                np.save(filepath + '{}_dilutions_mask.npy'.format(pat_name), dilutions.mask)
        data_all['Lref'] = freqs.shape[2]
        data_all['pat_names'] = patient_names
        
    elif filepath is not None:
        data_all = {}
        for pat_name in patient_names:
            tt = np.load(filepath + '{}_tt.npy'.format(pat_name))
            data = np.load(filepath + '{}_data.npy'.format(pat_name))
            mask = np.load(filepath + '{}_mask.npy'.format(pat_name))
            freqs = np.ma.masked_array(data, mask = mask)
            vload = np.load(filepath + '{}_viral_load.npy'.format(pat_name))
            dilutions = np.load(filepath + '{}_dilutions.npy'.format(pat_name))
            dilutions_mask = np.load(filepath + '{}_dilutions_mask.npy'.format(pat_name))
            dilutions = np.ma.masked_array(dilutions, mask = dilutions_mask)
            data_all[pat_name] = (tt, freqs, vload, dilutions)
        data_all['Lref'] = freqs.shape[2]
        data_all['pat_names'] = patient_names
    else:
        print 'Path to data is not specified'
    return data_all
Ejemplo n.º 4
0
def create_all_patient_trajectories(region, patient_names=[]):
    if patient_names == []:
        patient_names = ["p1", "p2", "p3", "p4", "p5", "p6", "p8", "p9", "p11"]

    trajectories = []
    ref = HIVreference(subtype="any")
    for patient_name in patient_names:
        patient = Patient.load(patient_name)
        aft = patient.get_allele_frequency_trajectories(region)
        trajectories = trajectories + create_trajectory_list(
            patient, region, aft, ref)

    return trajectories
Ejemplo n.º 5
0
def get_divergence_in_time(region, patient, ref=HIVreference(subtype="any")):
    """
    Returns the 2D matrix with divergence at all genome postion through time. Sites that are too often gapped
    are masked
    """
    aft = patient.get_allele_frequency_trajectories(region)
    ref_filter = trajectory.get_reference_filter(patient, region, aft, ref)
    div_3D = divergence_matrix(patient, region, aft, False)
    initial_idx = patient.get_initial_indices(region)
    div = div_3D[np.arange(aft.shape[0])[:, np.newaxis, np.newaxis], initial_idx, np.arange(aft.shape[-1])]
    div = div[:, 0, :]
    div = np.ma.array(div, mask=np.tile(~ref_filter, (aft.shape[0], 1)))
    return div
Ejemplo n.º 6
0
def get_sweep_sites_sum(
        region,
        patient_names=["p1", "p2", "p3", "p4", "p5", "p6", "p8", "p9", "p11"]):
    "Returns a 1D vector with the sum of sweep sites over all patients"
    sites = []
    for patient_name in patient_names:
        patient = Patient.load(patient_name)
        aft = patient.get_allele_frequency_trajectories(region)
        sweep_mask = get_sweep_mask(patient, aft, region, threshold_low=0.5)
        ref = HIVreference(subtype="any")
        reference_mask = trajectory.get_reference_filter(
            patient, region, aft, ref)
        sweep_mask = sweep_mask[reference_mask]
        sites = sites + [list(sweep_mask[:2964])]

    sites = np.array(sites)
    sites = np.sum(sites, axis=0, dtype=int)
    return sites
def collect_data(patients, cov_min=100, refname='HXB2'):
    '''Collect data for the fitness cost estimate'''
    ref = HIVreference(refname=refname, subtype='any', load_alignment=True)
    mus = load_mutation_rates()
    mu = mus.mu
    muA = mus.muA

    data = []
    for pi, pcode in enumerate(patients):
        print pcode

        p = Patient.load(pcode)
        comap = (pd.DataFrame(
            p.map_to_external_reference('genomewide', refname=refname)[:, :2],
            columns=[refname, 'patient']).set_index('patient',
                                                    drop=True).loc[:, refname])

        aft = p.get_allele_frequency_trajectories('genomewide',
                                                  cov_min=cov_min)
        for pos, aft_pos in enumerate(aft.swapaxes(0, 2)):
            fead = p.pos_to_feature[pos]

            # Keep only sites within ONE protein
            # Note: we could drop this, but then we cannot quite classify syn/nonsyn
            if len(fead['protein_codon']) != 1:
                continue

            # Exclude codons with gaps
            pc = fead['protein_codon'][0][-1]
            cod_anc = ''.join(p.initial_sequence[pos - pc:pos - pc + 3])
            if '-' in cod_anc:
                continue

            # Keep only nonmasked times
            if aft_pos[:4].mask.any(axis=0).all():
                continue
            else:
                ind = ~aft_pos[:4].mask.any(axis=0)
                times = p.dsi[ind]
                aft_pos = aft_pos[:, ind]

            # Get site entropy
            if pos not in comap.index:
                continue
            pos_ref = comap.loc[pos]
            S_pos = ref.entropy[pos_ref]

            # Keep only sites where the ancestral allele and group M agree
            if ref.consensus_indices[pos_ref] != aft_pos[:, 0].argmax():
                continue

            for ia, aft_nuc in enumerate(aft_pos[:4]):
                # Keep only derived alleles
                if alpha[ia] == p.initial_sequence[pos]:
                    continue

                # Keep only sweeps
                if not (aft_nuc > 0.5).any():
                    continue

                # Annotate with syn/nonsyn alleles
                cod_new = cod_anc[:pc] + alpha[ia] + cod_anc[pc + 1:]
                if translate(cod_anc) != translate(cod_new):
                    syn = False
                else:
                    syn = True

                mut = p.initial_sequence[pos] + '->' + alpha[ia]
                mu_pos = mu[mut]
                muA_pos = muA[mut]

                for it, (t, af_nuc) in enumerate(izip(times, aft_nuc)):
                    datum = {
                        'time': t,
                        'af': af_nuc,
                        'pos': pos,
                        'pos_ref': pos_ref,
                        'protein': fead['protein_codon'][0][0],
                        'pcode': pcode,
                        'mut': mut,
                        'mu': mu_pos,
                        'muAbram': muA_pos,
                        'S': S_pos,
                        'syn': syn,
                    }
                    data.append(datum)

    data = pd.DataFrame(data)

    return data
def genome_fractions():
    # TODO: this should be real nonsyn opportunities rather than just genome fraction
    from collections import Counter
    ref = HIVreference(subtype='B', load_alignment=False)
    d = pd.Series(Counter(ref.consensus)).loc[alpha[:4]]
    return d / d.sum()
Ejemplo n.º 9
0
    data.index += start

    return data



# Script
if __name__ == '__main__':

    pp = load_pairing_probability()

    shape = load_shape()


    from hivevo.HIVreference import HIVreference
    ref = HIVreference('NL4-3', load_alignment=False)
    seq = ref.seq.seq

    pairings = {'A': 'T', 'C': 'G', 'T': 'A', 'G': 'C'}
    for position, (partner, prob) in pp.iterrows():
        partner = int(partner)
        if ((seq[partner] != pairings[seq[position]]) and
            # Wobble pairs
            (frozenset([seq[position], seq[partner]]) != frozenset(['G', 'T']))):

            print position, seq[position], seq[partner], pairings[seq[position]]


    #fig, ax = plt.subplots()
    #ax.hist(shape['probability'], bins=np.linspace(0, 1, 10))
    #ax.set_xlabel('Pairing probability')
Ejemplo n.º 10
0
        ]
        non_syn = [
            traj for traj in trajectories[region] if traj.synonymous == False
        ]
        trajectories[region] = {
            "rev": rev,
            "non_rev": non_rev,
            "syn": syn,
            "non_syn": non_syn,
            "all": trajectories[region]
        }

    return trajectories


def load_trajectory_dict(path="trajectory_dict"):
    trajectories = {}
    with open(path, 'rb') as file:
        trajectories = pickle.load(file)

    return trajectories


if __name__ == "__main__":
    region = "pol"
    patient = Patient.load("p1")
    ref = HIVreference(subtype="any")
    aft = patient.get_allele_frequency_trajectories(region)
    trajectories = create_trajectory_list(patient, region, aft, ref)
    reversions = [traj for traj in trajectories if traj.reversion == True]
def plot_drug_resistance_mutations(data, aa_mutation_rates, fname=None):
    '''Plot the frequency of drug resistance mutations'''
    import matplotlib.patches as patches

    fs = 16
    region = 'pol'
    pcodes = data['init_codon'][region].keys()

    fig, axs = plt.subplots(2, 1, gridspec_kw={'height_ratios': [1, 6]})
    ax = axs[1]

    drug_afs_items = []
    mut_types = []
    drug_classes = ['PI', 'NRTI', 'NNRTI', 'INI']
    for prot in drug_classes:
        drug_afs = {}
        drug_mut_rates = {}
        offset = drug_muts[prot]['offset']
        for cons_aa, pos, target_aa in drug_muts[prot]['mutations']:
            codons = {
                pat: data['init_codon'][region][pat][pos + offset]
                for pat in pcodes
            }
            mut_rates = {
                pat: np.sum(
                    [aa_mutation_rates[(codons[pat], aa)] for aa in target_aa])
                for pat in pcodes
            }
            freqs = {pat:np.sum([data['af_by_pat'][region][pat][alphaal.index(aa), pos+offset]\
                                /data['af_by_pat'][region][pat][:20,pos+offset].sum()
                        for aa in target_aa]) for pat in pcodes}

            drug_afs[(cons_aa, pos, target_aa)] = freqs
            drug_mut_rates[(cons_aa, pos, target_aa)] = mut_rates

        drug_afs_items.extend(
            filter(
                lambda x: np.sum(filter(lambda y: ~np.isnan(y), x[1].values()))
                > 0, sorted(drug_afs.items(), key=lambda x: x[0][1])))
        mut_types.append(len(drug_afs_items))
        #make list of all mutations whose non-nan frequencies sum to 0
        mono_muts = [
            ''.join(map(str, x[0])) for x in filter(
                lambda x: np.sum(filter(lambda y: ~np.isnan(y), x[1].values()))
                == 0, sorted(drug_afs.items(), key=lambda x: x[0][1]))
        ]
        print('Monomorphic:', prot, mono_muts)

    plt.ylim([1.1e-5, 1e-1])
    for mi in mut_types[:-1]:
        plt.plot([mi - 0.5, mi - 0.5],
                 plt.ylim(),
                 c=(.3, .3, .3),
                 lw=3,
                 alpha=0.5)
    ax.axhline(4e-5, c=(.3, .3, .3), lw=3, alpha=0.5)

    for ni, prot in enumerate(drug_classes):
        plt.text(0.5 * (mut_types[ni] + (mut_types[ni - 1] if ni else 0)) -
                 0.5,
                 0.12,
                 prot,
                 fontsize=16,
                 ha='center')

    for mi in range(max(mut_types)):
        c = 0.5 + 0.2 * (mi % 2)
        ax.add_patch(
            patches.Rectangle(
                (mi - 0.5, plt.ylim()[0]),
                1.0,
                plt.ylim()[1],  #(x,y), width, height
                color=(c, c, c),
                alpha=0.2))

    #plt.xticks(np.arange(len(all_muts)), ["".join(map(str, x)) for x in all_muts], rotation=60)
    afdr = pd.DataFrame(
        np.array([x[1].values() for x in drug_afs_items]).T,
        columns=["".join(map(str, x[0])) for x in drug_afs_items])
    afdr[afdr < 0.8e-4] = 0
    sns.stripplot(data=afdr,
                  jitter=0.4,
                  alpha=0.8,
                  size=12,
                  lw=1,
                  edgecolor='white')

    # Add the number of missing points at the bottom of the plot, and the cost
    # at the top
    dd = afdr.iloc[[0, 1, 2, 3, 4]].copy()
    dd.index = ['x', 'freq', 'size', 'cost', 'mr']
    dd.loc['x'] = np.arange(dd.shape[1])
    dd.loc['freq'] = 2e-5
    dd.loc['n'] = afdr.shape[0] - (afdr > 1e-4).sum(axis=0)
    dd.loc['size'] = dd.loc['n']**(1.4) + 13
    dd.loc['cost'] = 1.0 / afdr.fillna(0).mean(axis=0)
    dd.loc['mr'] = 0
    # NOTE: the first 6 mutations are in PR, the rest in RT
    import re
    from Bio.Seq import translate
    reference = HIVreference(refname='HXB2', load_alignment=False)
    seq_PR = reference.annotation['PR'].extract(reference.seq)
    seq_RT = reference.annotation['RT'].extract(reference.seq)
    seq_IN = reference.annotation['IN'].extract(reference.seq)
    murate = load_mutation_rates()['mu']
    for i, mut in enumerate(dd.T.index):
        mr = 0
        if i < 6:
            seq_tmp = seq_PR
        elif i < 6 + 5 + 4:
            seq_tmp = seq_RT
        else:
            seq_tmp = seq_IN
        aa_from, pos, aas_to = re.sub('([A-Z])(\d+)([A-Z]+)', r'\1_\2_\3',
                                      mut).split('_')
        cod = str(seq_tmp.seq[(int(pos) - 1) * 3:int(pos) * 3])
        for pos_cod in xrange(3):
            for nuc in ['A', 'C', 'G', 'T']:
                codmut = list(cod)
                codmut[pos_cod] = nuc
                codmut = ''.join(codmut)
                if (codmut != cod) and (translate(cod)
                                        == aa_from) and (translate(codmut)
                                                         in aas_to):
                    mr += murate[cod[pos_cod] + '->' + nuc]

        dd.loc['cost', mut] *= mr
        dd.loc['mr', mut] = mr

    for im, (mutname, s) in enumerate(dd.T.iterrows()):
        ax.scatter(
            s['x'],
            s['freq'],
            s=s['size']**2,
            alpha=0.8,
            edgecolor='white',
            facecolor=sns.color_palette('husl', afdr.shape[1])[im],
            lw=2,
        )
        ax.text(s['x'],
                s['freq'],
                str(int(s['n'])),
                fontsize=fs,
                ha='center',
                va='center')

    plt.yscale('log')
    plt.xticks(rotation=50)
    plt.ylabel('minor variant frequency', fontsize=fs)
    plt.tick_params(labelsize=fs * 0.8)
    for tick in ax.xaxis.get_major_ticks():
        tick.label.set_horizontalalignment('right')

    # Fitness cost at the top
    ax1 = axs[0]
    ax1.set_xlim(*ax.get_xlim())
    ax1.set_xticks(ax.get_xticks() + 0.5)
    ax1.set_xticklabels([])
    ax1.set_ylim(1e-3, 1)
    ax1.set_yticks([1e-3, 1e-2, 1e-1, 1])
    ax1.yaxis.set_tick_params(labelsize=fs * 0.8)
    ax1.set_yscale('log')
    ax1.set_ylabel('cost', fontsize=fs)
    for im, (mut, y) in enumerate(dd.loc['cost'].iteritems()):
        ax1.bar(im - 0.5,
                y,
                1,
                color=sns.color_palette('husl', afdr.shape[1])[im])

    plt.tight_layout()

    if fname is not None:
        for ext in ['svg', 'pdf', 'png']:
            plt.savefig(fname + '.' + ext)
    else:
        plt.ion()
        plt.show()
Ejemplo n.º 12
0
                        help="Protein to study")

    args = parser.parse_args()

    # Coordinates should be AA of HXB2, but we load the sequence to make
    # sure
    fn = 'data/secondary_uniprot/'+args.protein+'.tsv'
    fn_seq = 'data/secondary_uniprot/'+args.protein+'.fasta'

    # Get sequence first, then annotate (for now)
    # TODO: we will annotate the actual DNA sequence eventually
    seq = SeqIO.read(fn_seq, 'fasta')

    # Get the DNA sequence from our reference and compare (should be HXB2)
    refname = 'HXB2'
    ref = HIVreference(refname=refname, subtype='B', load_alignment=False)

    # Extract gagpol feature
    if args.protein == 'gagpol':
        from Bio.Seq import Seq
        start = ref.annotation['gag'].location.nofuzzy_start
        end = ref.annotation['pol'].location.nofuzzy_end
        slippage_site = 434
        refdna = (ref.seq[start: start + slippage_site * 3] +
                  ref.seq[start + slippage_site * 3 - 1: end])
        refprot = refdna.seq.translate()

    elif args.protein == 'nef':
        # it comes with the stop codon...
        refdna = ref.annotation[args.protein].extract(ref.seq)
        refprot = refdna.seq.translate()[:-1]
Ejemplo n.º 13
0
def collect_data(patients,
                 cov_min=100,
                 refname='HXB2',
                 subtype='any',
                 entropy_threshold=0.1,
                 excluded_proteins=[]):
    '''Collect data for the mutation rate estimate'''
    print('Collect data from patients')

    ref = HIVreference(refname=refname, load_alignment=True, subtype=subtype)

    data = []
    for pi, pcode in enumerate(patients):
        print(pcode)

        p = Patient.load(pcode)
        comap = (pd.DataFrame(p.map_to_external_reference('genomewide')[:, :2],
                              columns=[refname, 'patient'
                                       ]).set_index('patient',
                                                    drop=True).loc[:, refname])

        aft = p.get_allele_frequency_trajectories('genomewide',
                                                  cov_min=cov_min)
        times = p.dsi

        for pos, aft_pos in enumerate(aft.swapaxes(0, 2)):
            fead = p.pos_to_feature[pos]

            # Keep only sites within ONE protein
            if len(fead['protein_codon']) != 1:
                continue
            # skip if protein is to be excluded
            if fead['protein_codon'][0][0] in excluded_proteins:
                continue

            # Exclude codons with gaps
            pc = fead['protein_codon'][0][-1]
            cod_anc = ''.join(p.initial_sequence[pos - pc:pos - pc + 3])
            if '-' in cod_anc:
                continue

            for ia, aft_nuc in enumerate(aft_pos[:4]):
                # Keep only derived alleles
                if alpha[ia] == p.initial_sequence[pos]:
                    continue

                # Keep only no RNA structures
                if fead['RNA']:
                    continue

                # Keep only sites which are also in the reference
                if pos not in comap.index:
                    continue

                # Keep only high-entropy sites
                S_pos = ref.entropy[comap.loc[pos]]
                if S_pos < entropy_threshold:
                    continue

                # Keep only synonymous alleles
                cod_new = cod_anc[:pc] + alpha[ia] + cod_anc[pc + 1:]
                if translate(cod_anc) != translate(cod_new):
                    continue

                mut = p.initial_sequence[pos] + '->' + alpha[ia]

                for it, (t, af_nuc) in enumerate(izip(times, aft_nuc)):
                    # Keep only nonmasked times
                    if aft_nuc.mask[it]:
                        continue

                    datum = {
                        'time': t,
                        'af': af_nuc,
                        'pos': pos,
                        'refpos': comap.loc[pos],
                        'protein': fead['protein_codon'][0][0],
                        'pcode': pcode,
                        'mut': mut,
                        'subtype': subtype,
                        'refname': refname,
                    }
                    data.append(datum)

    data = pd.DataFrame(data)

    return data
Ejemplo n.º 14
0
def plot_sweeps(data):
    '''Plot the sweeps and have a look'''
    ref = HIVreference(refname='HXB2', subtype='B', load_alignment=False)
    pos_genes = {
        genename: list(ref.annotation[genename].location)
        for genename in ['gag', 'pol', 'env']
    }
    palette = sns.color_palette("husl", 8)
    colors = {
        'gag': palette[1],
        'pol': palette[0],
        'env': palette[4],
        'other': 'grey'
    }
    zs = {'gag': 1, 'pol': 2, 'env': 0, 'other': 3}

    def get_color(pos_ref):
        '''Get color of line based on the gene'''
        for genename, pos_gene in pos_genes.iteritems():
            if pos_ref in pos_gene:
                return colors[genename]
        return colors['other']

    def get_z(pos_ref):
        '''Get z (depth) of line based on the gene'''
        for genename, pos_gene in pos_genes.iteritems():
            if pos_ref in pos_gene:
                return colors[genename]
        return colors['other']

    fig, axs = plt.subplots(3, 3, figsize=(9, 9))
    axs = axs.ravel()
    pcodes = sorted(data['pcode'].unique(), key=lambda x: int(x[1:]))
    for (pcode, data_pat) in data.groupby('pcode'):
        ip = pcodes.index(pcode)
        ax = axs[ip]
        for (pos_ref, mut), datum in data_pat.groupby(['pos_ref', 'mut']):
            x = np.array(datum['time'])
            x += 100. * pos_ref / 9000
            y = np.array(datum['af'])
            ax.plot(x,
                    y,
                    lw=2,
                    zorder=get_z(pos_ref),
                    color=get_color(pos_ref),
                    alpha=0.3)
        ax.set_ylim(0.001, 0.999)
        ax.set_title(pcode)
        if ip > 5:
            ax.set_xlabel('Time [days since EDI]')
        if (ip % 3) == 0:
            ax.set_ylabel('Frequency')
        for tick in ax.get_xticklabels():
            tick.set_rotation(30)
    ax = axs[-1]
    for genename in ['gag', 'pol', 'env']:  #, 'other']:
        ax.plot([0], color=colors[genename], label=genename)
    ax.legend(loc='center', fontsize=16)
    ax.set_axis_off()

    plt.tight_layout()

    plt.ion()
    plt.show()

    return fig
Ejemplo n.º 15
0
    av = process_average_allele_frequencies(data, regions, nbootstraps=0,nstates=20)
    combined_af = av['combined_af']
    combined_entropy = av['combined_entropy']
    minor_af = av['minor_af']

    # get association, calculate fitness costs
    associations = get_associations(regions)
    aa_mutation_rates, total_nonsyn_mutation_rates = calc_amino_acid_mutation_rates()
    selcoeff = {}
    for region in regions:
        s = fitness_costs_per_site(region, data, total_nonsyn_mutation_rates)
        s[s>1] = 1
        selcoeff[region] = s

    aa_ref = 'NL4-3'
    global_ref = HIVreference(refname=aa_ref, subtype=args.subtype)

    ### FIGURE 5
    fig,axs = plt.subplots(1,2, figsize=(10,5))
    #fitness_costs_in_optimal_epis(['gag', 'nef'], selcoeff, ax=axs[0])
    #add_panel_label(axs[0], 'A', x_offset=-0.15)
    plot_fraction_associated(regions, selcoeff, associations, axs=axs, slope=2.0)
    add_panel_label(axs[0], 'A', x_offset=-0.15)
    region='nef'
    reference = HIVreferenceAminoacid(region, refname=aa_ref, subtype = args.subtype)
    tmp, rho, pval =  fitness_scatter(region, selcoeff, associations, reference, ax=axs[0])
    add_panel_label(axs[1], 'B', x_offset=-0.15)
    axs[0].legend(loc=3, fontsize=fs)
    axs[0].set_ylim([0.03,3])
    plt.tight_layout()
    for fmt in ['pdf', 'png', 'svg']:
# Script
if __name__ == "__main__":

    parser = argparse.ArgumentParser(
        description='analyze relation of fitness costs and noncoding elements')
    parser.add_argument('--regenerate',
                        action='store_true',
                        help="regenerate data")
    parser.add_argument('--subtype',
                        choices=['B', 'any'],
                        default='B',
                        help='subtype to compare against')
    args = parser.parse_args()

    # NOTE: HXB2 alignment has way more sequences resulting in better correlations
    reference = HIVreference(refname='HXB2', subtype=args.subtype)
    genes = ['gag', 'nef', 'env', 'vif', 'pol', 'vpr', 'vpu']
    # Intermediate data are saved to file for faster access later on
    fn = '../data/fitness_pooled_noncoding/avg_noncoding_allele_frequency_st_' + args.subtype + '.pickle.gz'
    if not os.path.isfile(fn) or args.regenerate:
        if args.subtype == 'B':
            patient_codes = ['p2', 'p3', 'p5', 'p7', 'p8', 'p9', 'p10',
                             'p11']  # subtype B only
        else:
            patient_codes = [
                'p1', 'p2', 'p3', 'p5', 'p6', 'p7', 'p8', 'p9', 'p10', 'p11'
            ]  # all subtypes

        # gag and nef are loaded since they overlap with relevnat non-coding structures
        # and we need to know which positions have synonymous mutations
        data = collect_data(patient_codes, genes, reference, synnonsyn=True)
    return xka_q, tt

# Script
if __name__=="__main__":

    parser = argparse.ArgumentParser(description='Fitness cost')
    parser.add_argument('--quantiles', type=int, default=6,
                        help="Number of quantiles")
    parser.add_argument('--subtype', type=str, default='any',
                        help="subtype to compare against")
    args = parser.parse_args()


    gen_region = 'genomewide' #'pol' #'gag' #'pol' #'gp41' #'gp120' #'vif' #'RRE'

    ref = HIVreference(subtype=args.subtype)
    tmp = ref.get_entropy_quantiles(args.quantiles)
    Squant = {}
    qi1=0
    for qi in range(len(tmp)): # prune empty quantiles.
        if len(tmp[qi]['ind']):
            Squant[qi1]=tmp[qi]
            qi1+=1

    q=len(Squant)
    Smedians = [np.median(ref.entropy[Squant[i]['ind']]) for i in range(q)]


    patient_names = ['p1','p2','p5','p6','p8','p9','p11']
    # p3, p10 - excluded since probably infected by >1 virion