示例#1
0
def GeneratePercentiles(sols):
    n_time_points = len(sols[0]['t'])

    y_plot = np.zeros((len(categories.keys()), len(sols), n_time_points))

    for k, sol in enumerate(sols):
        sol['y'] = np.asarray(sol['y'])
        for name in categories.keys():
            y_plot[categories[name]['index'],
                   k, :] = sol['y_plot'][categories[name]['index']]

    y_L95, y_U95, y_LQ, y_UQ, y_median = [
        np.zeros((len(categories.keys()), n_time_points)) for i in range(5)
    ]

    for name in categories.keys():
        y_L95[categories[name]['index'], :] = np.asarray([
            np.percentile(y_plot[categories[name]['index'], :, i], 2.5)
            for i in range(n_time_points)
        ])
        y_LQ[categories[name]['index'], :] = np.asarray([
            np.percentile(y_plot[categories[name]['index'], :, i], 25)
            for i in range(n_time_points)
        ])
        y_UQ[categories[name]['index'], :] = np.asarray([
            np.percentile(y_plot[categories[name]['index'], :, i], 75)
            for i in range(n_time_points)
        ])
        y_U95[categories[name]['index'], :] = np.asarray([
            np.percentile(y_plot[categories[name]['index'], :, i], 97.5)
            for i in range(n_time_points)
        ])

        y_median[categories[name]['index'], :] = np.asarray([
            statistics.median(y_plot[categories[name]['index'], :, i])
            for i in range(n_time_points)
        ])
    return [y_U95, y_UQ, y_LQ, y_L95, y_median]
示例#2
0
def stacked_bar_plot(sols,cats_to_plot,population_plot,population_frame):

    # population_plot = params.population
    font_size = 13
    lines_to_plot = []

    ii = -1
    for sol in sols:
        ii += 1
        for name in categories.keys():
            if name == cats_to_plot:
                sol['y'] = np.asarray(sol['y'])
                
                xx = sol['t']
                y_sum = np.zeros(len(xx))
                
                xx = [xx[i] for i in range(1,len(xx),2)]
                
                for i in range(population_frame.shape[0]): # age_cats
                    y_plot = 100*sol['y'][categories[name]['index']+ i*params.number_compartments,:]
                    y_sum  = y_sum + y_plot
                    legend_name = categories[name]['longname'] + ': ' + np.asarray(population_frame.Age)[i] # first one says e.g. infected

                    y_plot = [y_plot[i] for i in range(1,len(y_plot),2)]
                    
                    line =  {'x': xx, 'y': y_plot,

                            'hovertemplate': '%{y:.2f}%, ' + '%{text} <br>',# +
                                            # 'Time: %{x:.1f} days<extra></extra>',
                            'text': [population_format(i*population_plot/100) for i in y_plot],
                            # 'marker_line_width': 0,
                            # 'marker_line_color': 'black',
                            'type': 'bar',
                            'name': legend_name}
                    lines_to_plot.append(line)



    ymax = max(y_sum)
    # ymax = 0
    # for line in lines_to_plot:
    #     ymax = max(ymax,max(line['y']))


    yax = dict(range= [0,min(1.1*ymax,100)])
    ##

    lines_to_plot.append(
    dict(
        type='scatter',
        x = [0,sol['t'][-1]],
        y = [ 0, population_plot],
        yaxis="y2",
        opacity=0,
        hoverinfo = 'skip',
        showlegend=False
    ))


    yy2 = [0]
    for i in range(8):
        yy2.append(10**(i-5))
        yy2.append(2*10**(i-5))
        yy2.append(5*10**(i-5))

    yy = [i for i in yy2]

    pop_vec_lin = np.linspace(0,yy2[1],11)

    for i in range(len(yy)-1):
        if yax['range'][1]>yy[i] and yax['range'][1] <= yy[i+1]:
            pop_vec_lin = np.linspace(0,yy2[i+1],11)

    vec = [i*(population_plot) for i in pop_vec_lin]

    log_bottom = -8
    log_range = [log_bottom,np.log10(yax['range'][1])]

    pop_vec_log_intermediate = np.linspace(log_range[0],ceil(np.log10(pop_vec_lin[-1])), 1+ ceil(np.log10(pop_vec_lin[-1])-log_range[0]) )

    pop_log_vec = [10**(i) for i in pop_vec_log_intermediate]
    vec2 = [i*(population_plot) for i in pop_log_vec]





    layout = go.Layout(
                    template="simple_white",
                    font = dict(size= font_size), #'12em'),
                   margin=dict(t=5, b=5, l=10, r=10,pad=15),
                   hovermode='x',
                   xaxis= dict(
                        title='Days',
                        
                        automargin=True,
                        hoverformat='.0f',
                   ),
                   yaxis= dict(mirror= True,
                        title='Percentage of Total Population',
                        range= yax['range'],
                        
                        automargin=True,
                        type = 'linear'
                   ),
                    barmode = 'stack',

                    updatemenus = [dict(
                                            buttons=list([
                                                dict(
                                                    args=[{"yaxis": {'title': 'Percentage of Total Population', 'type': 'linear', 'range': yax['range'], 'automargin': True},
                                                    "yaxis2": {'title': 'Population','type': 'linear', 'overlaying': 'y1', 'range': yax['range'], 'ticktext': [population_format(0.01*vec[i]) for i in range(len(pop_vec_lin))], 'tickvals': [i for i in  pop_vec_lin],'automargin': True,'side':'right'}
                                                    }], # tickformat
                                                    label="Linear",
                                                    method="relayout"
                                                ),
                                                dict(
                                                    args=[{"yaxis": {'title': 'Percentage of Total Population', 'type': 'log', 'range': log_range,'automargin': True},
                                                    "yaxis2": {'title': 'Population','type': 'log', 'overlaying': 'y1', 'range': log_range, 'ticktext': [population_format(0.01*vec2[i]) for i in range(len(pop_log_vec))], 'tickvals': [i for i in  pop_log_vec],'automargin': True,'side':'right'}
                                                    }], # 'tickformat': yax_form_log,
                                                    label="Logarithmic",
                                                    method="relayout"
                                                )
                                        ]),
                                        x= 0.5,
                                        xanchor="right",
                                        pad={"r": 5, "t": 30, "b": 10, "l": 5},
                                        active=0,
                                        y=-0.13,
                                        showactive=True,
                                        direction='up',
                                        yanchor="top"
                                        )],
                                        legend = dict(
                                                        font=dict(size=font_size*(20/24)),
                                                        x = 0.5,
                                                        y = 1.03,
                                                        xanchor= 'center',
                                                        yanchor= 'bottom'
                                                    ),
                                        legend_orientation  = 'h',
                                        legend_title        = '<b> Key </b>',
                                        yaxis2 = dict(
                                                        title = 'Population',
                                                        overlaying='y1',
                                                        
                                                        range = yax['range'],
                                                        side='right',
                                                        ticktext = [population_format(0.01*vec[i]) for i in range(len(pop_vec_lin))],
                                                        tickvals = [i for i in  pop_vec_lin],
                                                        automargin=True
                                                    )

                            )

    return {'data': lines_to_plot, 'layout': layout}
示例#3
0
def age_structure_plot(sols,cats_to_plot,population_plot,population_frame): # ,confidence_range=None

    # population_plot = params.population

    font_size = 13

    lines_to_plot = []

    ii = -1
    for sol in sols:
        ii += 1
        for name in categories.keys():
            if name == cats_to_plot:
                sol['y'] = np.asarray(sol['y'])
                
                xx = sol['t']
                for i in range(population_frame.shape[0]): # # age_categories
                    y_plot = 100*sol['y'][categories[name]['index']+ i*params.number_compartments,:]

                    legend_name = categories[name]['longname'] + ': ' + np.asarray(population_frame.Age)[i] # first one says e.g. infected
                    
                    line =  {'x': xx, 'y': y_plot,

                            'hovertemplate': '%{y:.2f}%, ' + '%{text} <br>',# +
                                            # 'Time: %{x:.1f} days<extra></extra>',
                            'text': [population_format(i*population_plot/100) for i in y_plot],

                            'opacity': 0.5,
                            'name': legend_name}
                    lines_to_plot.append(line)



    ymax = 0
    for line in lines_to_plot:
        ymax = max(ymax,max(line['y']))


    yax = dict(range= [0,min(1.1*ymax,100)])
    ##

    lines_to_plot.append(
    dict(
        type='scatter',
        x = [0,sol['t'][-1]],
        y = [ 0, population_plot],
        yaxis="y2",
        opacity=0,
        hoverinfo = 'skip',
        showlegend=False
    ))

    shapes=[]
    annots=[]
    # if control_time[0]!=control_time[1] and not no_control:
    #     shapes.append(dict(
    #             # filled Blue Control Rectangle
    #             type="rect",
    #             x0= control_time[0],
    #             y0=0,
    #             x1= control_time[1],
    #             y1= yax['range'][1],
    #             line=dict(
    #                 color="LightSkyBlue",
    #                 width=0,
    #             ),
    #             fillcolor="LightSkyBlue",
    #             opacity= 0.15
    #         ))

    #     annots.append(dict(
    #             x  = 0.5*(control_time[0] + control_time[1]),
    #             y  = 0.5,
    #             text="<b>Control<br>" + "<b> In <br>" + "<b> Place",
    #             textangle=0,
    #             font=dict(
    #                 size= font_size*(30/24),
    #                 color="blue"
    #             ),
    #             showarrow=False,
    #             opacity=0.4,
    #             xshift= 0,
    #             xref = 'x',
    #             yref = 'paper',
    #     ))

    yy2 = [0]
    for i in range(8):
        yy2.append(10**(i-5))
        yy2.append(2*10**(i-5))
        yy2.append(5*10**(i-5))

    yy = [i for i in yy2]

    pop_vec_lin = np.linspace(0,yy2[1],11)

    for i in range(len(yy)-1):
        if yax['range'][1]>yy[i] and yax['range'][1] <= yy[i+1]:
            pop_vec_lin = np.linspace(0,yy2[i+1],11)

    vec = [i*(population_plot) for i in pop_vec_lin]

    log_bottom = -8
    log_range = [log_bottom,np.log10(yax['range'][1])]

    pop_vec_log_intermediate = np.linspace(log_range[0],ceil(np.log10(pop_vec_lin[-1])), 1+ ceil(np.log10(pop_vec_lin[-1])-log_range[0]) )

    pop_log_vec = [10**(i) for i in pop_vec_log_intermediate]
    vec2 = [i*(population_plot) for i in pop_log_vec]





    layout = go.Layout(
                    template="simple_white",
                    shapes=shapes,
                    annotations=annots,
                    font = dict(size= font_size), #'12em'),
                   margin=dict(t=5, b=5, l=10, r=10,pad=15),
                   hovermode='x',
                    xaxis= dict(
                        title='Days',
                        
                        automargin=True,
                        hoverformat='.0f',
                   ),
                   yaxis= dict(mirror= True,
                        title='Percentage of Total Population',
                        range= yax['range'],
                        
                        automargin=True,
                        type = 'linear'
                   ),
                    updatemenus = [dict(
                                            buttons=list([
                                                dict(
                                                    args=[{"yaxis": {'title': 'Percentage of Total Population', 'type': 'linear', 'range': yax['range'], 'automargin': True},
                                                    "yaxis2": {'title': 'Population','type': 'linear', 'overlaying': 'y1', 'range': yax['range'], 'ticktext': [population_format(0.01*vec[i]) for i in range(len(pop_vec_lin))], 'tickvals': [i for i in  pop_vec_lin],'automargin': True,'side':'right'}
                                                    }], # tickformat
                                                    label="Linear",
                                                    method="relayout"
                                                ),
                                                dict(
                                                    args=[{"yaxis": {'title': 'Percentage of Total Population', 'type': 'log', 'range': log_range,'automargin': True},
                                                    "yaxis2": {'title': 'Population','type': 'log', 'overlaying': 'y1', 'range': log_range, 'ticktext': [population_format(0.01*vec2[i]) for i in range(len(pop_log_vec))], 'tickvals': [i for i in  pop_log_vec],'automargin': True,'side':'right'}
                                                    }], # 'tickformat': yax_form_log,
                                                    label="Logarithmic",
                                                    method="relayout"
                                                )
                                        ]),
                                        x= 0.5,
                                        xanchor="right",
                                        pad={"r": 5, "t": 30, "b": 10, "l": 5},
                                        active=0,
                                        y=-0.13,
                                        showactive=True,
                                        direction='up',
                                        yanchor="top"
                                        )],
                                        legend = dict(
                                                        font=dict(size=font_size*(20/24)),
                                                        x = 0.5,
                                                        y = 1.03,
                                                        xanchor= 'center',
                                                        yanchor= 'bottom'
                                                    ),
                                        legend_orientation  = 'h',
                                        legend_title        = '<b> Key </b>',
                                        yaxis2 = dict(
                                                        title = 'Population',
                                                        overlaying='y1',
                                                        
                                                        range = yax['range'],
                                                        side='right',
                                                        ticktext = [population_format(0.01*vec[i]) for i in range(len(pop_vec_lin))],
                                                        tickvals = [i for i in  pop_vec_lin],
                                                        automargin=True
                                                    )

                            )



    return {'data': lines_to_plot, 'layout': layout}
示例#4
0
def uncertainty_plot(sols,cats_to_plot,population_plot,population_frame,confidence_range=None):

    if len(cats_to_plot)==0:
        cats_to_plot=['I']

    font_size = 13

    lines_to_plot = []

    
    percentiles = ['97.5','75','25','2.5']
    labels = ['1','75-97.5 percentile','25-75 percentile','2.5-25 percentile']
    showledge = [False,True,True,True]

    xx = sols[0]['t']

    for name in categories.keys():
        if name == cats_to_plot:
            ii = 0
            
            for yy in confidence_range[:-1]:
                # print(yy.shape)
                if ii == 0:
                    fill = None
                else:
                    fill = 'tonexty'

                if ii==2:
                    opac = '0.5)'
                else:
                    opac = '0.2)'

                ii = ii+1

                yy = np.asarray(yy)
                y_plot = 100*yy[categories[name]['index'],:]



                
                line =  {'x': xx, 'y': y_plot,
                        'hovertemplate': '%{y:.2f}%, %{text}, ' + percentiles[ii-1] + ' percentile<extra></extra>',
                                        # 'Time: %{x:.1f} days<extra></extra>',
                        'text': [population_format(i*population_plot/100) for i in y_plot],
                        'line': {'width': 0, 'color': categories[name]['colour']},
                        'fillcolor': categories[name]['fill_colour'][:-4] + opac,
                        'legendgroup': name + 'fill',
                        'showlegend': showledge[ii-1],
                        'mode': 'lines',
                        # 'opacity': 0.1,
                        'fill': fill,
                        'name': labels[ii-1]
                        }
                lines_to_plot.append(line)

    for name in categories.keys():
            if name == cats_to_plot:
                y_plot = 100*confidence_range[-1][categories[name]['index'],:]
                
                line =  {'x': xx, 'y': y_plot,
                        'hovertemplate': '%{y:.2f}%, %{text}',
                        'text': [population_format(i*population_plot/100) for i in y_plot],
                        'line': {'color': str(categories[name]['colour'])},
                        'legendgroup': name,
                        'name': categories[name]['longname'] + '; median'}
                lines_to_plot.append(line)


    ymax = 0
    for line in lines_to_plot:
        ymax = max(ymax,max(line['y']))


    yax = dict(range= [0,min(1.1*ymax,100)])
    ##

    lines_to_plot.append(
    dict(
        type='scatter',
        x = [0,xx[-1]],
        y = [ 0, population_plot],
        yaxis="y2",
        opacity=0,
        hoverinfo = 'skip',
        showlegend=False
    ))


    yy2 = [0]
    for i in range(8):
        yy2.append(10**(i-5))
        yy2.append(2*10**(i-5))
        yy2.append(5*10**(i-5))

    yy = [i for i in yy2]

    pop_vec_lin = np.linspace(0,yy2[1],11)

    for i in range(len(yy)-1):
        if yax['range'][1]>yy[i] and yax['range'][1] <= yy[i+1]:
            pop_vec_lin = np.linspace(0,yy2[i+1],11)

    vec = [i*(population_plot) for i in pop_vec_lin]

    log_bottom = -8
    log_range = [log_bottom,np.log10(yax['range'][1])]

    pop_vec_log_intermediate = np.linspace(log_range[0],ceil(np.log10(pop_vec_lin[-1])), 1+ ceil(np.log10(pop_vec_lin[-1])-log_range[0]) )

    pop_log_vec = [10**(i) for i in pop_vec_log_intermediate]
    vec2 = [i*(population_plot) for i in pop_log_vec]

    shapes=[]
    annots=[]

    # if control_time[0]!=control_time[1] and not no_control:
    #     shapes.append(dict(
    #             # filled Blue Control Rectangle
    #             type="rect",
    #             x0= control_time[0],
    #             y0=0,
    #             x1= control_time[1],
    #             y1= yax['range'][1],
    #             line=dict(
    #                 color="LightSkyBlue",
    #                 width=0,
    #             ),
    #             fillcolor="LightSkyBlue",
    #             opacity= 0.15
    #         ))

    #     annots.append(dict(
    #             x  = 0.5*(control_time[0] + control_time[1]),
    #             y  = 0.5,
    #             text="<b>Control<br>" + "<b> In <br>" + "<b> Place",
    #             textangle=0,
    #             font=dict(
    #                 size= font_size*(30/24),
    #                 color="blue"
    #             ),
    #             showarrow=False,
    #             opacity=0.4,
    #             xshift= 0,
    #             xref = 'x',
    #             yref = 'paper',
    #     ))


    layout = go.Layout(
                    template="simple_white",
                    shapes=shapes,
                    annotations=annots,
                    font = dict(size= font_size), #'12em'),
                   margin=dict(t=5, b=5, l=10, r=10,pad=15),
                   hovermode='x',
                   xaxis= dict(
                        title='Days',
                        
                        automargin=True,
                        hoverformat='.0f',
                   ),
                   yaxis= dict(mirror= True,
                        title='Percentage of Total Population',
                        range= yax['range'],
                        
                        automargin=True,
                        type = 'linear'
                   ),
                    updatemenus = [dict(
                                            buttons=list([
                                                dict(
                                                    args=[{"yaxis": {'title': 'Percentage of Total Population', 'type': 'linear', 'range': yax['range'], 'automargin': True},
                                                    "yaxis2": {'title': 'Population','type': 'linear', 'overlaying': 'y1', 'range': yax['range'], 'ticktext': [population_format(0.01*vec[i]) for i in range(len(pop_vec_lin))], 'tickvals': [i for i in  pop_vec_lin],'automargin': True,'side':'right'}
                                                    }], # tickformat
                                                    label="Linear",
                                                    method="relayout"
                                                ),
                                                dict(
                                                    args=[{"yaxis": {'title': 'Percentage of Total Population', 'type': 'log', 'range': log_range,'automargin': True},
                                                    "yaxis2": {'title': 'Population','type': 'log', 'overlaying': 'y1', 'range': log_range, 'ticktext': [population_format(0.01*vec2[i]) for i in range(len(pop_log_vec))], 'tickvals': [i for i in  pop_log_vec],'automargin': True,'side':'right'}
                                                    }], # 'tickformat': yax_form_log,
                                                    label="Logarithmic",
                                                    method="relayout"
                                                )
                                        ]),
                                        x= 0.5,
                                        xanchor="right",
                                        pad={"r": 5, "t": 30, "b": 10, "l": 5},
                                        active=0,
                                        y=-0.13,
                                        showactive=True,
                                        direction='up',
                                        yanchor="top"
                                        )],
                                        legend = dict(
                                                        font=dict(size=font_size*(20/24)),
                                                        x = 0.5,
                                                        y = 1.03,
                                                        xanchor= 'center',
                                                        yanchor= 'bottom'
                                                    ),
                                        legend_orientation  = 'h',
                                        legend_title        = '<b> Key </b>',
                                        yaxis2 = dict(
                                                        title = 'Population',
                                                        overlaying='y1',
                                                        
                                                        range = yax['range'],
                                                        side='right',
                                                        ticktext = [population_format(0.01*vec[i]) for i in range(len(pop_vec_lin))],
                                                        tickvals = [i for i in  pop_vec_lin],
                                                        automargin=True
                                                    )

                            )



    return {'data': lines_to_plot, 'layout': layout}
示例#5
0
def generate_csv(data_to_save,
                 population_frame,
                 filename,
                 input_type=None,
                 time_vec=None):

    category_map = {}
    for key in categories.keys():
        category_map[str(categories[key]['index'])] = key

    print(category_map)

    if input_type == 'percentile':
        csv_sol = np.transpose(data_to_save)

        solution_csv = pd.DataFrame(csv_sol)

        col_names = []
        for i in range(csv_sol.shape[1]):
            col_names.append(categories[category_map[str(i)]]['longname'])

        solution_csv.columns = col_names
        solution_csv['Time'] = time_vec
        # this is our dataframe to be saved

    elif input_type == 'raw':

        final_frame = pd.DataFrame()

        for key, value in tqdm(data_to_save.items()):
            csv_sol = np.transpose(value['y'])  # age structured

            solution_csv = pd.DataFrame(csv_sol)

            # setup column names
            col_names = []
            number_categories_with_age = csv_sol.shape[1]
            for i in range(number_categories_with_age):
                ii = i % params.number_compartments
                jj = floor(i / params.number_compartments)

                col_names.append(
                    categories[category_map[str(ii)]]['longname'] + ': ' +
                    str(np.asarray(population_frame.Age)[jj]))

            solution_csv.columns = col_names
            solution_csv['Time'] = value['t']

            for j in range(len(
                    categories.keys())):  # params.number_compartments
                solution_csv[categories[category_map[str(j)]]
                             ['longname']] = value['y_plot'][
                                 j]  # summary/non age-structured

            (R0, latentRate, removalRate, hospRate, deathRateICU,
             deathRateNoIcu) = key
            solution_csv['R0'] = [R0] * solution_csv.shape[0]
            solution_csv['latentRate'] = [latentRate] * solution_csv.shape[0]
            solution_csv['removalRate'] = [removalRate] * solution_csv.shape[0]
            solution_csv['hospRate'] = [hospRate] * solution_csv.shape[0]
            solution_csv['deathRateICU'] = [deathRateICU
                                            ] * solution_csv.shape[0]
            solution_csv['deathRateNoIcu'] = [deathRateNoIcu
                                              ] * solution_csv.shape[0]
            final_frame = pd.concat([final_frame, solution_csv],
                                    ignore_index=True)

        solution_csv = final_frame
        #this is our dataframe to be saved

    elif input_type == 'solution':
        csv_sol = np.transpose(data_to_save[0]['y'])  # age structured

        solution_csv = pd.DataFrame(csv_sol)

        # setup column names
        col_names = []
        number_categories_with_age = csv_sol.shape[1]
        for i in range(number_categories_with_age):
            ii = i % params.number_compartments
            jj = floor(i / params.number_compartments)

            col_names.append(categories[category_map[str(ii)]]['longname'] +
                             ': ' + str(np.asarray(population_frame.Age)[jj]))

        solution_csv.columns = col_names
        solution_csv['Time'] = data_to_save[0]['t']

        for j in range(len(categories.keys())):  # params.number_compartments
            solution_csv[categories[category_map[str(j)]]
                         ['longname']] = data_to_save[0]['y_plot'][
                             j]  # summary/non age-structured
        # this is our dataframe to be saved

    # save it
    solution_csv.to_csv(
        os.path.join(os.path.dirname(cwd), 'CSV_output/' + filename + '.csv'))

    return None
示例#6
0
    def run_model(
            self,
            T_stop,
            population,
            population_frame,
            infection_matrix,
            beta,
            control_dict,  # control
            latentRate=params.latent_rate,
            removalRate=params.removal_rate,
            hospRate=params.hosp_rate,
            deathRateICU=params.death_rate_with_ICU,
            deathRateNoIcu=params.death_rate  # more params
    ):

        E0 = 0  # exposed
        I0 = 1 / population  # sympt
        A0 = 1 / population  # asympt
        R0 = 0  # recovered
        H0 = 0  # hospitalised/needing hospital care
        C0 = 0  # critical (cared)
        D0 = 0  # dead
        O0 = 0  # offsite
        Q0 = 0  # quarantined
        U0 = 0  # critical (uncared)

        S0 = 1 - I0 - R0 - C0 - H0 - D0 - O0 - Q0 - U0

        age_categories = int(population_frame.shape[0])

        y0 = np.zeros(params.number_compartments * age_categories)

        population_vector = np.asarray(population_frame.Population_structure)

        # initial conditions
        for i in range(age_categories):
            y0[params.S_ind + i *
               params.number_compartments] = (population_vector[i] / 100) * S0
            y0[params.E_ind + i *
               params.number_compartments] = (population_vector[i] / 100) * E0
            y0[params.I_ind + i *
               params.number_compartments] = (population_vector[i] / 100) * I0
            y0[params.A_ind + i *
               params.number_compartments] = (population_vector[i] / 100) * A0
            y0[params.R_ind + i *
               params.number_compartments] = (population_vector[i] / 100) * R0
            y0[params.H_ind + i *
               params.number_compartments] = (population_vector[i] / 100) * H0
            y0[params.C_ind + i *
               params.number_compartments] = (population_vector[i] / 100) * C0
            y0[params.D_ind + i *
               params.number_compartments] = (population_vector[i] / 100) * D0
            y0[params.O_ind + i *
               params.number_compartments] = (population_vector[i] / 100) * O0
            y0[params.Q_ind + i *
               params.number_compartments] = (population_vector[i] / 100) * Q0
            y0[params.U_ind + i *
               params.number_compartments] = (population_vector[i] / 100) * U0

        symptomatic_prob = np.asarray(population_frame.p_symptomatic)
        hospital_prob = np.asarray(population_frame.p_hospitalised)
        critical_prob = np.asarray(population_frame.p_critical)

        sol = ode(self.ode_system).set_f_params(
            infection_matrix,
            age_categories,
            symptomatic_prob,
            hospital_prob,
            critical_prob,
            beta,  # params
            latentRate,
            removalRate,
            hospRate,
            deathRateICU,
            deathRateNoIcu,  # more params
            control_dict['better_hygiene'],
            control_dict['remove_symptomatic'],
            control_dict['remove_high_risk'],
            control_dict['ICU_capacity']  # control params
        )

        tim = np.linspace(0, T_stop, T_stop + 1)  # 1 time value per day

        sol.set_initial_value(y0, tim[0])

        y_out = np.zeros((len(y0), len(tim)))

        i2 = 0
        y_out[:, 0] = sol.y
        for t in tim[1:]:
            if sol.successful():
                sol.integrate(t)
                i2 = i2 + 1
                y_out[:, i2] = sol.y
            else:
                raise RuntimeError('ode solver unsuccessful')

        y_plot = np.zeros((len(categories.keys()), len(tim)))
        for name in calculated_categories:

            y_plot[categories[name]['index'], :] = y_out[
                categories[name]['index'], :]
            for i in range(1, population_frame.shape[0]):  # age_categories
                y_plot[categories[name]['index'], :] = y_plot[
                    categories[name]['index'], :] + y_out[
                        categories[name]['index'] +
                        i * params.number_compartments, :]

        for name in change_in_categories:  # daily change in
            name_changed_var = name[
                -1]  # name of the variable we want daily change of
            y_plot[categories[name]['index'], :] = np.concatenate(
                [[0],
                 np.diff(y_plot[categories[name_changed_var]['index'], :])])

        # finally,
        E = y_plot[categories['CE']['index'], :]
        I = y_plot[categories['CI']['index'], :]
        A = y_plot[categories['CA']['index'], :]

        y_plot[categories['Ninf']['index'], :] = [
            E[i] + I[i] + A[i] for i in range(len(E))
        ]  # change in total number of people with active infection

        return {'y': y_out, 't': tim, 'y_plot': y_plot}
示例#7
0
def simulate_R0_unmitigated(
    R_0,
    column,
    t_stop=200
):  # gives solution for middle R0, as well as solutions for a range of R0s between an upper and lower bound
    from plots import plot_by_age
    import sys, os
    cwd = os.getcwd()
    sys.path.append(os.path.abspath(os.path.join('..', 'base')))
    from functions import simulator
    from configs.baseline import camp, population_frame, population, control_dict
    from initialise_parameters import params, categories
    # infection_matrix = np.asarray(pd.read_csv(os.path.join(os.path.dirname(cwd),'Parameters/Contact_matrix.csv'))) #np.ones((population_frame.shape[0],population_frame.shape[0]))
    infection_matrix = np.asarray(
        pd.read_csv(
            os.path.join(os.path.dirname(cwd),
                         'Parameters/Contact_matrix_' + camp + '.csv'))
    )  #np.ones((population_frame.shape[0],population_frame.shape[0]))
    infection_matrix = infection_matrix[:, 1:]

    next_generation_matrix = np.matmul(
        0.01 * np.diag(population_frame.Population_structure),
        infection_matrix)
    largest_eigenvalue = max(
        np.linalg.eig(next_generation_matrix)[0])  # max eigenvalue
    beta = R_0 * params.removal_rate
    beta = np.real(
        (1 / largest_eigenvalue) * beta)  # in case eigenvalue imaginary
    sols_raw = {}
    result = simulator().run_model(T_stop=t_stop,
                                   infection_matrix=infection_matrix,
                                   population=population,
                                   population_frame=population_frame,
                                   beta=beta,
                                   control_dict=control_dict)
    sols_raw[beta * largest_eigenvalue / params.removal_rate] = result
    final_frame = pd.DataFrame()
    category_map = {
        '0': 'S',
        '1': 'E',
        '2': 'I',
        '3': 'A',
        '4': 'R',
        '5': 'H',
        '6': 'C',
        '7': 'D',
        '8': 'O',
        '9': 'CS',  # change in S
        '10': 'CE',  # change in E
        '11': 'CI',  # change in I
        '12': 'CA',  # change in A
        '13': 'CR',  # change in R
        '14': 'CH',  # change in H
        '15': 'CC',  # change in C
        '16': 'CD',  # change in D
        '17': 'CO',  # change in O
        '18': 'Ninf',
    }
    for key, value in sols_raw.items():
        csv_sol = np.transpose(value['y'])  # age structured
        solution_csv = pd.DataFrame(csv_sol)
        # setup column names
        col_names = []
        number_categories_with_age = csv_sol.shape[1]
        for i in range(number_categories_with_age):
            ii = i % params.number_compartments
            jj = floor(i / params.number_compartments)

            col_names.append(categories[category_map[str(ii)]]['longname'] +
                             ': ' + str(np.asarray(population_frame.Age)[jj]))

        solution_csv.columns = col_names
        solution_csv['Time'] = value['t']

        for j in range(len(categories.keys())):  # params.number_compartments
            solution_csv[categories[category_map[str(j)]]['longname']] = value[
                'y_plot'][j]  # summary/non age-structured

        solution_csv['R0'] = [key] * solution_csv.shape[0]
        final_frame = pd.concat([final_frame, solution_csv], ignore_index=True)
    solution_csv = final_frame
    #this is our dataframe to be saved
    plot_by_age(column, solution_csv)