Ejemplo n.º 1
0
def load_data(
    resource_name: Union[Text, "TrainingDataImporter"],
    domain: "Domain",
    remove_duplicates: bool = True,
    unique_last_num_states: Optional[int] = None,
    augmentation_factor: int = 50,
    tracker_limit: Optional[int] = None,
    use_story_concatenation: bool = True,
    debug_plots: bool = False,
    exclusion_percentage: Optional[int] = None,
) -> List["TrackerWithCachedStates"]:
    """
    Load training data from a resource.

    Args:
        resource_name: resource to load the data from. either a path or an importer
        domain: domain used for loading
        remove_duplicates: should duplicated training examples be removed?
        unique_last_num_states: number of states in a conversation that make the
            a tracker unique (this is used to identify duplicates)
        augmentation_factor:
            by how much should the story training data be augmented
        tracker_limit:
            maximum number of trackers to generate during augmentation
        use_story_concatenation:
            should stories be concatenated when doing data augmentation
        debug_plots:
            generate debug plots during loading
        exclusion_percentage:
            how much data to exclude

    Returns:
        list of loaded trackers
    """
    from rasa.shared.core.generator import TrainingDataGenerator
    from rasa.shared.importers.importer import TrainingDataImporter

    if resource_name:
        if isinstance(resource_name, TrainingDataImporter):
            graph = resource_name.get_stories(
                exclusion_percentage=exclusion_percentage)
        else:
            graph = extract_story_graph(
                resource_name,
                domain,
                exclusion_percentage=exclusion_percentage)

        g = TrainingDataGenerator(
            graph,
            domain,
            remove_duplicates,
            unique_last_num_states,
            augmentation_factor,
            tracker_limit,
            use_story_concatenation,
            debug_plots,
        )
        return g.generate()
    else:
        return []
Ejemplo n.º 2
0
def visualize_stories(
    story_steps: List[StoryStep],
    domain: Domain,
    output_file: Optional[Text],
    max_history: int,
    nlu_training_data: Optional["TrainingData"] = None,
    should_merge_nodes: bool = True,
    fontsize: int = 12,
) -> "networkx.MultiDiGraph":
    """Given a set of stories, generates a graph visualizing the flows in the stories.

    Visualization is always a trade off between making the graph as small as
    possible while
    at the same time making sure the meaning doesn't change to "much". The
    algorithm will
    compress the graph generated from the stories to merge nodes that are
    similar. Hence,
    the algorithm might create paths through the graph that aren't actually
    specified in the
    stories, but we try to minimize that.

    Output file defines if and where a file containing the plotted graph
    should be stored.

    The history defines how much 'memory' the graph has. This influences in
    which situations the
    algorithm will merge nodes. Nodes will only be merged if they are equal
    within the history, this
    means the larger the history is we take into account the less likely it
    is we merge any nodes.

    The training data parameter can be used to pass in a Rasa NLU training
    data instance. It will
    be used to replace the user messages from the story file with actual
    messages from the training data.
    """
    story_graph = StoryGraph(story_steps)

    g = TrainingDataGenerator(
        story_graph,
        domain,
        use_story_concatenation=False,
        tracker_limit=100,
        augmentation_factor=0,
    )
    completed_trackers = g.generate()
    event_sequences = [t.events for t in completed_trackers]

    graph = visualize_neighborhood(
        None,
        event_sequences,
        output_file,
        max_history,
        nlu_training_data,
        should_merge_nodes,
        max_distance=1,
        fontsize=fontsize,
    )
    return graph
Ejemplo n.º 3
0
    def provide(self, story_graph: StoryGraph,
                domain: Domain) -> List[TrackerWithCachedStates]:
        """Generates the training trackers from the training data.

        Args:
            story_graph: The story graph containing the test stories and rules.
            domain: The domain of the model.

        Returns:
            The trackers which can be used to train dialogue policies.
        """
        generator = TrainingDataGenerator(story_graph, domain, **self._config)
        return generator.generate()
Ejemplo n.º 4
0
    def verify_story_structure(self,
                               ignore_warnings: bool = True,
                               max_history: Optional[int] = None) -> bool:
        """Verifies that the bot behaviour in stories is deterministic.

        Args:
            ignore_warnings: When `True`, return `True` even if conflicts were found.
            max_history: Maximal number of events to take into account for conflict identification.

        Returns:
            `False` is a conflict was found and `ignore_warnings` is `False`.
            `True` otherwise.
        """

        logger.info("Story structure validation...")

        trackers = TrainingDataGenerator(
            self.story_graph,
            domain=self.domain,
            remove_duplicates=False,
            augmentation_factor=0,
        ).generate()

        # Create a list of `StoryConflict` objects
        conflicts = rasa.core.training.story_conflict.find_story_conflicts(
            trackers, self.domain, max_history)

        if not conflicts:
            logger.info("No story structure conflicts found.")
        else:
            for conflict in conflicts:
                logger.warning(conflict)

        return ignore_warnings or not conflicts
Ejemplo n.º 5
0
    def verify_story_structure(self, raise_exception: bool = True, max_history: Optional[int] = None):
        """
        Validates whether the bot behaviour in stories is deterministic.
        @param raise_exception: Set this flag to false to prevent raising exceptions.
        @param max_history:
        @return:
        """
        self.component_count['stories'] = 0
        self.component_count['rules'] = 0
        for steps in self.story_graph.story_steps:
            if isinstance(steps, RuleStep):
                self.component_count['rules'] += 1
            elif isinstance(steps, StoryStep):
                self.component_count['stories'] += 1

        trackers = TrainingDataGenerator(
            self.story_graph,
            domain=self.domain,
            remove_duplicates=False,
            augmentation_factor=0,
        ).generate_story_trackers()

        conflicts = find_story_conflicts(
            trackers, self.domain, max_history
        )

        if conflicts:
            conflict_msg = []
            for conflict in conflicts:
                conflict_msg.append(str(conflict))
                if raise_exception:
                    raise AppException(str(conflict))
            self.summary['stories'] = conflict_msg
Ejemplo n.º 6
0
def _create_data_generator(
    resource_name: Text,
    agent: "Agent",
    max_stories: Optional[int] = None,
    use_conversation_test_files: bool = False,
) -> "TrainingDataGenerator":
    from rasa.shared.core.generator import TrainingDataGenerator
    from rasa.shared.constants import DEFAULT_DOMAIN_PATH
    from rasa.model import get_model_subdirectories

    core_model = None
    if agent.model_directory:
        core_model, _ = get_model_subdirectories(agent.model_directory)

    if core_model and os.path.exists(
            os.path.join(core_model, DEFAULT_DOMAIN_PATH)):
        domain_path = os.path.join(core_model, DEFAULT_DOMAIN_PATH)
    else:
        domain_path = None

    test_data_importer = TrainingDataImporter.load_from_dict(
        training_data_paths=[resource_name], domain_path=domain_path)
    if use_conversation_test_files:
        story_graph = test_data_importer.get_conversation_tests()
    else:
        story_graph = test_data_importer.get_stories()

    return TrainingDataGenerator(
        story_graph,
        agent.domain,
        use_story_concatenation=False,
        augmentation_factor=0,
        tracker_limit=max_stories,
    )
Ejemplo n.º 7
0
Archivo: test.py Proyecto: zoovu/rasa
def _create_data_generator(
    resource_name: Text,
    agent: "Agent",
    max_stories: Optional[int] = None,
    use_conversation_test_files: bool = False,
) -> "TrainingDataGenerator":
    from rasa.shared.core.generator import TrainingDataGenerator

    tmp_domain_path = Path(tempfile.mkdtemp()) / "domain.yaml"
    agent.domain.persist(tmp_domain_path)
    test_data_importer = TrainingDataImporter.load_from_dict(
        training_data_paths=[resource_name], domain_path=str(tmp_domain_path)
    )
    if use_conversation_test_files:
        story_graph = test_data_importer.get_conversation_tests()
    else:
        story_graph = test_data_importer.get_stories()

    return TrainingDataGenerator(
        story_graph,
        agent.domain,
        use_story_concatenation=False,
        augmentation_factor=0,
        tracker_limit=max_stories,
    )
Ejemplo n.º 8
0
async def _generate_trackers(
    resource_name: Text,
    agent: "Agent",
    max_stories: Optional[int] = None,
    use_e2e: bool = False,
) -> List[Any]:
    from rasa.shared.core.generator import TrainingDataGenerator

    from rasa.core import training

    story_graph = await training.extract_story_graph(resource_name,
                                                     agent.domain, use_e2e)
    g = TrainingDataGenerator(
        story_graph,
        agent.domain,
        use_story_concatenation=False,
        augmentation_factor=0,
        tracker_limit=max_stories,
    )
    return g.generate_story_trackers()
Ejemplo n.º 9
0
async def _setup_trackers_for_testing(
        domain_path: Text, training_data_file: Text
) -> Tuple[List[TrackerWithCachedStates], Domain]:
    importer = RasaFileImporter(domain_path=domain_path,
                                training_data_paths=[training_data_file])
    validator = await Validator.from_importer(importer)

    trackers = TrainingDataGenerator(
        validator.story_graph,
        domain=validator.domain,
        remove_duplicates=False,
        augmentation_factor=0,
    ).generate()

    return trackers, validator.domain
Ejemplo n.º 10
0
Archivo: test.py Proyecto: tanbui/rasa
async def _create_data_generator(
    resource_name: Text,
    agent: "Agent",
    max_stories: Optional[int] = None,
    use_e2e: bool = False,
) -> "TrainingDataGenerator":
    from rasa.shared.core.generator import TrainingDataGenerator

    test_data_importer = TrainingDataImporter.load_from_dict(
        training_data_paths=[resource_name])
    story_graph = await test_data_importer.get_stories(use_e2e=use_e2e)
    return TrainingDataGenerator(
        story_graph,
        agent.domain,
        use_story_concatenation=False,
        augmentation_factor=0,
        tracker_limit=max_stories,
    )