def visualize(self, output_file=None): import networkx as nx from rasa.core.training import visualization from colorhash import ColorHash graph = nx.MultiDiGraph() next_node_idx = [0] nodes = {"STORY_START": 0, "STORY_END": -1} def ensure_checkpoint_is_drawn(cp): if cp.name not in nodes: next_node_idx[0] += 1 nodes[cp.name] = next_node_idx[0] if cp.name.startswith(GENERATED_CHECKPOINT_PREFIX): # colors generated checkpoints based on their hash color = ColorHash(cp.name[-GENERATED_HASH_LENGTH:]).hex graph.add_node(next_node_idx[0], label=utils.cap_length(cp.name), style="filled", fillcolor=color) else: graph.add_node(next_node_idx[0], label=utils.cap_length(cp.name)) graph.add_node(nodes["STORY_START"], label="START", fillcolor="green", style="filled") graph.add_node(nodes["STORY_END"], label="END", fillcolor="red", style="filled") for step in self.story_steps: next_node_idx[0] += 1 step_idx = next_node_idx[0] graph.add_node(next_node_idx[0], label=utils.cap_length(step.block_name), style="filled", fillcolor="lightblue", shape="rect") for c in step.start_checkpoints: ensure_checkpoint_is_drawn(c) graph.add_edge(nodes[c.name], step_idx) for c in step.end_checkpoints: ensure_checkpoint_is_drawn(c) graph.add_edge(step_idx, nodes[c.name]) if not step.end_checkpoints: graph.add_edge(step_idx, nodes["STORY_END"]) if output_file: visualization.persist_graph(graph, output_file) return graph
if __name__ == "__main__": import pickle import argparse import os from rasa.core.training.visualization import persist_graph parser = argparse.ArgumentParser() parser.add_argument("infile", type=str, help="Input stories with edges") parser.add_argument("--config-graph", type=str, help="Output pickle graph config file", default=PATH_CONFIG_GRAPH) parser.add_argument("--output-html", type=str, help="Output html graph", default=PATH_OUTPUT_GRAPH_HTML) parser.add_argument("--unmerge-nodes", type=str, help="Unmerge action node in graph, separate by comma", default='action_sorry') args = parser.parse_args() assert os.path.exists(args.infile), "story-infile-edges {} not found!".format(args.infile) print_log("Reading edges from {} - \n".format(args.infile)) edges_init = convert_edges_to_stories_with_checkpoint(args.infile) print_log("Building graph edges - ") graph_list, edges, nodes, graphviz = build_graph_from_edges(edges_init, unmerge_nodes=args.unmerge_nodes.split(',')) print_log("{} edges - {} nodes\n".format(len(edges), len(nodes))) pickle.dump(graph_list, open(args.config_graph, "wb")) print_log("Dump graph pickle in {}\n".format(args.config_graph)) persist_graph(graphviz, args.output_html) print_log("Print graph to {}\n".format(args.output_html))