コード例 #1
0
    def __init__(self,
                 story_graph: StoryGraph,
                 domain: Domain,
                 remove_duplicates: bool = True,
                 unique_last_num_states: Optional[int] = None,
                 augmentation_factor: int = 20,
                 tracker_limit: Optional[int] = None,
                 use_story_concatenation: bool = True,
                 debug_plots: bool = False):
        """Given a set of story parts, generates all stories that are possible.

        The different story parts can end and start with checkpoints
        and this generator will match start and end checkpoints to
        connect complete stories. Afterwards, duplicate stories will be
        removed and the data is augmented (if augmentation is enabled)."""

        self.story_graph = story_graph.with_cycles_removed()
        if debug_plots:
            self.story_graph.visualize('story_blocks_connections.html')

        self.domain = domain

        # 10x factor is a heuristic for augmentation rounds
        max_number_of_trackers = augmentation_factor * 10

        self.config = ExtractorConfig(
            remove_duplicates=remove_duplicates,
            unique_last_num_states=unique_last_num_states,
            augmentation_factor=augmentation_factor,
            max_number_of_trackers=max_number_of_trackers,
            tracker_limit=tracker_limit,
            use_story_concatenation=use_story_concatenation,
            rand=random.Random(42))
        # hashed featurization of all finished trackers
        self.hashed_featurizations = set()
コード例 #2
0
def visualize_stories(
        story_steps,  # type: List[StoryStep]
        domain,  # type: Domain
        output_file,  # type: Optional[Text]
        max_history,  # type: int
        interpreter=RegexInterpreter(),  # type: NaturalLanguageInterpreter
        nlu_training_data=None,  # type: Optional[TrainingData]
        should_merge_nodes=True,  # type: bool
        fontsize=12,  # type: int
        silent=False  # type: bool
):
    """Given a set of stories, generates a graph visualizing the flows in the
    stories.

    Visualization is always a trade off between making the graph as small as
    possible while
    at the same time making sure the meaning doesn't change to "much". The
    algorithm will
    compress the graph generated from the stories to merge nodes that are
    similar. Hence,
    the algorithm might create paths through the graph that aren't actually
    specified in the
    stories, but we try to minimize that.

    Output file defines if and where a file containing the plotted graph
    should be stored.

    The history defines how much 'memory' the graph has. This influences in
    which situations the
    algorithm will merge nodes. Nodes will only be merged if they are equal
    within the history, this
    means the larger the history is we take into account the less likely it
    is we merge any nodes.

    The training data parameter can be used to pass in a Rasa NLU training
    data instance. It will
    be used to replace the user messages from the story file with actual
    messages from the training data."""

    story_graph = StoryGraph(story_steps)

    g = TrainingDataGenerator(story_graph,
                              domain,
                              use_story_concatenation=False,
                              tracker_limit=100,
                              augmentation_factor=0)
    completed_trackers = g.generate(silent)
    event_sequences = [t.events for t in completed_trackers]

    graph = visualize_neighborhood(None,
                                   event_sequences,
                                   output_file,
                                   max_history,
                                   interpreter,
                                   nlu_training_data,
                                   should_merge_nodes,
                                   max_distance=1,
                                   fontsize=fontsize)
    return graph
コード例 #3
0
ファイル: __init__.py プロジェクト: zh2010/rasa_core
def extract_story_graph_from_file(
        filename,  # type: Text
        domain,  # type: Domain
        interpreter=RegexInterpreter()  # type: NaturalLanguageInterpreter
):
    # type: (...) -> StoryGraph

    story_steps = StoryFileReader.read_from_file(filename, domain, interpreter)
    return StoryGraph(story_steps)
コード例 #4
0
ファイル: test_graph.py プロジェクト: githubclj/rasa_core
def test_node_ordering_with_cycle():
    example_graph = {
        "a": ["b", "c", "d"],
        "b": [],
        "c": ["d"],
        "d": ["a"],
        "e": ["f"],
        "f": ["e"]}
    sorted_nodes, removed_edges = StoryGraph.topological_sort(example_graph)

    check_graph_is_sorted(example_graph, sorted_nodes, removed_edges)
コード例 #5
0
ファイル: test_graph.py プロジェクト: prenigma/testfou
def test_node_ordering_with_cycle():
    example_graph = {
        "a": ["b", "c", "d"],
        "b": [],
        "c": ["d"],
        "d": ["a"],
        "e": ["f"],
        "f": ["e"]
    }
    sorted_nodes, removed_edges = StoryGraph.topological_sort(example_graph)

    check_graph_is_sorted(example_graph, sorted_nodes, removed_edges)
コード例 #6
0
ファイル: test_graph.py プロジェクト: githubclj/rasa_core
def test_node_ordering():
    example_graph = {
        "a": ["b", "c", "d"],
        "b": [],
        "c": ["d"],
        "d": [],
        "e": ["f"],
        "f": []}
    sorted_nodes, removed_edges = StoryGraph.topological_sort(example_graph)

    assert removed_edges == set()
    check_graph_is_sorted(example_graph, sorted_nodes, removed_edges)
コード例 #7
0
ファイル: test_graph.py プロジェクト: prenigma/testfou
def test_node_ordering():
    example_graph = {
        "a": ["b", "c", "d"],
        "b": [],
        "c": ["d"],
        "d": [],
        "e": ["f"],
        "f": []
    }
    sorted_nodes, removed_edges = StoryGraph.topological_sort(example_graph)

    assert removed_edges == set()
    check_graph_is_sorted(example_graph, sorted_nodes, removed_edges)
コード例 #8
0
def extract_story_graph(
    resource_name,  # type: Text
    domain,  # type: Domain
    interpreter=None  # type: Optional[NaturalLanguageInterpreter]
):
    # type: (...) -> StoryGraph
    from rasa_core.interpreter import RegexInterpreter
    from rasa_core.training.dsl import StoryFileReader
    from rasa_core.training.structures import StoryGraph

    if not interpreter:
        interpreter = RegexInterpreter()
    story_steps = StoryFileReader.read_from_folder(resource_name, domain,
                                                   interpreter)
    return StoryGraph(story_steps)
コード例 #9
0
ファイル: __init__.py プロジェクト: hanish2760/ChatBot
def extract_story_graph(
    resource_name: Text,
    domain: 'Domain',
    interpreter: Optional['NaturalLanguageInterpreter'] = None,
    use_e2e: bool = False,
    exclusion_percentage: int = None
) -> 'StoryGraph':
    from rasa_core.interpreter import RegexInterpreter
    from rasa_core.training.dsl import StoryFileReader
    from rasa_core.training.structures import StoryGraph

    if not interpreter:
        interpreter = RegexInterpreter()
    story_steps = StoryFileReader.read_from_folder(
        resource_name,
        domain, interpreter,
        use_e2e=use_e2e,
        exclusion_percentage=exclusion_percentage)
    return StoryGraph(story_steps)
コード例 #10
0
def extract_story_graph(
        resource_name,  # type: Text
        domain,  # type: Domain
        interpreter=None,  # type: Optional[NaturalLanguageInterpreter]
        use_e2e=False,  # type: bool
        exclusion_percentage=None  # type: int
):
    # type: (...) -> StoryGraph
    from rasa_core.interpreter import RegexInterpreter
    from rasa_core.training.dsl import StoryFileReader
    from rasa_core.training.structures import StoryGraph

    if not interpreter:
        interpreter = RegexInterpreter()
    story_steps = StoryFileReader.read_from_folder(
        resource_name,
        domain,
        interpreter,
        use_e2e=use_e2e,
        exclusion_percentage=exclusion_percentage)
    return StoryGraph(story_steps)
コード例 #11
0
def visualize_stories(
        story_steps,  # type: List[StoryStep]
        domain,  # type: Domain
        output_file,  # type: Optional[Text]
        max_history,  # type: int
        interpreter=RegexInterpreter(),  # type: NaturalLanguageInterpreter
        nlu_training_data=None,  # type: Optional[TrainingData]
        should_merge_nodes=True,  # type: bool
        fontsize=12  # type: int
):
    """Given a set of stories, generates a graph visualizing the flows in the
    stories.

    Visualization is always a trade off between making the graph as small as
    possible while
    at the same time making sure the meaning doesn't change to "much". The
    algorithm will
    compress the graph generated from the stories to merge nodes that are
    similar. Hence,
    the algorithm might create paths through the graph that aren't actually
    specified in the
    stories, but we try to minimize that.

    Output file defines if and where a file containing the plotted graph
    should be stored.

    The history defines how much 'memory' the graph has. This influences in
    which situations the
    algorithm will merge nodes. Nodes will only be merged if they are equal
    within the history, this
    means the larger the history is we take into account the less likely it
    is we merge any nodes.

    The training data parameter can be used to pass in a Rasa NLU training
    data instance. It will
    be used to replace the user messages from the story file with actual
    messages from the training data."""
    import networkx as nx

    story_graph = StoryGraph(story_steps)
    graph = nx.MultiDiGraph()
    next_node_idx = 0
    graph.add_node(0,
                   label="START",
                   fillcolor="green",
                   style="filled",
                   fontsize=fontsize)
    graph.add_node(-1,
                   label="END",
                   fillcolor="red",
                   style="filled",
                   fontsize=fontsize)

    g = TrainingDataGenerator(story_graph,
                              domain,
                              use_story_concatenation=False,
                              tracker_limit=100,
                              augmentation_factor=0)
    completed_trackers = g.generate()

    for tracker in completed_trackers:
        message = None
        current_node = 0
        for el in tracker.events:
            if isinstance(el, UserUttered):
                message = interpreter.parse(el.text)
            elif (isinstance(el, ActionExecuted)
                  and el.action_name != ACTION_LISTEN_NAME):
                next_node_idx += 1
                graph.add_node(next_node_idx,
                               label=el.action_name,
                               fontsize=fontsize)

                if message:
                    message_key = message.get("intent", {}).get("name", None)
                    message_label = message.get("text", None)
                else:
                    message_key = None
                    message_label = None

                _add_edge(graph, current_node, next_node_idx, message_key,
                          message_label)
                current_node = next_node_idx

                message = None
        if message:
            graph.add_edge(current_node,
                           -1,
                           key=EDGE_NONE_LABEL,
                           label=message)
        else:
            graph.add_edge(current_node, -1, key=EDGE_NONE_LABEL)

    if should_merge_nodes:
        _merge_equivalent_nodes(graph, max_history)
    _replace_edge_labels_with_nodes(graph, next_node_idx, interpreter,
                                    nlu_training_data, fontsize)

    if output_file:
        persist_graph(graph, output_file)
    return graph