def test_add_ground_truth_rt(self): df = pd.DataFrame({ 'Time Points': [1, 2, 3, 4, 5, 6], 'R_t': [3, 3, 0.5, 0.5, 0.5, 0.5] }) my_plot = bp.ReproductionNumberPlot() my_plot.add_ground_truth_rt(df) npt.assert_array_equal(np.array([my_plot.figure['data'][0]['x']]), np.array([np.array([1, 2, 3, 4, 5, 6])])) npt.assert_array_equal( np.array([my_plot.figure['data'][0]['y']]), np.array([np.array([3, 3, 0.5, 0.5, 0.5, 0.5])])) with self.assertRaises(TypeError): bp.ReproductionNumberPlot().add_ground_truth_rt(0) with self.assertWarns(UserWarning): df = pd.DataFrame({ 't': [1, 2, 3, 4, 5, 6], 'R_t': [3, 3, 0.5, 0.5, 0.5, 0.5] }) my_plot.add_ground_truth_rt(df, time_key='t') with self.assertWarns(UserWarning): df = pd.DataFrame({ 'Time Points': [1, 2, 3, 4, 5, 6], 'r': [3, 3, 0.5, 0.5, 0.5, 0.5] }) my_plot.add_ground_truth_rt(df, r_key='r')
def update_inference_figure(self, source=None): """Update the inference figure based on currently stored information. Parameters ---------- source : str Dash callback source Returns ------- plotly.Figure Figure with updated posterior distribution """ data = self.session_data.get('data_storage') posterior = self.session_data.get('posterior_storage') if data is None or posterior is None: raise dash.exceptions.PreventUpdate() time_label, inc_label = data.columns[:2] plot = bp.ReproductionNumberPlot() plot.add_interval_rt(posterior) if 'R_t' in data.columns: plot.add_ground_truth_rt(data[[time_label, 'R_t']], time_key=time_label, r_key='R_t') # Keeps traces visibility states fixed when changing sliders plot.figure['layout']['legend']['uirevision'] = True return plot.figure
def test_show_figure(self): with patch('plotly.graph_objs.Figure.show') as show_patch: df = pd.DataFrame({ 'Time Points': [1, 2, 3, 4, 5, 6], 'R_t': [3, 3, 0.5, 0.5, 0.5, 0.5] }) my_plot = bp.ReproductionNumberPlot() my_plot.add_ground_truth_rt(df) my_plot.show_figure() # Assert show_figure is called once assert show_patch.called
def test_update_labels(self): df = pd.DataFrame({ 'Time Points': [1, 2, 3, 4, 5, 6], 'R_t': [3, 3, 0.5, 0.5, 0.5, 0.5] }) my_plot = bp.ReproductionNumberPlot() my_plot.add_ground_truth_rt(df) new_time_label = 'Time' new_r_label = 'R Value' my_plot.update_labels(time_label=new_time_label) self.assertEqual(my_plot.figure['layout']['xaxis']['title']['text'], 'Time') my_plot.update_labels(r_label=new_r_label) self.assertEqual(my_plot.figure['layout']['yaxis']['title']['text'], 'R Value')
def test_add_interval_rt(self): df = pd.DataFrame({ 'Time': [1, 2, 3, 5, 6], 'Incidence Number': [0, 0, 0, 0, 0] }) ser_int = [1, 2] inference = bp.BranchProPosterior(df, ser_int, 1, 0.2) inference.run_inference(tau=2) intervals_df = inference.get_intervals(.95) my_plot = bp.ReproductionNumberPlot() my_plot.add_interval_rt(intervals_df) npt.assert_array_equal(np.array([my_plot.figure['data'][0]['x']]), np.array([np.array([4, 5, 6])])) npt.assert_array_equal(np.array([my_plot.figure['data'][0]['y']]), np.array([np.array([5.0] * 3)])) npt.assert_array_equal(np.array([my_plot.figure['data'][1]['x']]), np.array([np.array([4, 5, 6, 6, 5, 4])])) npt.assert_array_almost_equal( np.array([my_plot.figure['data'][1]['y']]), np.array([np.array([18.444397] * 3 + [0.126589] * 3)])) with self.assertRaises(TypeError): bp.ReproductionNumberPlot().add_interval_rt(0) with self.assertWarns(UserWarning): df = pd.DataFrame({ 'Time Points': [1, 2, 3, 4, 5, 6], 'R_t': [3, 3, 0.5, 0.5, 0.5, 0.5] }) my_plot = bp.ReproductionNumberPlot() my_plot.add_ground_truth_rt(df) dfs1 = pd.DataFrame({ 't': [4, 5, 6], 'Mean': [5.0] * 3, 'Lower bound CI': [5.0] * 3, 'Upper bound CI': [5.0] * 3, 'Central Probability': [.95] * 3 }) my_plot.add_interval_rt(dfs1, time_key='t') with self.assertWarns(UserWarning): df = pd.DataFrame({ 'Time Points': [1, 2, 3, 4, 5, 6], 'R_t': [3, 3, 0.5, 0.5, 0.5, 0.5] }) my_plot = bp.ReproductionNumberPlot() my_plot.add_ground_truth_rt(df) dfs2 = pd.DataFrame({ 'Time Points': [4, 5, 6], 'r': [5.0] * 3, 'Lower bound CI': [5.0] * 3, 'Upper bound CI': [5.0] * 3, 'Central Probability': [.95] * 3 }) my_plot.add_interval_rt(dfs2, r_key='r')
def test__init__(self): bp.ReproductionNumberPlot()
def __init__(self, long_callback_manager=None): """ Parameters ---------- long_callback_manager Optional callback manager for long callbacks. See https://dash.plotly.com/long-callbacks """ super(BranchProInferenceApp, self).__init__() self.app = dash.Dash(__name__, external_stylesheets=self.css, long_callback_manager=long_callback_manager) self.app.title = 'BranchproInf' self.session_data = { 'data_storage': None, 'interval_storage': None, 'posterior_storage': None } button_style = { 'width': '100%', 'height': '60px', 'lineHeight': '60px', 'borderWidth': '1px', 'borderStyle': 'dashed', 'borderRadius': '5px', 'textAlign': 'center', 'margin': '10px' } self.app.layout = html.Div([ dbc.Container( [ html.H1('Branching Processes', id='page-title'), html.Div([]), # Empty div for top explanation texts html.H2('Incidence Data'), dbc.Row( dbc.Col( dcc.Graph(figure=bp.IncidenceNumberPlot().figure, id='data-fig'))), dbc.Row( [ dbc.Col(children=[ html.H6([ 'You can upload your own ', html.Span( 'incidence data', id='inc-tooltip', style={ 'textDecoration': 'underline', 'cursor': 'pointer' }, ), ' here. It will appear as bars.' ]), dbc.Modal( self._inc_modal, id='inc_modal', size='xl', ), html.Div([ 'Data must be in the following column ' 'format: `Time`, `Incidence number`, ' '`Imported Cases` (optional), ' '`R_t` (true value of R, optional).' ]), dcc.Upload( id='upload-data', children=html.Div([ 'Drag and Drop or ', html.A('Select Files', style={ 'text-decoration': 'underline' }), ' to upload your Incidence \ Number data.' ]), style=button_style, # Allow multiple files to be uploaded multiple=True), html.Div(id='incidence-data-upload') ]), dbc.Col(children=[ html.H6([ 'You can upload your own ', html.Span('serial interval', id='si-tooltip', style={ 'textDecoration': 'underline', 'cursor': 'pointer' }), ' here.' ]), dbc.Modal( self._si_modal, id='si_modal', size='lg', ), html.Div([ 'Data must contain one or more serial ' 'intervals to be used for constructing' ' the posterior distributions each ' 'included as a column.' ]), dcc.Upload( id='upload-interval', children=html.Div([ 'Drag and Drop or ', html.A('Select Files', style={ 'text-decoration': '\ underline' }), ' to upload your Serial \ Interval.' ]), style=button_style, # Allow multiple files to be uploaded multiple=True), html.Div(id='ser-interval-upload') ]) ], align='center', ), html.H2('Plot of R values'), html.Progress(id='progress_bar'), html.Div( id='first_run', # see flip_first_run() in the app children='True', style={'display': 'none'}), dbc.Row( [ dbc.Col(children=dcc.Graph( figure=bp.ReproductionNumberPlot().figure, id='posterior-fig', style={'display': 'block'})), dbc.Col(self.update_sliders(), id='all-sliders') ], align='center', ), html.Div([]), # Empty div for bottom text html.Div(id='data_storage', style={'display': 'none'}), html.Div(id='interval_storage', style={'display': 'none'}), html.Div(id='posterior_storage', style={'display': 'none'}) ], fluid=True), self.mathjax_script ]) # Set the app index string for mathjax self.app.index_string = self.mathjax_html # Save the locations of texts from the layout self.main_text = self.app.layout.children[0].children[1].children self.collapsed_text = self.app.layout.children[0].children[-4].children