def plot_transition_matrix( data: Union[pd.DataFrame, traja.TrajaDataFrame, np.ndarray], interactive=True, **kwargs, ) -> matplotlib.image.AxesImage: """Plot transition matrix. Args: data (trajectory or square transition matrix) interactive (bool): show plot kwargs: kwargs to :func:`traja.grid_coordinates` Returns: axesimage (matplotlib.image.AxesImage) """ if isinstance(data, np.ndarray): if data.shape[0] != data.shape[1]: raise ValueException( f"Ndarray input must be square transition matrix, shape is {data.shape}" ) transition_matrix = data elif isinstance(data, (pd.DataFrame, traja.TrajaDataFrame)): transition_matrix = traja.transitions(data, **kwargs) img = plt.imshow(transition_matrix) if interactive: plt.show() return img
def plot_transition_graph( data: Union[pd.DataFrame, traja.TrajaDataFrame, np.ndarray], outpath="markov.dot", interactive=True, ): """Plot transition graph with networkx. Args: data (trajectory or transition_matrix) .. note:: Modified from http://www.blackarbs.com/blog/introduction-hidden-markov-models-python-networkx-sklearn/2/9/2017 """ try: import networkx as nx import pydot import graphviz except ImportError as e: raise ImportError(f"{e} - please install it with pip") if (isinstance(data, (traja.TrajaDataFrame)) or isinstance(data, pd.DataFrame) and "x" in data): transition_matrix = traja.transitions(data) edges_wts = _get_markov_edges(pd.DataFrame(transition_matrix)) states_ = list(range(transition_matrix.shape[0])) # create graph object G = nx.MultiDiGraph() # nodes correspond to states G.add_nodes_from(states_) # edges represent transition probabilities for k, v in edges_wts.items(): tmp_origin, tmp_destination = k[0], k[1] G.add_edge(tmp_origin, tmp_destination, weight=v.round(4), label=v.round(4)) pos = nx.drawing.nx_pydot.graphviz_layout(G, prog="dot") nx.draw_networkx(G, pos) # create edge labels for jupyter plot but is not necessary edge_labels = {(n1, n2): d["label"] for n1, n2, d in G.edges(data=True)} nx.draw_networkx_edge_labels(G, pos, edge_labels=edge_labels) if os.exists(outpath): logging.info(f"Overwriting {outpath}") nx.drawing.nx_pydot.write_dot(G, outpath) if interactive: # Plot from graphviz import Source s = Source.from_file(outpath) s.view()
def test_transitions(): df_copy = df.copy() transitions = traja.transitions(df_copy) assert isinstance(transitions, np.ndarray) # Check when bins set bins = traja._bins_to_tuple(df_copy, bins=None) xmin = df_copy.x.min() xmax = df_copy.x.max() ymin = df_copy.y.min() ymax = df_copy.y.max() xbins = np.linspace(xmin, xmax, bins[0]) ybins = np.linspace(ymin, ymax, bins[1]) xbin = np.digitize(df_copy.x, xbins) ybin = np.digitize(df_copy.y, ybins) df_copy.set("xbin", xbin) df_copy.set("ybin", ybin) transitions = traja.transitions(df_copy) assert isinstance(transitions, np.ndarray)
def transitions(self, *args, **kwargs): """Calculate transition matrix""" return traja.transitions(self._obj, *args, **kwargs)