Beispiel #1
0
def plot_transition_matrix(
    data: Union[pd.DataFrame, traja.TrajaDataFrame, np.ndarray],
    interactive=True,
    **kwargs,
) -> matplotlib.image.AxesImage:
    """Plot transition matrix.

    Args:
        data (trajectory or square transition matrix)
        interactive (bool): show plot
        kwargs: kwargs to :func:`traja.grid_coordinates`

    Returns:
        axesimage (matplotlib.image.AxesImage)

    """
    if isinstance(data, np.ndarray):
        if data.shape[0] != data.shape[1]:
            raise ValueException(
                f"Ndarray input must be square transition matrix, shape is {data.shape}"
            )
        transition_matrix = data
    elif isinstance(data, (pd.DataFrame, traja.TrajaDataFrame)):
        transition_matrix = traja.transitions(data, **kwargs)
    img = plt.imshow(transition_matrix)
    if interactive:
        plt.show()
    return img
Beispiel #2
0
def plot_transition_graph(
    data: Union[pd.DataFrame, traja.TrajaDataFrame, np.ndarray],
    outpath="markov.dot",
    interactive=True,
):
    """Plot transition graph with networkx.

    Args:
        data (trajectory or transition_matrix)

    .. note::
        Modified from http://www.blackarbs.com/blog/introduction-hidden-markov-models-python-networkx-sklearn/2/9/2017

    """
    try:
        import networkx as nx
        import pydot
        import graphviz
    except ImportError as e:
        raise ImportError(f"{e} - please install it with pip")

    if (isinstance(data, (traja.TrajaDataFrame))
            or isinstance(data, pd.DataFrame) and "x" in data):
        transition_matrix = traja.transitions(data)
        edges_wts = _get_markov_edges(pd.DataFrame(transition_matrix))
        states_ = list(range(transition_matrix.shape[0]))

    # create graph object
    G = nx.MultiDiGraph()

    # nodes correspond to states
    G.add_nodes_from(states_)

    # edges represent transition probabilities
    for k, v in edges_wts.items():
        tmp_origin, tmp_destination = k[0], k[1]
        G.add_edge(tmp_origin,
                   tmp_destination,
                   weight=v.round(4),
                   label=v.round(4))

    pos = nx.drawing.nx_pydot.graphviz_layout(G, prog="dot")
    nx.draw_networkx(G, pos)

    # create edge labels for jupyter plot but is not necessary
    edge_labels = {(n1, n2): d["label"] for n1, n2, d in G.edges(data=True)}
    nx.draw_networkx_edge_labels(G, pos, edge_labels=edge_labels)
    if os.exists(outpath):
        logging.info(f"Overwriting {outpath}")
    nx.drawing.nx_pydot.write_dot(G, outpath)

    if interactive:
        # Plot
        from graphviz import Source

        s = Source.from_file(outpath)
        s.view()
Beispiel #3
0
def test_transitions():
    df_copy = df.copy()
    transitions = traja.transitions(df_copy)
    assert isinstance(transitions, np.ndarray)

    # Check when bins set
    bins = traja._bins_to_tuple(df_copy, bins=None)
    xmin = df_copy.x.min()
    xmax = df_copy.x.max()
    ymin = df_copy.y.min()
    ymax = df_copy.y.max()
    xbins = np.linspace(xmin, xmax, bins[0])
    ybins = np.linspace(ymin, ymax, bins[1])
    xbin = np.digitize(df_copy.x, xbins)
    ybin = np.digitize(df_copy.y, ybins)

    df_copy.set("xbin", xbin)
    df_copy.set("ybin", ybin)
    transitions = traja.transitions(df_copy)
    assert isinstance(transitions, np.ndarray)
Beispiel #4
0
 def transitions(self, *args, **kwargs):
     """Calculate transition matrix"""
     return traja.transitions(self._obj, *args, **kwargs)