Beispiel #1
0
    def test_add_ground_truth_rt(self):
        df = pd.DataFrame({
            'Time Points': [1, 2, 3, 4, 5, 6],
            'R_t': [3, 3, 0.5, 0.5, 0.5, 0.5]
        })
        my_plot = bp.ReproductionNumberPlot()
        my_plot.add_ground_truth_rt(df)

        npt.assert_array_equal(np.array([my_plot.figure['data'][0]['x']]),
                               np.array([np.array([1, 2, 3, 4, 5, 6])]))

        npt.assert_array_equal(
            np.array([my_plot.figure['data'][0]['y']]),
            np.array([np.array([3, 3, 0.5, 0.5, 0.5, 0.5])]))

        with self.assertRaises(TypeError):
            bp.ReproductionNumberPlot().add_ground_truth_rt(0)

        with self.assertWarns(UserWarning):
            df = pd.DataFrame({
                't': [1, 2, 3, 4, 5, 6],
                'R_t': [3, 3, 0.5, 0.5, 0.5, 0.5]
            })
            my_plot.add_ground_truth_rt(df, time_key='t')

        with self.assertWarns(UserWarning):
            df = pd.DataFrame({
                'Time Points': [1, 2, 3, 4, 5, 6],
                'r': [3, 3, 0.5, 0.5, 0.5, 0.5]
            })
            my_plot.add_ground_truth_rt(df, r_key='r')
Beispiel #2
0
    def update_inference_figure(self, source=None):
        """Update the inference figure based on currently stored information.

        Parameters
        ----------
        source : str
            Dash callback source

        Returns
        -------
        plotly.Figure
            Figure with updated posterior distribution
        """
        data = self.session_data.get('data_storage')
        posterior = self.session_data.get('posterior_storage')

        if data is None or posterior is None:
            raise dash.exceptions.PreventUpdate()

        time_label, inc_label = data.columns[:2]

        plot = bp.ReproductionNumberPlot()
        plot.add_interval_rt(posterior)

        if 'R_t' in data.columns:
            plot.add_ground_truth_rt(data[[time_label, 'R_t']],
                                     time_key=time_label,
                                     r_key='R_t')

        # Keeps traces visibility states fixed when changing sliders
        plot.figure['layout']['legend']['uirevision'] = True

        return plot.figure
Beispiel #3
0
    def test_show_figure(self):
        with patch('plotly.graph_objs.Figure.show') as show_patch:
            df = pd.DataFrame({
                'Time Points': [1, 2, 3, 4, 5, 6],
                'R_t': [3, 3, 0.5, 0.5, 0.5, 0.5]
            })
            my_plot = bp.ReproductionNumberPlot()
            my_plot.add_ground_truth_rt(df)
            my_plot.show_figure()

        # Assert show_figure is called once
        assert show_patch.called
Beispiel #4
0
    def test_update_labels(self):
        df = pd.DataFrame({
            'Time Points': [1, 2, 3, 4, 5, 6],
            'R_t': [3, 3, 0.5, 0.5, 0.5, 0.5]
        })
        my_plot = bp.ReproductionNumberPlot()
        my_plot.add_ground_truth_rt(df)

        new_time_label = 'Time'
        new_r_label = 'R Value'

        my_plot.update_labels(time_label=new_time_label)
        self.assertEqual(my_plot.figure['layout']['xaxis']['title']['text'],
                         'Time')

        my_plot.update_labels(r_label=new_r_label)
        self.assertEqual(my_plot.figure['layout']['yaxis']['title']['text'],
                         'R Value')
Beispiel #5
0
    def test_add_interval_rt(self):
        df = pd.DataFrame({
            'Time': [1, 2, 3, 5, 6],
            'Incidence Number': [0, 0, 0, 0, 0]
        })
        ser_int = [1, 2]

        inference = bp.BranchProPosterior(df, ser_int, 1, 0.2)
        inference.run_inference(tau=2)
        intervals_df = inference.get_intervals(.95)

        my_plot = bp.ReproductionNumberPlot()
        my_plot.add_interval_rt(intervals_df)

        npt.assert_array_equal(np.array([my_plot.figure['data'][0]['x']]),
                               np.array([np.array([4, 5, 6])]))

        npt.assert_array_equal(np.array([my_plot.figure['data'][0]['y']]),
                               np.array([np.array([5.0] * 3)]))
        npt.assert_array_equal(np.array([my_plot.figure['data'][1]['x']]),
                               np.array([np.array([4, 5, 6, 6, 5, 4])]))

        npt.assert_array_almost_equal(
            np.array([my_plot.figure['data'][1]['y']]),
            np.array([np.array([18.444397] * 3 + [0.126589] * 3)]))

        with self.assertRaises(TypeError):
            bp.ReproductionNumberPlot().add_interval_rt(0)

        with self.assertWarns(UserWarning):
            df = pd.DataFrame({
                'Time Points': [1, 2, 3, 4, 5, 6],
                'R_t': [3, 3, 0.5, 0.5, 0.5, 0.5]
            })
            my_plot = bp.ReproductionNumberPlot()
            my_plot.add_ground_truth_rt(df)

            dfs1 = pd.DataFrame({
                't': [4, 5, 6],
                'Mean': [5.0] * 3,
                'Lower bound CI': [5.0] * 3,
                'Upper bound CI': [5.0] * 3,
                'Central Probability': [.95] * 3
            })
            my_plot.add_interval_rt(dfs1, time_key='t')

        with self.assertWarns(UserWarning):
            df = pd.DataFrame({
                'Time Points': [1, 2, 3, 4, 5, 6],
                'R_t': [3, 3, 0.5, 0.5, 0.5, 0.5]
            })
            my_plot = bp.ReproductionNumberPlot()
            my_plot.add_ground_truth_rt(df)

            dfs2 = pd.DataFrame({
                'Time Points': [4, 5, 6],
                'r': [5.0] * 3,
                'Lower bound CI': [5.0] * 3,
                'Upper bound CI': [5.0] * 3,
                'Central Probability': [.95] * 3
            })
            my_plot.add_interval_rt(dfs2, r_key='r')
Beispiel #6
0
 def test__init__(self):
     bp.ReproductionNumberPlot()
Beispiel #7
0
    def __init__(self, long_callback_manager=None):
        """
        Parameters
        ----------
        long_callback_manager
            Optional callback manager for long callbacks.
            See https://dash.plotly.com/long-callbacks
        """
        super(BranchProInferenceApp, self).__init__()

        self.app = dash.Dash(__name__,
                             external_stylesheets=self.css,
                             long_callback_manager=long_callback_manager)
        self.app.title = 'BranchproInf'

        self.session_data = {
            'data_storage': None,
            'interval_storage': None,
            'posterior_storage': None
        }

        button_style = {
            'width': '100%',
            'height': '60px',
            'lineHeight': '60px',
            'borderWidth': '1px',
            'borderStyle': 'dashed',
            'borderRadius': '5px',
            'textAlign': 'center',
            'margin': '10px'
        }

        self.app.layout = html.Div([
            dbc.Container(
                [
                    html.H1('Branching Processes', id='page-title'),
                    html.Div([]),  # Empty div for top explanation texts
                    html.H2('Incidence Data'),
                    dbc.Row(
                        dbc.Col(
                            dcc.Graph(figure=bp.IncidenceNumberPlot().figure,
                                      id='data-fig'))),
                    dbc.Row(
                        [
                            dbc.Col(children=[
                                html.H6([
                                    'You can upload your own ',
                                    html.Span(
                                        'incidence data',
                                        id='inc-tooltip',
                                        style={
                                            'textDecoration': 'underline',
                                            'cursor': 'pointer'
                                        },
                                    ), ' here. It will appear as bars.'
                                ]),
                                dbc.Modal(
                                    self._inc_modal,
                                    id='inc_modal',
                                    size='xl',
                                ),
                                html.Div([
                                    'Data must be in the following column '
                                    'format: `Time`, `Incidence number`, '
                                    '`Imported Cases` (optional), '
                                    '`R_t` (true value of R, optional).'
                                ]),
                                dcc.Upload(
                                    id='upload-data',
                                    children=html.Div([
                                        'Drag and Drop or ',
                                        html.A('Select Files',
                                               style={
                                                   'text-decoration':
                                                   'underline'
                                               }), ' to upload your Incidence \
                                                    Number data.'
                                    ]),
                                    style=button_style,
                                    # Allow multiple files to be uploaded
                                    multiple=True),
                                html.Div(id='incidence-data-upload')
                            ]),
                            dbc.Col(children=[
                                html.H6([
                                    'You can upload your own ',
                                    html.Span('serial interval',
                                              id='si-tooltip',
                                              style={
                                                  'textDecoration':
                                                  'underline',
                                                  'cursor': 'pointer'
                                              }), ' here.'
                                ]),
                                dbc.Modal(
                                    self._si_modal,
                                    id='si_modal',
                                    size='lg',
                                ),
                                html.Div([
                                    'Data must contain one or more serial '
                                    'intervals to be used for constructing'
                                    ' the posterior distributions each '
                                    'included as a column.'
                                ]),
                                dcc.Upload(
                                    id='upload-interval',
                                    children=html.Div([
                                        'Drag and Drop or ',
                                        html.A('Select Files',
                                               style={
                                                   'text-decoration':
                                                   '\
                                                            underline'
                                               }), ' to upload your Serial \
                                                    Interval.'
                                    ]),
                                    style=button_style,
                                    # Allow multiple files to be uploaded
                                    multiple=True),
                                html.Div(id='ser-interval-upload')
                            ])
                        ],
                        align='center',
                    ),
                    html.H2('Plot of R values'),
                    html.Progress(id='progress_bar'),
                    html.Div(
                        id='first_run',  # see flip_first_run() in the app
                        children='True',
                        style={'display': 'none'}),
                    dbc.Row(
                        [
                            dbc.Col(children=dcc.Graph(
                                figure=bp.ReproductionNumberPlot().figure,
                                id='posterior-fig',
                                style={'display': 'block'})),
                            dbc.Col(self.update_sliders(), id='all-sliders')
                        ],
                        align='center',
                    ),
                    html.Div([]),  # Empty div for bottom text
                    html.Div(id='data_storage', style={'display': 'none'}),
                    html.Div(id='interval_storage', style={'display': 'none'}),
                    html.Div(id='posterior_storage', style={'display': 'none'})
                ],
                fluid=True),
            self.mathjax_script
        ])

        # Set the app index string for mathjax
        self.app.index_string = self.mathjax_html

        # Save the locations of texts from the layout
        self.main_text = self.app.layout.children[0].children[1].children
        self.collapsed_text = self.app.layout.children[0].children[-4].children