def plot_change_dists(plots,
                      drift_data_paths,
                      constant_data_paths,
                      both_data_paths,
                      pt_data_paths,
                      params_path,
                      original_phen_color="black",
                      original_phen_lw=4,
                      original_phen_linestyle="dashed",
                      **kwargs):

    data = get_data(drift_data_paths, constant_data_paths, both_data_paths,
                    pt_data_paths)

    params = pp.get_params(params_path)

    mut_mean = float(params.get("DEFAULT_CHAIN_MUT_MEAN"))
    mut_sd = float(params.get("DEFAULT_CHAIN_MUT_SD"))

    for condition in ['drift', 'constant', 'both', 'pt']:

        df = data[condition]

        # plot antigenic changes
        change_axis = plots['change-dist-' + condition + '-sim']
        change_color = 'gray' if condition == 'drift' else ps.inocs_color
        phen_bounds = [-0.35, 0.35]
        plot_antigenic_changes(data=df,
                               axis=change_axis,
                               mut_mean=mut_mean,
                               mut_sd=mut_sd,
                               phen_bounds=phen_bounds,
                               **kwargs)
        change_axis.set_xlabel("antigenic change")
        if change_axis.is_first_col():
            change_axis.set_ylabel("frequency")
        change_axis.set_xlim(phen_bounds)
        change_axis.grid(b=True)
        change_axis.axvline(0,
                            lw=original_phen_lw,
                            color=original_phen_color,
                            linestyle=original_phen_linestyle)

    plots["change-dist-drift-sim"].set_title("naive population\n")
    plots["change-dist-constant-sim"].set_title("immediate recall response\n")
    plots["change-dist-both-sim"].set_title("mucosal antibodies and\n"
                                            "immediate recall response")
    plots["change-dist-pt-sim"].set_title("mucosal antibodies and\n"
                                          "realstic (48h) recall response")
def plot_comparison(param_file,
                    model_name,
                    results_dir,
                    bottleneck=5,
                    axis=None,
                    killing=True,
                    label=False):
    """
    Compare simulated model results and 
    prediction from the analytical emergence_time_cdf
    """
    if axis is None:
        fig, axis = plt.subplots()

    params = pp.get_params(param_file)

    C_max = pp.get_param("C_max", model_name, params)
    R0 = pp.get_param("R0_wh", model_name, params)
    d_v = pp.get_param("d_v", model_name, params)
    mu = pp.get_param("mu", model_name, params)
    k = pp.get_param("k", model_name, params)

    data = get_data(results_dir, bottleneck)

    cmap = plt.cm.Reds
    t_Ms = []  #[0, 0.5, 1]
    n_cols = len(t_Ms)
    colors = np.linspace(0.4, 0.9, n_cols)

    if killing:
        bn = 1
    else:
        bn = bottleneck

    times = np.linspace(0, 2, 1000)
    sim_t_M = pp.get_param("t_M", model_name, params)
    c_w = 1
    c_m = 0

    probs = [
        emergence_time_cdf(time, mu, sim_t_M, R0, d_v, k, bottleneck, c_w, c_m)
        for time in times
    ]

    labs = [None, None]
    if label:
        labs = ['simulated', 'analytical']
    emerged = data[data['emergence_time'] < max(times)]['emergence_time']
    sns.distplot(emerged,
                 hist_kws=dict(cumulative=True),
                 kde=False,
                 hist=True,
                 norm_hist=True,
                 bins=np.arange(0, 2, 0.05),
                 ax=axis,
                 label=labs[0])
    axis.plot(times,
              probs,
              color="k",
              lw=ps.standard_lineweight,
              label=labs[1])
def plot_wh_timecourse(timecourse,
                       bottleneck,
                       results_dir,
                       wt_col,
                       mut_col,
                       cell_col = None,
                       col_colors = None,
                       detection_threshold = None,
                       detection_limit = None,
                       axis = None,
                       detection_color = None,
                       detection_linestyle = "dotted",
                       E_w = None,
                       t_M = None,
                       transmission_threshold = None,
                       non_trans_alpha = 1,
                       gen_param_file = None,
                       frequencies = True,
                       analytical_frequencies = True,
                       infer_t_M = False):

    if axis is None:
        fig, axis = plt.subplots()

    model_name = os.path.basename(results_dir)

    pattern =  "{}{}_{}".format(model_name, bottleneck, timecourse)

    print("Getting timecourse", pattern)

    data = get_timecourse(
        results_dir,
        pattern)
    
    params = pp.get_params(gen_param_file)
    param_names = {"mu": None,
                   "d_v": None,
                   "R0_wh": None,
                   "r_w": None,
                   "r_m": None,
                   "C_max":None,
                   "k":None,
                   "cross_imm":None,
                   "z_mut": None,
                   "t_M": None,
                   "t_N": None,
                   "emergence_threshold": None}
    
    params = {name: pp.get_param(name,
                                 model_name,
                                 params,
                                 default = default)
              for name, default in param_names.items()}

        
    print("plotting timecourse", pattern)

    if frequencies:
        plot_df_freqs(
            data,
            wt_col,
            mut_col,
            transmission_threshold,
            detection_limit = 1,
            ax = axis,
            col_colors = col_colors,
            non_trans_alpha = non_trans_alpha,
            label = False)
            
        
        if detection_threshold is not None:
            thresholds = np.repeat(detection_threshold, data.time.size)
            axis.axhline(detection_threshold,
                         color=detection_color,
                         linestyle=detection_linestyle)
        if analytical_frequencies:
            transmissible = (data[wt_col] + data[mut_col]) > transmission_threshold
            any_mutant = data[mut_col] > params["emergence_threshold"]
            if np.any(any_mutant):
                row_emerge = data[any_mutant].iloc[0]
                f_0 = (row_emerge["Vm"] /
                       (row_emerge["Vw"] + row_emerge["Vm"]))
                t_emerge = row_emerge["time"]


            else:
                row_emerge, t_emerge = None, None
                
            peak_row = data["Vw"].idxmax()
            t_peak = data["time"].iloc[peak_row]

            trans_f_m = [
                f_m(t_final = t,
                    t_M = t_M,
                    delta = params["k"] * E_w,
                    t_emerge = t_emerge,
                    R0 = params["R0_wh"],
                    d_v = params["d_v"],
                    t_peak = t_peak,
                    f_0 = f_0,
                    mu = params["mu"],
                    ongoing_mutate = True)
                if transmissible[index] else None
                for index, t in enumerate(data.time)]
            non_trans_f_m = [
                f_m(t_final = t,
                    t_M = t_M,
                    delta = params["k"] * E_w,
                    t_emerge = t_emerge,
                    R0 = params["R0_wh"],
                    d_v = params["d_v"],
                    t_peak = t_peak,
                    f_0 = f_0,
                    mu = params["mu"],
                    ongoing_mutate = True)
                if not transmissible[index] and
                any_mutant[index]
                else None
                for index, t in enumerate(data.time)]

            f_m_color = "black"
            analytic_lw = mpl.rcParams['lines.linewidth'] * 0.7
            axis.plot(data.time,
                      non_trans_f_m,
                      color = f_m_color,
                      linestyle = "dotted",
                      alpha = non_trans_alpha * 0.5,
                      lw = analytic_lw)
            axis.plot(data.time,
                      trans_f_m,
                      color = f_m_color,
                      linestyle = "dotted",
                      lw = analytic_lw)

    else:
        plot_df_timecourse(
            data,
            wt_col,
            mut_col,
            transmission_threshold,
            cell_col = cell_col,
            ax = axis,
            col_colors = col_colors,
            non_trans_alpha = non_trans_alpha,
            label = False)

    if t_M is not None:
        axis.axvspan(
            t_M,
            max(np.max(data.time), 50), 
            alpha=0.25 ,
            color="gray")

    axis.grid(b=True,
              which="major")
Esempio n. 4
0
def main(inoculum_mults=[1, 10, 25],
         bottleneck_sizes=[1, 5, 10],
         outpath='../../ms/main/figures/figure-escape-tradeoff.pdf',
         param_path=None,
         fineness=200,
         implication_z_m=None,
         inoculum_model="poisson"):
    nrows = len(bottleneck_sizes)
    ncols = len(inoculum_mults)
    width = 22
    height = (nrows / ncols) * width

    if param_path is None:
        param_path = "../../dat/RunParameters.mk"
    params = pp.get_params(param_path)

    f_mut = pp.get_param("f_mut", "", params)
    print("fmt: ", f_mut)

    fig, axes = plt.subplots(
        nrows=nrows,
        ncols=ncols,
        sharex=True,
        sharey=True,  # 'row',
        figsize=(width, height))

    full_plot = add_bounding_subplot(fig)
    xticklabs = ["" for tick in full_plot.get_xticks()]
    yticklabs = ["" for tick in full_plot.get_yticks()]
    full_plot.set_xticklabels(xticklabs)
    full_plot.set_yticklabels(yticklabs)
    full_plot.tick_params(pad=70)

    prob_ticks = [0, 0.25, 0.5, 0.75, 1]

    for i, bn in enumerate(bottleneck_sizes):
        for j, mult in enumerate(inoculum_mults):
            print("plotting final bottleneck b: {}, v: {}..."
                  "".format(bn, bn * mult))
            v = bn * mult

            plot_escape_tradeoff(
                f_mut=f_mut,
                axis=axes[i, j],
                bottleneck=bn,
                inoculum_size=bn * mult,
                cross_immunities=[0, 0.5, 0.9, 0.99],
                legend=False,
                label_ax=False,
                cross_imm_labels=["0\%", "50\%", "90\%", "99\%"],
                fineness=fineness,
                inoculum_model=inoculum_model)
            axes[i, j].label_outer()
            axes[i, j].set_xticks(prob_ticks)
            axes[i, j].set_yscale('log')
            axes[i, j].set_yticks([1e-6, 1e-5, 1e-4, 1e-3, 1e-2])
    leg = axes[0, 0].legend(ncol=2, handlelength=1)
    leg.set_title("sIgA cross immunity\n($\sigma = \kappa_m / \kappa_w$)")
    full_plot.set_ylabel("probability mutant present\n"
                         "after final bottleneck")
    full_plot.set_xlabel("wild-type neutralization "
                         "probability $\kappa_{w}$")
    fig.tight_layout()
    fig.savefig(outpath)
def main(drift_data_paths=None,
         constant_data_paths=None,
         pt_data_paths=None,
         both_data_paths=None,
         param_file=None,
         output_path=None):

    ## figure styling / setup
    width = 18.3
    mpl.rcParams['font.size'] = width / 1.5
    mpl.rcParams['lines.linewidth'] = width / 2.5
    kernel_lw = width / 4.5
    mpl.rcParams['ytick.major.pad'] = width / 2
    mpl.rcParams['xtick.major.pad'] = width / 2
    mpl.rcParams['legend.fontsize'] = width * 0.9
    mpl.rcParams['legend.title_fontsize'] = width
    mpl.rcParams['legend.handlelength'] = 2

    height = width / 2
    fig = plt.figure(figsize=(width, height))

    nrows = 2
    ncols = 4
    gs = gridspec.GridSpec(nrows, ncols)

    ## multipanel setup
    change_dist_sim_row = 0
    change_dist_ana_row = 1
    drift_col, const_col, both_col, pt_col = 0, 1, 2, 3

    plot_positions = [{
        "name": "change-dist-drift-sim",
        "grid_position": np.s_[change_dist_sim_row, drift_col],
        "sharex": None,
        "sharey": None
    }, {
        "name": "change-dist-constant-sim",
        "grid_position": np.s_[change_dist_sim_row, const_col],
        "sharex": "change-dist-drift-sim",
        "sharey": "change-dist-drift-sim"
    }, {
        "name": "change-dist-both-sim",
        "grid_position": np.s_[change_dist_sim_row, both_col],
        "sharex": "change-dist-drift-sim",
        "sharey": "change-dist-drift-sim"
    }, {
        "name": "change-dist-pt-sim",
        "grid_position": np.s_[change_dist_sim_row, pt_col],
        "sharex": "change-dist-drift-sim",
        "sharey": "change-dist-drift-sim"
    }, {
        "name": "change-dist-drift-ana",
        "grid_position": np.s_[change_dist_ana_row, drift_col],
        "sharex": "change-dist-drift-sim",
        "sharey": None
    }, {
        "name": "change-dist-constant-ana",
        "grid_position": np.s_[change_dist_ana_row, const_col],
        "sharex": "change-dist-drift-ana",
        "sharey": "change-dist-drift-ana"
    }, {
        "name": "change-dist-both-ana",
        "grid_position": np.s_[change_dist_ana_row, both_col],
        "sharex": "change-dist-drift-ana",
        "sharey": "change-dist-drift-ana"
    }, {
        "name": "change-dist-pt-ana",
        "grid_position": np.s_[change_dist_ana_row, pt_col],
        "sharex": "change-dist-drift-ana",
        "sharey": "change-dist-drift-ana"
    }]

    letter_loc = (-0.1, 1.15)
    plots = setup_multipanel(fig,
                             plot_positions,
                             letter_loc=letter_loc,
                             gridspec=gs)

    ## parametrization
    params = pp.get_params(param_file)
    b = int(params.get("DEFAULT_CHAIN_BOTTLENECK"))
    v = int(float(params.get("DEFAULT_VB_RATIO")) * b)
    f_mut = float(params.get("DEFAULT_F_MUT"))
    z_wt = float(params.get("DEFAULT_Z_WT"))
    k = float(params.get("CONSTANT_CHAIN_K"))
    mu = float(params.get("DEFAULT_MU"))
    d_v = float(params.get("DEFAULT_D_V"))
    R0 = float(params.get("DEFAULT_R0_WH"))
    mut_sd = float(params.get("DEFAULT_CHAIN_MUT_SD"))
    mut_mu = float(params.get("DEFAULT_CHAIN_MUT_MEAN"))

    print("parsed parameter file: v = {}, b = {}".format(v, b))
    print("parsed parameter file: f_mut = {}".format(f_mut))
    print("parsed parameter file: z_wt = {}".format(z_wt))

    #####################################
    # simulated kernel shifts
    #####################################

    print("plotting simulated kernel shifts...")
    plot_change_dists(plots,
                      drift_data_paths,
                      constant_data_paths,
                      both_data_paths,
                      pt_data_paths,
                      param_file,
                      kernel_lw=kernel_lw,
                      kernel_color=ps.kernel_color,
                      kernel_alpha=ps.kernel_alpha,
                      kernel_linestyle=ps.kernel_linestyle,
                      dash_capstyle='round')

    #####################################
    # analytical kernel shifts
    #####################################
    print("plotting analytical kernel shifts...")
    ana_phenotypes = np.linspace(-0.35, 0.35, 101)

    ana_gen_phen = -0.8  # least immune history
    ana_recip_phen = -0.8  # from the sim model

    kern_escape = 1
    kern_model = 'linear'

    print("no immunity...")
    plot_kernel_shift_repl(axis=plots['change-dist-drift-ana'],
                           k=0,
                           mu=mu,
                           d_v=d_v,
                           R0=R0,
                           bottleneck=b,
                           t_M=0,
                           t_transmit=2,
                           vb_ratio=v / b,
                           sd_mut=mut_sd,
                           phenotypes=ana_phenotypes,
                           generator_phenotype=-99,
                           z_homotypic=0.95,
                           recipient_phenotype=-99,
                           susceptibility_model=kern_model,
                           escape=kern_escape,
                           dist_color=ps.inocs_color,
                           kernel_color=ps.kernel_color,
                           kernel_alpha=ps.kernel_alpha,
                           kernel_lw=kernel_lw,
                           kernel_linestyle=ps.kernel_linestyle,
                           dist_alpha=ps.dist_alpha,
                           dash_capstyle='round',
                           mark_original_phenotype=True)

    print("constant recall response...")
    plot_kernel_shift_repl(axis=plots['change-dist-constant-ana'],
                           k=k,
                           mu=mu,
                           d_v=d_v,
                           R0=R0,
                           bottleneck=b,
                           t_M=0,
                           t_transmit=2,
                           vb_ratio=v / b,
                           sd_mut=mut_sd,
                           phenotypes=ana_phenotypes,
                           generator_phenotype=ana_gen_phen,
                           z_homotypic=0.95,
                           recipient_phenotype=-99,
                           susceptibility_model=kern_model,
                           escape=kern_escape,
                           dist_color=ps.inocs_color,
                           kernel_color=ps.kernel_color,
                           kernel_alpha=ps.kernel_alpha,
                           kernel_linestyle=ps.kernel_linestyle,
                           kernel_lw=kernel_lw,
                           dist_alpha=ps.dist_alpha,
                           dash_capstyle='round',
                           mark_original_phenotype=True)

    print("constant recall response with mucosal antibodies...")
    plot_kernel_shift_repl(axis=plots['change-dist-both-ana'],
                           k=k,
                           mu=mu,
                           d_v=d_v,
                           R0=R0,
                           bottleneck=b,
                           t_M=0,
                           t_transmit=2,
                           vb_ratio=v / b,
                           sd_mut=mut_sd,
                           phenotypes=ana_phenotypes,
                           generator_phenotype=ana_gen_phen,
                           z_homotypic=0.95,
                           recipient_phenotype=ana_recip_phen,
                           susceptibility_model=kern_model,
                           escape=kern_escape,
                           dist_color=ps.inocs_color,
                           kernel_color=ps.kernel_color,
                           kernel_alpha=ps.kernel_alpha,
                           kernel_linestyle=ps.kernel_linestyle,
                           kernel_lw=kernel_lw,
                           dist_alpha=ps.dist_alpha,
                           dash_capstyle='round',
                           mark_original_phenotype=True)

    print("realistic recall response with mucosal antibodies...")
    plot_kernel_shift_repl(axis=plots['change-dist-pt-ana'],
                           k=k,
                           mu=mu,
                           d_v=d_v,
                           R0=R0,
                           bottleneck=b,
                           t_M=2,
                           t_transmit=2,
                           vb_ratio=v / b,
                           sd_mut=mut_sd,
                           phenotypes=ana_phenotypes,
                           generator_phenotype=ana_gen_phen,
                           z_homotypic=0.95,
                           recipient_phenotype=ana_recip_phen,
                           susceptibility_model=kern_model,
                           escape=kern_escape,
                           dist_color=ps.inocs_color,
                           kernel_color=ps.kernel_color,
                           kernel_alpha=ps.kernel_alpha,
                           kernel_lw=kernel_lw,
                           kernel_linestyle=ps.kernel_linestyle,
                           dist_alpha=ps.dist_alpha,
                           dash_capstyle='round',
                           mark_original_phenotype=True)

    #####################################
    # plot styling
    #####################################

    for plotname, plot in plots.items():
        if 'ana' in plotname:
            plot.set_xlabel('antigenic change')
        if 'drift-sim' in plotname:
            plot.set_ylabel('frequency')
        if 'drift-ana' in plotname:
            plot.set_ylabel('probability density')
        plot.set_ylim(bottom=0)
        plot.label_outer()

    fig.tight_layout()

    # save
    fig.savefig(output_path)

    return 0
def main(output_path=None,
         results_dir=None,
         gen_param_file=None,
         bottleneck=200):
    if results_dir is None:
        results_dir = "../../out/within_host_results/minimal_sterilizing"
    if output_path is None:
        output_path = "../../ms/main/figures/figure-wh-sterilizing-timecourse.pdf"
    if gen_param_file is None:
        gen_param_file = "../../dat/RunParameters.mk"

    fig, plots = setup_figure()

    ## read in parameters
    params = pp.get_params(gen_param_file)
    detection_threshold = float(params["DEFAULT_DETECTION_THRESHOLD"])
    transmission_threshold = float(params["DEFAULT_TRANS_THRESHOLD"])
    detection_limit = float(params["DEFAULT_DETECTION_LIMIT"])
    f_mut_default = float(params["DEFAULT_F_MUT"])

    model_name = 'minimal_sterilizing'

    R0_wh = pp.get_param("R0_wh", model_name, params)
    C_max = pp.get_param("C_max", model_name, params)
    r_w = pp.get_param("r_w", model_name, params)
    r_m = pp.get_param("r_m", model_name, params)
    mu = pp.get_param("mu", model_name, params)
    d_v = pp.get_param("d_v", model_name, params)
    k = pp.get_param("k", model_name, params)
    cross_imm = pp.get_param("cross_imm", model_name, params)
    t_M = pp.get_param("t_M", model_name, params)
    t_N = pp.get_param("t_N", model_name, params)
    p_loss = pot.minimal_p_loss(3e-5, R0_wh, 1)
    print("p loss:", p_loss)

    # set styling

    wt_col = "Vw"
    mut_col = "Vm"
    cell_col = "C"
    col_colors = [ps.wt_color, ps.mut_color, ps.cell_color]
    detection_linestyle = "dashed"
    detection_color = "black"
    non_trans_alpha = 0.4

    sims_to_plot = {
        'sterilized': 'repl_invisible',
        'replication-selected': 'repl_visible',
        'inoculation-selected': 'inoc_visible'
    }

    for sim_type, timecourse in sims_to_plot.items():
        for plot_freq, text in zip([False, True], ["abs", "freq"]):
            plot_axis_name = "{}-{}".format(sim_type, text)
            plot_axis = plots[plot_axis_name]

            print("plotting {}...".format(sim_type))
            plot_wh_timecourse(timecourse,
                               bottleneck,
                               results_dir,
                               wt_col,
                               mut_col,
                               col_colors=col_colors,
                               cell_col=cell_col,
                               detection_threshold=detection_threshold,
                               detection_color=detection_color,
                               detection_limit=detection_limit,
                               detection_linestyle=detection_linestyle,
                               transmission_threshold=transmission_threshold,
                               axis=plot_axis,
                               t_M=t_M,
                               E_w=True,
                               gen_param_file=gen_param_file,
                               non_trans_alpha=non_trans_alpha,
                               frequencies=plot_freq,
                               analytical_frequencies=True)
            plot_axis.set_xlim([0, 8])
            plot_axis.set_xticks([0, 2, 4, 6, 8])
            if plot_freq:
                plot_axis.set_ylim([1e-6, 2])
            if ("sterilized" not in plot_axis_name):
                for label in plots[plot_axis_name].get_yticklabels():
                    label.set_visible(False)

            if "abs" in plot_axis_name:
                for label in plots[plot_axis_name].get_xticklabels():
                    label.set_visible(False)

    dat = asc.get_bottleneck_data(results_dir)
    f_muts = asc.get_mean_fmut(dat)
    bottlenecks = np.sort(pd.unique(dat.bottleneck))

    analytic_p_inocs = np.array(
        [calc_p_inocs(float(f_muts.at[bn]), bn, R0_wh) for bn in bottlenecks])

    analytic_p_repls = np.array([
        whp.p_repl_declining(bn, mu, R0_wh, d_v, k, 1, cross_imm)
        for bn in bottlenecks
    ])

    analytic_p_repls = analytic_p_repls * (1 - analytic_p_inocs)

    analytic_p_wts = np.zeros_like(analytic_p_repls)

    analytic_p_elses = 1 - analytic_p_inocs - analytic_p_repls

    ## extract needed simulation info
    ## and get to tidy format for plotting

    print('analytical p repl:')
    print(analytic_p_repls)

    print('raw simulated repl events:')
    print(dat.groupby('bottleneck').apply(calc_raw_repl))

    print('simulated repl probs:')
    print(dat.groupby('bottleneck').apply(calc_repl))

    results = dat.groupby('bottleneck').apply(calc_apply)

    results = pd.melt(results.reset_index(),
                      id_vars='bottleneck',
                      var_name='outcome',
                      value_name='frequency')

    my_palette = ['grey', ps.repl_color, ps.inocs_color, ps.wt_color]

    analyticals = [
        analytic_p_elses, analytic_p_repls, analytic_p_inocs, analytic_p_wts
    ]

    for outcome, color, analytical in zip(possible_outcomes, my_palette,
                                          analyticals):
        df = results[results.outcome == outcome]
        plots['proportion-plot'].plot(df.bottleneck,
                                      df.frequency * denom,
                                      marker='o',
                                      markersize=20,
                                      linestyle="",
                                      markeredgecolor='k',
                                      color=color,
                                      alpha=prob_alpha,
                                      label=outcome)

        plots['proportion-plot'].plot(bottlenecks,
                                      analytical * denom,
                                      marker='+',
                                      markeredgewidth=5,
                                      markersize=20,
                                      linestyle="",
                                      alpha=1,
                                      color=color)

    # style resulting plots
    plots['proportion-plot'].set_yscale('symlog')
    plots['proportion-plot'].set_title('distribution of outcomes\n')
    plots['proportion-plot'].legend(loc='center left', bbox_to_anchor=(0, 0.7))
    plots['proportion-plot'].set_xscale('log')
    plots['proportion-plot'].set_xlabel('bottleneck')
    plots['proportion-plot'].set_ylabel('frequency per $10^{}$ inoculations'
                                        ''.format(denom_exponent))

    plots['sterilized-abs'].set_ylabel('virions, cells')
    plots['sterilized-freq'].set_ylabel('frequency')

    plots['sterilized-abs'].set_title('no detectable infection\n')

    plots['replication-selected-abs'].set_title(
        'detectable infection\nwith de novo new variant')

    plots['inoculation-selected-abs'].set_title(
        'detectable infection\nwith inoculated new variant')

    plots['all-timecourses'].set_xlabel('time (days)')
    plots['proportion-plot'].set_xticks(bottlenecks)
    plots['proportion-plot'].set_xticklabels(bottlenecks)

    fig.tight_layout()
    fig.savefig(output_path)
Esempio n. 7
0
def main(results_dir=None, param_path=None, output_path=None):

    width = 18.3
    aspect = 1
    height = width * aspect
    figsize = (width, height)
    legend_markersize = width

    if results_dir is None:
        results_dir = "../../out/sensitivity_analysis_results/minimal_visvar"

    if param_path is None:
        param_path = "../../dat/RunParameters.mk"

    if output_path is None:
        output_path = ("../../ms/supp/"
                       "figs/figure-sensitivity-scatter-visvar.pdf")

    parm_dict = pp.get_params(param_path)
    model_name = os.path.basename(results_dir)

    indiv = get_tidy_data(results_dir)
    extra_columns = ['paramset_id', 'niter', 'outcome', 'frequency', 'ratio']
    parms = [col for col in indiv.columns if col not in extra_columns]
    n_parms = len(parms)

    n_cols = 4
    n_rows = int(np.ceil(n_parms / n_cols))
    n_axes = n_rows * n_cols

    fig, ax = plt.subplots(nrows=n_rows,
                           ncols=n_cols,
                           figsize=figsize,
                           sharey=True)
    for ind, parm in enumerate(parms):
        axis = ax.flatten()[ind]

        if parm != 'bottleneck':
            xlim = [
                pp.get_param(param_name=parm + "_MIN",
                             model_name=model_name,
                             param_dict=parm_dict),
                pp.get_param(param_name=parm + "_MAX",
                             model_name=model_name,
                             param_dict=parm_dict)
            ]
            parm_scatter(indiv, parm, 'mut-inf', axis=axis, xlims=xlim)

        else:
            xlim = None
            parm_striplot(indiv,
                          parm,
                          'mut-inf',
                          axis=axis,
                          xlims=xlim,
                          jitter=0.2)

        axis.set_xlabel(ps.parm_display_names.get(parm, parm))
        axis.set_yscale('symlog')
        axis.set_ylim([-0.1, 110])

    # legend in an unused axis
    leg_ax = ax.flatten()[n_parms]
    labels = [
        'de novo new variants\nmore common',
        'inoculated new variants\nmore common'
    ]

    legend_elements = [
        mpl.lines.Line2D([0], [0],
                         marker='o',
                         color=color,
                         alpha=observation_alpha,
                         lw=0,
                         markersize=legend_markersize,
                         markeredgecolor='k',
                         label=lab) for color, lab in zip(palette, labels)
    ]
    leg_ax.legend(handles=legend_elements)
    leg_ax.spines['top'].set_color('none')
    leg_ax.spines['bottom'].set_color('none')
    leg_ax.spines['left'].set_color('none')
    leg_ax.spines['right'].set_color('none')
    leg_ax.grid(b=False)
    leg_ax.tick_params(labelcolor='w',
                       grid_alpha=0,
                       top=False,
                       bottom=False,
                       left=False,
                       right=False)

    # delete unused axes
    for i_ax in range(n_parms + 1, n_axes):
        fig.delaxes(ax.flatten()[i_ax])

    # style axes
    for axis in ax[:, 0].flatten():
        axis.set_ylabel('new variant infections\n'
                        'per 100 detectable infections')
        axis.yaxis.set_major_formatter(ScalarFormatter())
    for axis in ax[:, 1:].flatten():
        plt.setp(axis.get_yticklabels(), visible=False)

    fig.tight_layout()
    fig.savefig(output_path)
def main(param_file=None, output_path=None):

    ## figure styling / setup
    width = 18.3
    mpl.rcParams['font.size'] = width / 1.5
    mpl.rcParams['lines.linewidth'] = width / 2.5
    kernel_lw = width / 4.5
    mpl.rcParams['ytick.major.pad'] = width / 2
    mpl.rcParams['xtick.major.pad'] = width / 2
    mpl.rcParams['legend.fontsize'] = width * 0.9
    mpl.rcParams['legend.title_fontsize'] = width
    mpl.rcParams['legend.handlelength'] = 2

    height = width / 3
    fig = plt.figure(figsize=(width, height))

    nrows = 1
    ncols = 3
    gs = gridspec.GridSpec(nrows, ncols)

    ## multipanel setup
    pop_row = 0
    e_rate_col, reinoc_col, epi_escape_col = 0, 1, 2

    plot_positions = [{
        "name": "emergence-rate",
        "grid_position": np.s_[pop_row, e_rate_col],
        "sharex": None,
        "sharey": None
    }, {
        "name": "reinoculations",
        "grid_position": np.s_[pop_row, reinoc_col],
        "sharex": None,
        "sharey": None
    }, {
        "name": "epi-escape",
        "grid_position": np.s_[pop_row, epi_escape_col],
        "sharex": None,
        "sharey": None
    }]

    letter_loc = (-0.1, 1.15)
    plots = setup_multipanel(fig,
                             plot_positions,
                             letter_loc=letter_loc,
                             gridspec=gs)

    ## parametrization
    params = pp.get_params(param_file)
    b = int(params.get("DEFAULT_BOTTLENECK"))
    v = int(float(params.get("DEFAULT_VB_RATIO")) * b)
    f_mut = float(params.get("DEFAULT_F_MUT"))
    z_wt = float(params.get("DEFAULT_Z_WT"))
    k = float(params.get("CONSTANT_CHAIN_K"))
    mu = float(params.get("DEFAULT_MU"))
    d_v = float(params.get("DEFAULT_D_V"))
    R0 = float(params.get("DEFAULT_R0_WH"))

    v_mult = 5
    epi_sigma = 0.75
    escape = 1

    print("parsed parameter file: v = {}, b = {}".format(v, b))
    print("parsed parameter file: f_mut = {}".format(f_mut))
    print("parsed parameter file: z_wt = {}".format(z_wt))

    pop_R0s = [1.5, 2, 2.5]
    print("Plotting reinoculations vs population immunity...")
    plot_reinoculations(axis=plots["reinoculations"],
                        R0s=pop_R0s,
                        cmap=plt.cm.Greens,
                        leg_cmap=plt.cm.Greys)

    print("Plotting new chains vs population immunity...")
    plot_epidemic_p_inocs(axis=plots["epi-escape"],
                          f_mut=f_mut,
                          bottleneck=b,
                          cross_imm_sigma=epi_sigma,
                          v=v,
                          population_R0s=pop_R0s,
                          escape=escape,
                          leg_cmap=plt.cm.Greys,
                          z_wt=z_wt,
                          legend=True)

    print("Plotting new chains vs population immunity with v = {}..."
          "".format(v * v_mult))

    plot_epidemic_p_inocs(axis=plots["epi-escape"],
                          cross_imm_sigma=epi_sigma,
                          f_mut=f_mut,
                          bottleneck=b,
                          v=v * v_mult,
                          population_R0s=pop_R0s,
                          cmap=plt.cm.Purples,
                          linestyle='dotted',
                          escape=escape,
                          z_wt=z_wt,
                          legend=False)

    print("Plotting emergence vs population immunity...")
    mut_neut_prob_list = [0.75, 0.9, 0.99]

    plot_inoculation_selection_vs_immunity(axis=plots["emergence-rate"],
                                           f_mut=f_mut,
                                           mut_neut_probs=mut_neut_prob_list,
                                           wt_neut_prob=neutralize_prob_from_z(
                                               z_wt, v, "poisson"),
                                           final_bottleneck=b,
                                           mucus_bottleneck=v,
                                           cmap=plt.cm.Reds,
                                           linestyle="solid",
                                           legend=True,
                                           leg_cmap=plt.cm.Greys,
                                           inoculum_model="poisson",
                                           bottleneck_model="binomial")

    plot_inoculation_selection_vs_immunity(axis=plots["emergence-rate"],
                                           f_mut=f_mut,
                                           mut_neut_probs=mut_neut_prob_list,
                                           wt_neut_prob=neutralize_prob_from_z(
                                               z_wt, v * v_mult, "poisson"),
                                           final_bottleneck=b,
                                           mucus_bottleneck=v * v_mult,
                                           cmap=plt.cm.Purples,
                                           linestyle="dotted",
                                           legend=False,
                                           inoculum_model="poisson",
                                           bottleneck_model="binomial")

    #####################################
    # plot styling
    #####################################
    imm_level_text = "population fraction immune\nto wild-type"
    emergence_y_text = "new variant infections\nper inoculation"
    frac_ticks = [0, 0.25, 0.5, 0.75, 1]
    frac_tick_labs = ["$0$", "$0.25$", "$0.5$", "$0.75$", "$1$"]

    plots["reinoculations"].set_xlabel(imm_level_text)
    plots["reinoculations"].set_ylabel("per capita reinoculations")
    plots["epi-escape"].set_xlabel(imm_level_text)
    plots["epi-escape"].set_ylabel("per capita probability\n"
                                   "of new variant infection")
    plots["emergence-rate"].set_xlabel(imm_level_text)
    plots["emergence-rate"].set_ylabel(emergence_y_text)
    plots["epi-escape"].set_ylim([0, 3e-4])

    for plot in plots.values():
        plot.set_xlim(left=0, right=1)
        plot.set_ylim(bottom=0)
        plot.set_xticks(frac_ticks)
        plot.set_xticklabels(frac_tick_labs)

    for plot in plots.values():
        if not type(plot.yaxis._scale) == mpl.scale.LogScale:
            plot.set_ylim(bottom=0)

    plt.tight_layout()

    # save
    fig.savefig(output_path)

    return 0
Esempio n. 9
0
def plot_presentation_figures():
    non_trans_alpha = 0.4
    lineweight = 10
    mpl.rcParams['lines.linewidth'] = lineweight
    gen_param_file = "../../dat/RunParameters.mk"
    params = pp.get_params(gen_param_file)
    detection_threshold = float(params["DEFAULT_DETECTION_THRESHOLD"])
    transmission_threshold = float(params["DEFAULT_TRANS_THRESHOLD"])
    detection_limit = float(params["DEFAULT_DETECTION_LIMIT"])

    R0_wh = pp.get_param("R0_wh", "minimal_visvar", params)
    C_max = pp.get_param("C_max", "minimal_visvar", params)
    r_w = pp.get_param("r_w", "minimal_visvar", params)
    r_m = pp.get_param("r_m", "minimal_visvar", params)
    mu = pp.get_param("mu", "minimal_visvar", params)
    d_v = pp.get_param("d_v", "minimal_visvar", params)
    t_M = pp.get_param("t_M", "minimal_visvar", params)
    t_N = pp.get_param("t_N", "minimal_visvar", params)
    bottleneck = 1

    fixed_results_dir = "../../out/within_host_results/minimal_fixed"
    var_results_dir = "../../out/within_host_results/minimal_visvar"
    wt_col = "Vw"
    mut_col = "Vm"
    cell_col = "C"
    col_colors = [ps.wt_color, ps.mut_color, ps.cell_color]
    detection_linestyle = "dashed"
    detection_color = "black"

    sims_to_plot = {
        "naive": {
            "results-dir": var_results_dir,
            "timecourse": "naive",
            "activation-time": t_N,
            "E_w": False
        },
        "fixed": {
            "results-dir": fixed_results_dir,
            "timecourse": "repl_visible",
            "activation-time": 0,
            "E_w": True
        },
        "var": {
            "results-dir": var_results_dir,
            "timecourse": "repl_visible",
            "activation-time": t_M,
            "E_w": True
        },
        "inoc": {
            "results-dir": var_results_dir,
            "timecourse": "inoc_visible",
            "activation-time": t_M,
            "E_w": True
        }
    }

    for sim_type, metadata in sims_to_plot.items():
        fig, ax = plt.subplots(1, 2, figsize=(10, 5))
        for ind, plot_freq in enumerate([False, True]):
            plot_axis = ax[ind]
            print("plotting {}...".format(sim_type))
            plot_wh_timecourse(metadata["timecourse"],
                               bottleneck,
                               metadata["results-dir"],
                               wt_col,
                               mut_col,
                               col_colors=col_colors,
                               cell_col=cell_col,
                               detection_threshold=detection_threshold,
                               detection_color=detection_color,
                               detection_limit=detection_limit,
                               detection_linestyle=detection_linestyle,
                               transmission_threshold=transmission_threshold,
                               axis=plot_axis,
                               t_M=metadata["activation-time"],
                               E_w=metadata["E_w"],
                               gen_param_file=gen_param_file,
                               non_trans_alpha=non_trans_alpha,
                               frequencies=plot_freq,
                               analytical_frequencies=True)
            plot_axis.set_xlim([0, 8])
            plot_axis.set_xticks([0, 2, 4, 6, 8])
            if plot_freq:
                plot_axis.set_ylim([1e-6, 2])
        ax[0].set_ylabel("virions, cells")
        ax[1].set_ylabel("variant frequency")
        ax[0].set_xlabel("time (days)")
        ax[1].set_xlabel("time (days)")
        fig.tight_layout()
        plotname = 'timecourse_{}.pdf'.format(sim_type)
        fig.savefig('../../cons/fourth-year-talk/' + plotname)
Esempio n. 10
0
def main(output_path=None,
         fixed_results_dir=None,
         var_results_dir=None,
         gen_param_file=None,
         empirical_data_file=None,
         bottleneck=1,
         heatmap_bottleneck=1,
         increment=0.05):

    if fixed_results_dir is None:
        fixed_results_dir = "../../out/within_host_results/minimal_visible"
    if var_results_dir is None:
        var_results_dir = "../../out/within_host_results/minimal_visvar"
    if output_path is None:
        output_path = "../../ms/main/figures/figure-wh-dynamics-summary.pdf"
    if gen_param_file is None:
        gen_param_file = "../../dat/RunParameters.mk"

    if empirical_data_file is None:
        empirical_data_file = "../../dat/cleaned/cleaned_wh_data.csv"

    aspect = .9
    width = 18.3
    height = width * aspect

    lineweight = width / 2.5
    mpl.rcParams['lines.linewidth'] = lineweight
    mpl.rcParams['font.size'] = width
    mpl.rcParams['legend.fontsize'] = "small"

    non_trans_alpha = 0.4

    params = pp.get_params(gen_param_file)
    detection_threshold = float(params["DEFAULT_DETECTION_THRESHOLD"])
    detection_limit = float(params["DEFAULT_DETECTION_LIMIT"])
    td_50 = float(params["DEFAULT_TD50"])
    transmission_cutoff = float(params["DEFAULT_TRANS_CUTOFF"])

    transmission_threshold = (-td_50 * np.log(1 - transmission_cutoff) /
                              np.log(2))

    print(transmission_threshold)

    R0_wh = pp.get_param("R0_wh", "minimal_visvar", params)
    C_max = pp.get_param("C_max", "minimal_visvar", params)
    r_w = pp.get_param("r_w", "minimal_visvar", params)
    r_m = pp.get_param("r_m", "minimal_visvar", params)
    mu = pp.get_param("mu", "minimal_visvar", params)
    d_v = pp.get_param("d_v", "minimal_visvar", params)
    t_M = pp.get_param("t_M", "minimal_visvar", params)
    t_N = pp.get_param("t_N", "minimal_visvar", params)

    max_k = 8
    max_t_M = 3.5
    heatmap_t_final = 3

    # set up multi-panel figure
    fig = plt.figure(figsize=(width, height))

    n_cols = 4
    n_rows = 2
    height_ratios = [2, 3.5]
    ly2_inc = 0.07
    ly1 = 1 + 3 * ly2_inc / 2
    ly2 = 1 + ly2_inc
    lx = -0.05
    row_1_loc = (lx, ly1)
    row_2_loc = (lx, ly2)

    gs = gridspec.GridSpec(n_rows, n_cols, height_ratios=height_ratios)
    all_heatmaps = add_bounding_subplot(fig, position=gs[0, 2:])

    all_plots = add_bounding_subplot(fig, position=gs[:, :])
    all_timecourses = add_bounding_subplot(fig, position=gs[1, :])

    plot_positions = [{
        "name": "empirical-detect",
        "grid_position": np.s_[0, 0],
        "sharex": None,
        "sharey": None,
        "letter_loc": row_1_loc
    }, {
        "name": "empirical-hist",
        "grid_position": np.s_[0, 1],
        "sharex": None,
        "sharey": None,
        "letter_loc": row_1_loc
    }, {
        "name": "one-percent-heatmap",
        "grid_position": np.s_[0, 2],
        "sharex": None,
        "sharey": None,
        "letter_loc": row_1_loc
    }, {
        "name": "consensus-heatmap",
        "grid_position": np.s_[0, 3],
        "sharex": "one-percent-heatmap",
        "sharey": "one-percent-heatmap",
        "letter_loc": row_1_loc
    }, {
        "name": "naive",
        "grid_position": np.s_[1, 0],
        "sharex": None,
        "sharey": None,
        "letter_loc": row_2_loc
    }, {
        "name": "fixed",
        "grid_position": np.s_[1, 1],
        "sharex": None,
        "sharey": None,
        "letter_loc": row_2_loc
    }, {
        "name": "var",
        "grid_position": np.s_[1, 2],
        "sharex": None,
        "sharey": None,
        "letter_loc": row_2_loc
    }, {
        "name": "inoc",
        "grid_position": np.s_[1, 3],
        "sharex": None,
        "sharey": None,
        "letter_loc": row_2_loc
    }]

    plots = setup_multipanel(fig, plot_positions, gridspec=gs)
    fig.tight_layout(rect=(0, 0.1, 1, 1), w_pad=-0.5)

    for timecourse in ['naive', 'fixed', 'var', 'inoc']:
        inner = gridspec.GridSpecFromSubplotSpec(
            2, 1, hspace=0.15, subplot_spec=plots[timecourse])
        plots[timecourse + '-abs'] = plt.Subplot(fig, inner[0])
        plots[timecourse + '-freq'] = plt.Subplot(
            fig,
            inner[1],
            sharex=plots.get('naive-freq', None),
            sharey=plots.get('naive-freq', None))
        fig.add_subplot(plots[timecourse + '-abs'])
        fig.add_subplot(plots[timecourse + '-freq'])
        ax = plots[timecourse]
        ax.spines['top'].set_color('none')
        ax.spines['bottom'].set_color('none')
        ax.spines['left'].set_color('none')
        ax.spines['right'].set_color('none')
        ax.grid(b=False)
        ax.patch.set_alpha(0)
        ax.tick_params(labelcolor='w',
                       grid_alpha=0,
                       top=False,
                       bottom=False,
                       left=False,
                       right=False)
        ax.set_zorder(0)

    wt_col = "Vw"
    mut_col = "Vm"
    cell_col = "C"
    col_colors = [ps.wt_color, ps.mut_color, ps.cell_color]
    detection_linestyle = "dashed"
    detection_color = "black"

    empirical_data = pd.read_csv(empirical_data_file)
    plot_wh_ngs(empirical_data,
                axis=plots['empirical-detect'],
                min_freq=0.01,
                legend=True,
                edgecolor="k")

    plot_ngs_hist(empirical_data,
                  axis=plots['empirical-hist'],
                  min_freq=0,
                  legend=True)

    sims_to_plot = {
        "naive": {
            "results-dir": var_results_dir,
            "timecourse": "naive",
            "activation-time": t_N,
            "E_w": False
        },
        "fixed": {
            "results-dir": fixed_results_dir,
            "timecourse": "repl_visible",
            "activation-time": 0,
            "E_w": True
        },
        "var": {
            "results-dir": var_results_dir,
            "timecourse": "repl_visible",
            "activation-time": t_M,
            "E_w": True
        },
        "inoc": {
            "results-dir": var_results_dir,
            "timecourse": "inoc_visible",
            "activation-time": t_M,
            "E_w": True
        }
    }

    for sim_type, metadata in sims_to_plot.items():
        for plot_freq, text in zip([False, True], ["abs", "freq"]):
            plot_axis_name = "{}-{}".format(sim_type, text)
            plot_axis = plots[plot_axis_name]
            plot_wh_timecourse(metadata["timecourse"],
                               bottleneck,
                               metadata["results-dir"],
                               wt_col,
                               mut_col,
                               col_colors=col_colors,
                               cell_col=cell_col,
                               detection_threshold=detection_threshold,
                               detection_color=detection_color,
                               detection_limit=detection_limit,
                               detection_linestyle=detection_linestyle,
                               transmission_threshold=transmission_threshold,
                               axis=plot_axis,
                               t_M=metadata["activation-time"],
                               E_w=metadata["E_w"],
                               gen_param_file=gen_param_file,
                               non_trans_alpha=non_trans_alpha,
                               frequencies=plot_freq,
                               analytical_frequencies=True)
            plot_axis.set_xlim([0, 8])
            plot_axis.set_xticks([0, 2, 4, 6, 8])
            if plot_freq:
                plot_axis.set_ylim([1e-6, 2])

            ## remove inner labels
            if "naive" not in plot_axis_name:
                for label in plot_axis.get_yticklabels():
                    label.set_visible(False)
            if "abs" in plot_axis_name:
                for label in plot_axis.get_xticklabels():
                    label.set_visible(False)

    R0_wh = pp.get_param("R0_wh", "minimal_visvar", params)
    d_v = pp.get_param("d_v", "minimal_visvar", params)

    min_prob = plot_heatmap(axis=plots["consensus-heatmap"],
                            contour_linewidth=(width / 3) * 0.5,
                            contour_fontsize="large",
                            increment=increment,
                            mu=mu,
                            f_target=0.5,
                            t_final=heatmap_t_final,
                            max_t_M=max_t_M,
                            max_k=max_k,
                            R0=R0_wh,
                            c_w=1,
                            c_m=0,
                            d_v=d_v,
                            bottleneck=heatmap_bottleneck,
                            contour_levels=[1e-5, 1e-3, 1e-1],
                            cbar=True)

    plot_heatmap(axis=plots["one-percent-heatmap"],
                 contour_linewidth=(width / 3) * 0.5,
                 contour_fontsize="large",
                 increment=increment,
                 mu=mu,
                 bottleneck=heatmap_bottleneck,
                 f_target=0.01,
                 min_prob=min_prob,
                 t_final=heatmap_t_final,
                 R0=R0_wh,
                 d_v=d_v,
                 c_w=1,
                 c_m=0,
                 max_t_M=max_t_M,
                 max_k=max_k,
                 contour_levels=[1e-3, 1e-2, 1e-1],
                 cbar=False)

    star_size = 20
    star_marker = '*'
    star_facecolor = 'white'
    star_edgecolor = 'black'
    star_edgewidth = 1.5

    plots['empirical-detect'].set_ylabel('number of infections')
    plots['empirical-detect'].set_title('observed HA\npolymorphism')
    plots['empirical-hist'].set_title('variant within-host\nfrequencies')
    plots['empirical-hist'].set_xlabel('variant frequency')
    plots['empirical-hist'].set_ylabel('number of variants')

    for heatmap_name in ['consensus-heatmap', 'one-percent-heatmap']:
        hm = plots[heatmap_name]
        escape = 0.25
        star_x = 2.5
        sterilizing_k = (R0_wh - 1) * d_v
        fitness_diff = sterilizing_k * escape
        star_y = fitness_diff
        hm.plot(star_x,
                star_y,
                marker=star_marker,
                markersize=star_size,
                markerfacecolor=star_facecolor,
                markeredgewidth=star_edgewidth,
                markeredgecolor=star_edgecolor)
        hm.set_xlabel("")
        hm.grid(b=False)
        hm.set_yticks(np.arange(0, 8.5, 2))
        hm.set_xticks(np.arange(0, 3.5, 1))

        hm.set_xlabel("")

        ## need to reintroduce labels because
        ## axis sharing turns them off
        hm.xaxis.set_tick_params(labelbottom=True)

    ## need to reintroduce labels because
    ## axis sharing turns them off
    plots['one-percent-heatmap'].yaxis.set_tick_params(labelleft=True)

    all_heatmaps.set_xlabel("time of recall response $t_M$ (days)")
    all_heatmaps.set_ylabel('selection strength $\delta$')

    all_timecourses.set_xlabel("time (days)", fontsize="xx-large")
    plots["naive-abs"].set_title("no recall\nresponse")
    plots["fixed-abs"].set_title("constant recall\nresponse")
    plots["var-abs"].set_title("recall response at\n48h")
    plots["inoc-abs"].set_title("response at 48h,\nvariant inoculated")
    plots["consensus-heatmap"].set_title("prob. new variant\n" "at consensus")
    plots["one-percent-heatmap"].set_title("prob. new variant\n" "at 1\%")

    plots["naive-abs"].set_ylabel("virions, cells")
    plots["naive-freq"].set_ylabel("variant frequency")

    plots['empirical-hist'].set_xlim([0, 0.5])

    # create legend
    cells = mlines.Line2D([], [],
                          color=ps.cell_color,
                          lw=lineweight,
                          label='target\ncells')
    wt = mlines.Line2D([], [],
                       color=ps.wt_color,
                       lw=lineweight,
                       label='old variant\nvirus')
    mut = mlines.Line2D([], [],
                        color=ps.mut_color,
                        lw=lineweight,
                        label='new variant\nvirus')
    ngs = mlines.Line2D([], [],
                        color='black',
                        lw=lineweight,
                        linestyle="dashed",
                        label='NGS detection\nlimit')
    analytical = mlines.Line2D([], [],
                               color='black',
                               lw=lineweight,
                               linestyle="dotted",
                               label='analytical new\nvariant frequency')
    antibody = mlines.Line2D([], [],
                             color=ps.immune_active_color,
                             lw=2 * lineweight,
                             alpha=0.25,
                             label='recall response\nactive')
    star_leg = mlines.Line2D([], [],
                             color='white',
                             marker=star_marker,
                             markersize=star_size,
                             markerfacecolor=star_facecolor,
                             markeredgewidth=star_edgewidth,
                             markeredgecolor=star_edgecolor,
                             label="influenza-like\nparameters")

    handles = [cells, wt, mut, antibody, ngs, analytical, star_leg]
    labels = [h.get_label() for h in handles]

    all_timecourses.legend(handles=handles,
                           labels=labels,
                           fontsize='x-large',
                           loc="center",
                           bbox_to_anchor=(0.5, -.3),
                           frameon=False,
                           ncol=int(np.ceil(len(handles) / 2)))

    fig.savefig(output_path)
Esempio n. 11
0
def main(output_path=None,
         fixed_results_dir=None,
         var_results_dir=None,
         gen_param_file=None,
         bottleneck=5,
         increment=0.5,
         landscape=True):
    if fixed_results_dir is None:
        fixed_results_dir = "../../out/within_host_results/minimal_visible"
    if var_results_dir is None:
        var_results_dir = "../../out/within_host_results/minimal_visvar"
    if output_path is None:
        output_path = "../../ms/main/figures/figure-wh-dynamics-summary.pdf"
    if gen_param_file is None:
        gen_param_file = "../../dat/RunParameters.mk"
    width = 18.3

    if landscape:
        height = 10
    else:
        height = 24.7

    heatmap_bottleneck = bottleneck
    lineweight = width / 2.5
    mpl.rcParams['lines.linewidth'] = lineweight
    non_trans_alpha = 0.4

    params = pp.get_params(gen_param_file)
    detection_threshold = float(params["DEFAULT_DETECTION_THRESHOLD"])
    transmission_threshold = float(params["DEFAULT_TRANS_THRESHOLD"])
    detection_limit = float(params["DEFAULT_DETECTION_LIMIT"])

    R0_wh = pp.get_param("R0_wh", "minimal_visvar", params)
    C_max = pp.get_param("C_max", "minimal_visvar", params)
    r_w = pp.get_param("r_w", "minimal_visvar", params)
    r_m = pp.get_param("r_m", "minimal_visvar", params)
    mu = pp.get_param("mu", "minimal_visvar", params)
    d_v = pp.get_param("d_v", "minimal_visvar", params)
    t_M = pp.get_param("t_M", "minimal_visvar", params)
    t_N = pp.get_param("t_N", "minimal_visvar", params)

    n_pairs = 5

    wt_col = "Vw"
    mut_col = "Vm"
    cell_col = "C"
    col_colors = [ps.wt_color, ps.mut_color, ps.cell_color]
    detection_linestyle = "dashed"
    detection_color = "black"

    sims_to_plot = {
        "naive": {
            "results-dir": var_results_dir,
            "timecourse": "naive",
            "activation-time": t_N,
            "title": "no recall\nresponse",
            "E_w": False
        },
        "fixed": {
            "results-dir": fixed_results_dir,
            "timecourse": "repl_visible",
            "title": "constant recall\nresponse",
            "activation-time": 0,
            "E_w": True
        },
        "var": {
            "results-dir": var_results_dir,
            "timecourse": "repl_visible",
            "title": "delayed recall\nresponse",
            "activation-time": t_M,
            "E_w": True
        },
        "inoc": {
            "results-dir": var_results_dir,
            "timecourse": "inoc_visible",
            "activation-time": t_M,
            "title": "delayed, mutant\ninoculated",
            "E_w": True
        }
    }

    for sim_type, metadata in sims_to_plot.items():
        fig, plot_axes = plt.subplots(1, 2, figsize=(16, 8))
        for plot_freq, text in zip([0, 1], ["abs", "freq"]):
            plot_wh_timecourse(metadata["timecourse"],
                               bottleneck,
                               metadata["results-dir"],
                               wt_col,
                               mut_col,
                               col_colors=col_colors,
                               cell_col=cell_col,
                               detection_threshold=detection_threshold,
                               detection_color=detection_color,
                               detection_limit=detection_limit,
                               detection_linestyle=detection_linestyle,
                               transmission_threshold=transmission_threshold,
                               axis=plot_axes[plot_freq],
                               t_M=metadata["activation-time"],
                               E_w=metadata["E_w"],
                               gen_param_file=gen_param_file,
                               non_trans_alpha=non_trans_alpha,
                               frequencies=plot_freq,
                               analytical_frequencies=True)
            plot_axes[plot_freq].set_xlim([0, 8])
            plot_axes[plot_freq].set_xticks([0, 2, 4, 6, 8])
        plot_axes[0].set_ylim(ymax=1e10)
        plot_axes[1].set_ylim(ymax=5)

        plot_axes[0].set_ylabel("virions, cells")
        plot_axes[0].set_xlabel("time (days)")
        plot_axes[1].set_xlabel("time (days)")
        plot_axes[1].set_ylabel("variant frequency")
        fig.tight_layout()
        fig.savefig("../../out/wh-plot-{}.pdf" "".format(sim_type))
    heatmap_fig, heatmaps = plt.subplots(1, 2, figsize=(16, 8))
    min_prob = plot_heatmap(axis=heatmaps[0],
                            axlabel_fontsize=20,
                            contour_linewidth=(width / 3) * 0.75,
                            contour_fontsize="x-large",
                            increment=increment,
                            mu=mu,
                            f_target=0.5,
                            bottleneck=heatmap_bottleneck,
                            params=params,
                            cbar=True)

    plot_heatmap(axis=heatmaps[1],
                 axlabel_fontsize=20,
                 contour_linewidth=(width / 3) * 0.75,
                 contour_fontsize="x-large",
                 increment=increment,
                 mu=mu,
                 bottleneck=heatmap_bottleneck,
                 f_target=0.01,
                 min_prob=min_prob,
                 params=params,
                 cbar=True)
    heatmaps[0].set_xlabel("time of recall response")
    heatmaps[1].set_xlabel("time of recall response")

    heatmaps[0].set_ylabel("selection strength")
    heatmaps[1].set_title("probability of selection\n to 1%")
    heatmap_fig.tight_layout()
    heatmap_fig.savefig("../../out/talk-heatmaps.pdf")
Esempio n. 12
0
def plot_inoculation_summary(wh_sim_data_dir=None,
                             schematic_path=None,
                             run_parameter_path=None,
                             file_output_path=None,
                             escape=0.75,
                             v=10,
                             b=1,
                             sigma=0.75,
                             inoculum_model='poisson',
                             z_homotypic=None):

    width = 18.3
    lineweight = width / 3
    span_fontsize = 'x-large'
    schem_fontsize = 'xx-large'

    fig, plots = setup_figure()

    # get within-host data
    dataframe = asc.get_vb_data(wh_sim_data_dir)

    # get model name
    model_name = os.path.basename(wh_sim_data_dir)
    print("model name: ", model_name)

    # get parameters
    params = pp.get_params(run_parameter_path)

    inoculum_sizes = [1, 3, 10, 50, 100, 200]

    ## plot bottleneck schematic
    plot_image_from_path(schematic_path,
                         axis=plots['bottleneck-schematic'],
                         retain_axis=True)

    schem_labs = [
        "", "excretion\nbottleneck", "inter-host\nbottleneck",
        "mucus\nbottleneck", "sIgA\nbottleneck", "cell infection\nbottleneck"
    ]

    plots['bottleneck-schematic'].set_xticklabels(schem_labs,
                                                  fontsize=schem_fontsize)

    schem_ticks = plots['bottleneck-schematic'].get_xticks()

    schem_ticks = [
        (x + schem_ticks[i + 1]) / 2 if x < len(schem_ticks) - 1 else x +
        (x - schem_ticks[i - 1]) / 2 for i, x in enumerate(schem_ticks)
    ]
    schem_ticks = schem_ticks[:len(schem_labs)]
    print(schem_ticks)
    plots['bottleneck-schematic'].set_xticks(schem_ticks)
    plots['bottleneck-schematic'].set_xlim(
        left=(schem_ticks[0] + schem_ticks[1]) / 2)
    plots['bottleneck-schematic'].grid(b=False)

    plots['bottleneck-schematic'].set_yticklabels([])

    ## get parameters and plot filtered inocula
    z_wt = pp.get_param("z_wt", model_name, params)

    if z_homotypic is None:
        z_homotypic = z_wt

    f_mut = pp.get_param("f_mut", "", params)
    print("fmt: ", f_mut)

    mut_wt_neut_ratio = pp.get_param("mut_wt_neut_ratio", model_name, params)

    kappa_ws = pot.neutralize_prob_from_z(z_wt, np.array(inoculum_sizes),
                                          "poisson")
    sim_p_mut = dataframe.groupby(
        'n_encounters_iga')['p_mut_inoc'].mean()[inoculum_sizes]
    plot_filtered_inocula(inoculum_sizes,
                          sim_p_mut,
                          axis=plots["pre-filter-dist"])
    plots['pre-filter-dist'].legend(
        ['neither', 'old variant', 'new variant', 'both'], loc='lower right')

    plot_filtered_inocula(inoculum_sizes,
                          sim_p_mut,
                          axis=plots["post-filter-dist"],
                          kappa_w=kappa_ws,
                          kappa_m=kappa_ws * mut_wt_neut_ratio)
    print("sim_p_mut:", sim_p_mut)

    plot_sim_inocula(dataframe,
                     inoculum_sizes=inoculum_sizes,
                     axis=plots["mutant-creator-dist"],
                     f_mut_threshold=0.5)

    print("plotting virion survival...")

    plot_survive_bottleneck(axis=plots["virion-survival"],
                            f_mut=f_mut,
                            bottleneck=b,
                            lw=lineweight * 2.5,
                            drift=True,
                            inoculum_scaleup=v / b,
                            cmap=plt.cm.Greens,
                            fineness=10)
    plots['virion-survival'].set_ylim(bottom=0)
    plots['virion-survival'].set_xlim(left=0, right=1)

    print("Plotting susceptibility model examples...")
    sus_func_cmaps = [plt.cm.Purples, plt.cm.Greens]

    sus_line_alpha = 0.8

    plot_sus_funcs(escape=escape,
                   cmaps=sus_func_cmaps,
                   axis=plots["susceptibility-models"],
                   z_homotypic=z_homotypic,
                   line_alpha=sus_line_alpha)

    print("Plotting optimal selector by memory age...")
    c_darks = [0.7, 0.4]
    drift_style = "dashed"

    plot_optimal_selector_cluster(f_mut=f_mut,
                                  inoculum_size=v,
                                  bottleneck=b,
                                  escape=escape,
                                  axis=plots["optimal-selectors"],
                                  cmaps=sus_func_cmaps,
                                  darkness=c_darks[0],
                                  line_alpha=sus_line_alpha,
                                  inoculum_model=inoculum_model,
                                  z_homotypic=z_homotypic,
                                  drift_style=drift_style,
                                  legend=False)

    plot_optimal_selector_cluster(f_mut=f_mut,
                                  inoculum_size=v,
                                  bottleneck=b,
                                  escape=escape,
                                  axis=plots["optimal-selectors"],
                                  cross_imm_sigma=sigma,
                                  cmaps=sus_func_cmaps,
                                  darkness=c_darks[1],
                                  line_alpha=sus_line_alpha,
                                  inoculum_model=inoculum_model,
                                  z_homotypic=z_homotypic,
                                  plot_drift=False,
                                  legend=False)

    # manual legend for optimal selector panel
    sel_cluster_handles = [
        mlines.Line2D([], [],
                      color=plt.cm.Greys(c_darks[0]),
                      label="implied\nby $z_{m}$"),
        mlines.Line2D([], [],
                      color=plt.cm.Greys(c_darks[1]),
                      label=("{} ".format(sigma) + "$\kappa_{w}$"))
    ]
    plots['optimal-selectors'].legend(handles=sel_cluster_handles,
                                      title="$\kappa_m$",
                                      fancybox=True,
                                      frameon=True)

    print("plotting cutdown at iga bottleneck...")
    plot_cutdown(axis=plots['cutdown'], f_mut=f_mut)
    plots['cutdown'].set_xticks([1, 100, 200])

    # plot styling
    cluster_ticks = [0, 1, 2, 3, 4, 5]
    for plotname in ["susceptibility-models", "optimal-selectors"]:
        plots[plotname].set_xticks(cluster_ticks)
        plots[plotname].set_xlim(left=cluster_ticks[0],
                                 right=cluster_ticks[-1])
        plots[plotname].set_ylim(bottom=0)
        plots[plotname].set_xlabel("distance between\n"
                                   "host memory and old variant")

    for plotname in [
            "pre-filter-dist", "post-filter-dist", "mutant-creator-dist"
    ]:
        plots[plotname].label_outer()
        plots[plotname].set_yticks([0, 0.25, 0.5, 0.75, 1])

    plots['pre-filter-dist'].set_ylabel('proportion of inocula')
    plots['pre-filter-dist'].set_title('before IgA\n')
    plots['post-filter-dist'].set_title('after IgA\n')
    plots['mutant-creator-dist'].set_title('new variant emerged\n')

    emergence_y_text = "new variant infections\nper inoculation"
    plots["optimal-selectors"].set_ylabel(emergence_y_text)

    plots["virion-survival"].set_ylabel(
        "prob. new variant survives bottleneck")

    plots["virion-survival"].set_ylabel(
        "prob. new variant survives bottleneck")

    plots["all_dists"].set_xlabel('virions encountering IgA\n($v$)',
                                  fontsize=span_fontsize)

    fig.tight_layout(h_pad=-0.1)
    print("saving figure to {}".format(file_output_path))
    fig.savefig(file_output_path)