Esempio n. 1
0
def plot_flow_params(plotter: StreamlitPlotter, app: AppRegion):
    # Assume a COVID model
    model = app.build_model(app.params["default"])
    flow_names = sorted(list({f.name for f in model._flows}))

    flow_name = st.sidebar.selectbox("Select flow", flow_names)
    flows = [f for f in model._flows if f.name == flow_name]
    is_logscale = st.sidebar.checkbox("Log scale")
    flow_funcs = [f.get_weight_value for f in flows]
    plots.model.plots.plot_time_varying_input(
        plotter, f"flow-weights-{flow_name}", flow_funcs, model.times, is_logscale
    )
    t = st.slider("Time", min_value=0, max_value=int(model.times[-1]))
    init_dict = {}
    for f in flows:
        f_name = ""
        src = getattr(f, "source", None)
        dest = getattr(f, "dest", None)
        if src:
            f_name += f"from {src}"
        if dest:
            f_name += f" to {dest}"

        f_name = f_name.strip()
        init_dict[f_name] = f.get_weight_value(t)

    st.write("Values at start time:")
    st.write(init_dict)
Esempio n. 2
0
def plot_dynamic_mixing_matrix(plotter: StreamlitPlotter, app: AppRegion):
    model = app.build_model(app.params["default"])
    t = st.slider("Time", min_value=0, max_value=int(model.times[-1]))
    mixing_matrix = model._get_mixing_matrix(t)
    fig, _, _, _, _, _ = plotter.get_figure()
    pyplot.imshow(mixing_matrix, cmap="hot", interpolation="none", extent=[0, 80, 80, 0])
    plotter.save_figure(fig, filename="mixing-matrix", title_text="Mixing matrix")
    st.write(mixing_matrix)
Esempio n. 3
0
def plot_dynamic_inputs(plotter: StreamlitPlotter, app: AppRegion):
    # Assume a COVID model
    model = app.build_model(app.params["default"])
    tvs = model.time_variants
    tv_options = sorted(list(tvs.keys()))
    tv_key = st.sidebar.selectbox("Select function", tv_options)
    is_logscale = st.sidebar.checkbox("Log scale")
    tv_func = tvs[tv_key]
    plots.model.plots.plot_time_varying_input(plotter, tv_key, tv_func, model.times, is_logscale)
Esempio n. 4
0
def plot_flow_graph(plotter: StreamlitPlotter, app: AppRegion):
    """
    Plot a graph of the model's compartments and flows
    See NetworkX documentation: https://networkx.org/documentation/stable/index.html
    """
    model = app.build_model(app.params["default"])

    flow_types = st.multiselect("Flow types", ["Transition", "Entry", "Exit"],
                                default=["Transition"])
    layout_lookup = {
        "Spring": nx.spring_layout,
        "Spectral": nx.spectral_layout,
        "Kamada Kawai": nx.kamada_kawai_layout,
        "Random": nx.random_layout,
    }
    layout_key = st.selectbox("Layout", list(layout_lookup.keys()))
    layout_func = layout_lookup[layout_key]

    is_node_labels_visible = st.checkbox("Show node labels")
    include_connected_nodes = st.checkbox("Include connected nodes")

    # ADD compartment selector
    original_compartment_names = model._original_compartment_names
    stratifications = model._stratifications
    compartment_names = model.compartments

    orig_comps = ["All"] + original_compartment_names
    chosen_comp_names = st.multiselect("Compartments",
                                       orig_comps,
                                       default="All")
    if "All" in chosen_comp_names:
        chosen_comp_names = original_compartment_names

    chosen_strata = {}
    for strat in stratifications:
        options = ["All"] + strat.strata
        choices = st.multiselect(strat.name, options, default=["All"])
        if "All" not in choices:
            chosen_strata[strat.name] = choices

    # Build the graph.
    comps_to_graph = []
    if "Entry" in flow_types:
        comps_to_graph.append("ENTRY")
    if "Exit" in flow_types:
        comps_to_graph.append("EXIT")

    for comp in compartment_names:
        is_selected = True
        if not comp.name in chosen_comp_names:
            is_selected = False

        for strat_name, strata in chosen_strata.items():
            has_strat = strat_name in comp._strat_names
            has_strata = any(comp.has_stratum(strat_name, s) for s in strata)
            if has_strat and not has_strata:
                is_selected = False
                continue

        if not is_selected:
            continue

        comps_to_graph.append(comp)

    if not comps_to_graph:
        st.write("Nothing to plot")
        return

    graph = nx.DiGraph()
    for comp in comps_to_graph:
        graph.add_node(comp)

    flow_lookup = {}
    for flow in model._flows:
        if "Entry" in flow_types and is_flow_type(flow, BaseEntryFlow):
            edge = ("ENTRY", flow.dest)
        elif "Exit" in flow_types and is_flow_type(flow, BaseExitFlow):
            edge = (flow.source, "EXIT")
        elif "Transition" in flow_types and is_flow_type(
                flow, BaseTransitionFlow):
            edge = (flow.source, flow.dest)
        else:
            continue

        if include_connected_nodes:
            src_is_valid = edge[0] in comps_to_graph and edge[0] != "ENTRY"
            dst_is_valid = edge[1] in comps_to_graph and edge[1] != "EXIT"
            should_add_edge = src_is_valid or dst_is_valid
        else:
            src_is_valid = edge[0] in comps_to_graph
            dst_is_valid = edge[1] in comps_to_graph
            should_add_edge = src_is_valid and dst_is_valid

        if should_add_edge:
            graph.add_edge(*edge)
            flow_lookup[edge_to_str(edge)] = flow

    # Draw the graph.
    pyplot.style.use("ggplot")
    fig = pyplot.figure(figsize=[14 / 1.5, 9 / 1.5], dpi=300)
    axis = fig.add_axes([0, 0, 1, 1])

    # Specify a layout technique.
    pos = layout_func(graph)

    # Draw the nodes.
    node_size = 12

    for node in graph.nodes:
        nx.draw_networkx_nodes(
            graph,
            pos,
            nodelist=[node],
            ax=axis,
            node_size=node_size,
            node_color="#333",
            alpha=1,
        )

    if is_node_labels_visible:
        labels = {}
        for node in graph.nodes:
            labels[node] = get_label(node)

        nx.draw_networkx_labels(graph,
                                pos,
                                labels,
                                font_size=8,
                                verticalalignment="top",
                                alpha=0.8)

    # Draw the edges between nodes.
    for edge in graph.edges:
        flow = flow_lookup[edge_to_str(edge)]
        color = "#666"
        style = "solid"
        alpha = 0.8

        for flow_style in FLOW_STYLES:
            if is_flow_type(flow, flow_style["class"]):
                color = flow_style["color"]
                style = flow_style["style"]
                alpha = flow_style["alpha"]

        nx.draw_networkx_edges(
            graph,
            pos,
            edgelist=[edge],
            width=1,
            node_size=node_size,
            alpha=alpha,
            edge_color=color,
            style=style,
            arrowsize=6,
        )

    st.pyplot(fig)