예제 #1
0
def run_concerts(serve_forever=True):
    agent = Agent.load("examples/concerts/models/policy/init",
                       interpreter=RegexInterpreter())

    if serve_forever:
        agent.handle_channel(ConsoleInputChannel())
    return agent
def extract_stories_from_file(filename,
                              domain,
                              remove_duplicates=True,
                              interpreter=RegexInterpreter(),
                              max_number_of_trackers=2000):
    graph = extract_story_graph_from_file(filename, domain)
    return graph.build_stories(domain, interpreter, remove_duplicates,
                               max_number_of_trackers)
예제 #3
0
    def as_dialogue(self, sender, domain, interpreter=RegexInterpreter()):
        events = []
        for step in self.story_steps:
            events.extend(
                step.explicit_events(domain,
                                     interpreter,
                                     should_append_final_listen=False))

        events.append(ActionExecuted(ActionListen().name()))
        return Dialogue(sender, events)
def test_concerts_online_example():
    import conversationinsights.util
    conversationinsights.util.input = lambda _=None: "2"  # simulates cmdline input / detailed explanation above

    input_channel = FileInputChannel('examples/concerts/data/stories.md',
                                     message_line_pattern='^\s*\*\s(.*)$',
                                     max_messages=10)
    agent = run_concertbot_online(input_channel, RegexInterpreter())
    responses = agent.handle_message("_greet")
    assert responses[-1] in {
        "hey there!", "how can I help you?", "default message"
    }
예제 #5
0
    def _parse_message(self, message):
        # for testing - you can short-cut the NLU part with a message
        # in the format _intent[entity1=val1,entity=val2]
        # parse_data is a dict of intent & entities
        if message.text.startswith('_'):
            parse_data = RegexInterpreter().parse(message.text)
        else:
            parse_data = self.interpreter.parse(message.text)

        logger.debug(
            "Received user message '{}' with intent '{}' and entities  '{}'".
            format(message.text, parse_data["intent"], parse_data["entities"]))
        return parse_data
def extract_training_data_from_file(filename,
                                    augmentation_factor=20,
                                    max_history=1,
                                    remove_duplicates=True,
                                    domain=None,
                                    featurizer=None,
                                    interpreter=RegexInterpreter(),
                                    max_number_of_trackers=2000):
    graph = extract_story_graph_from_file(filename, domain)
    extractor = TrainingsDataExtractor(graph, domain, featurizer, interpreter)
    return extractor.extract_trainings_data(remove_duplicates,
                                            augmentation_factor, max_history,
                                            max_number_of_trackers)
예제 #7
0
    def __init__(
        self,
        story_graph,  # type: StoryGraph
        domain,  # type: Domain
        featurizer,  # type: Featurizer
        interpreter=None  # type: Optional[NaturalLanguageInterpreter]
    ):
        # type: (...) -> None

        self.story_graph = story_graph
        self.domain = domain
        self.featurizer = featurizer
        self.interpreter = interpreter if interpreter else RegexInterpreter()
예제 #8
0
def test_controller(default_domain, capsys):
    story_filename = "data/dsl_stories/stories_defaultdomain.md"
    ensemble = SimplePolicyEnsemble([ScoringPolicy()])
    interpreter = RegexInterpreter()

    PolicyTrainer(ensemble, default_domain,
                  BinaryFeaturizer()).train(story_filename, max_history=3)

    tracker_store = InMemoryTrackerStore(default_domain)
    processor = MessageProcessor(interpreter, ensemble, default_domain,
                                 tracker_store)

    processor.handle_message(UserMessage("_greet", ConsoleOutputChannel()))
    out, _ = capsys.readouterr()
    assert "hey there!" in out
    def run_online_training(self,
                            ensemble,
                            domain,
                            interpreter=None,
                            input_channel=None):
        from conversationinsights.agent import Agent
        if interpreter is None:
            interpreter = RegexInterpreter()

        bot = Agent(domain,
                    ensemble,
                    featurizer=self.featurizer,
                    interpreter=interpreter)
        bot.toggle_memoization(False)

        bot.handle_channel(
            input_channel if input_channel else ConsoleInputChannel())
예제 #10
0
    def build_stories(self,
                      domain,
                      interpreter=RegexInterpreter(),
                      remove_duplicates=True,
                      max_number_of_trackers=2000):
        # type: (Domain, NaturalLanguageInterpreter, bool, int) -> List[Story]
        """Build the stories of a graph."""
        from conversationinsights.training_utils.dsl import STORY_START, Story

        active_trackers = {STORY_START: [Story()]}
        rand = random.Random(42)

        for step in self.ordered_steps():
            if step.start_checkpoint in active_trackers:
                # these are the trackers that reached this story step
                # and that need to handle all events of the step
                incoming_trackers = active_trackers[step.start_checkpoint]

                if max_number_of_trackers is not None:
                    incoming_trackers = utils.subsample_array(
                        incoming_trackers, max_number_of_trackers, rand)

                events = step.explicit_events(domain, interpreter)
                # need to copy the tracker as multiple story steps might
                # start with the same checkpoint and all of them
                # will use the same set of incoming trackers
                if events:
                    trackers = [
                        Story(tracker.story_steps + [step])
                        for tracker in incoming_trackers
                    ]
                else:
                    trackers = []  # small optimization

                # update our tracker dictionary with the trackers that handled
                # the events of the step and that can now be used for further
                # story steps that start with the checkpoint this step ended on
                if step.end_checkpoint not in active_trackers:
                    active_trackers[step.end_checkpoint] = []
                active_trackers[step.end_checkpoint].extend(trackers)

        return active_trackers[None]
    def _prepare_training_data(self, filename, max_history, augmentation_factor,
                               max_training_samples=None,
                               max_number_of_trackers=2000):
        """Reads training data from file and prepares it for the training."""

        from conversationinsights.training_utils import extract_training_data_from_file

        if filename:
            X, y = extract_training_data_from_file(
                    filename,
                    augmentation_factor=augmentation_factor,
                    max_history=max_history,
                    remove_duplicates=True,
                    domain=self.domain,
                    featurizer=self.featurizer,
                    interpreter=RegexInterpreter(),
                    max_number_of_trackers=max_number_of_trackers)
            if max_training_samples is not None:
                X = X[:max_training_samples, :]
                y = y[:max_training_samples]
        else:
            X = np.zeros((0, self.domain.num_features))
            y = np.zeros(self.domain.num_actions)
        return X, y
def _test_stories(story_file,
                  policy_model_path,
                  nlu_model_path,
                  max_stories=None):
    """Test the stories from a file, running them through the stored model."""
    def actions_since_last_utterance(tracker):
        actions = []
        for e in reversed(tracker.events):
            if isinstance(e, UserUttered):
                break
            elif isinstance(e, ActionExecuted):
                actions.append(e.action_name)
        actions.reverse()
        return actions

    if nlu_model_path is not None:
        interpreter = RasaNLUInterpreter(model_directory=nlu_model_path)
    else:
        interpreter = RegexInterpreter()

    agent = Agent.load(policy_model_path, interpreter=interpreter)
    stories = _get_stories(story_file, agent.domain, max_stories=max_stories)
    preds = []
    actual = []

    logger.info("Evaluating {} stories\nProgress:".format(len(stories)))

    for s in tqdm(stories):
        sender = "default-" + uuid.uuid4().hex

        dialogue = s.as_dialogue(sender, agent.domain)
        actions_between_utterances = []
        last_prediction = []

        for i, event in enumerate(dialogue.events[1:]):
            if isinstance(event, UserUttered):
                p, a = _min_list_distance(last_prediction,
                                          actions_between_utterances)
                preds.extend(p)
                actual.extend(a)

                actions_between_utterances = []
                agent.handle_message(event.text, sender=sender)
                tracker = agent.tracker_store.retrieve(sender)
                last_prediction = actions_since_last_utterance(tracker)

            elif isinstance(event, ActionExecuted):
                actions_between_utterances.append(event.action_name)

        if last_prediction:
            preds.extend(last_prediction)
            preds_padding = len(actions_between_utterances) - \
                            len(last_prediction)
            preds.extend(["None"] * preds_padding)

            actual.extend(actions_between_utterances)
            actual_padding = len(last_prediction) - \
                             len(actions_between_utterances)
            actual.extend(["None"] * actual_padding)

    return actual, preds
def visualize_stories(story_steps,
                      output_file=None,
                      max_history=2,
                      interpreter=RegexInterpreter(),
                      training_data=None):
    """Given a set of stories, generates a graph visualizing the flows in the
    stories.

    Visualization is always a trade off between making the graph as small as
    possible while
    at the same time making sure the meaning doesn't change to "much". The
    algorithm will
    compress the graph generated from the stories to merge nodes that are
    similar. Hence,
    the algorithm might create paths through the graph that aren't actually
    specified in the
    stories, but we try to minimize that.

    Output file defines if and where a file containing the plotted graph
    should be stored.

    The history defines how much 'memory' the graph has. This influences in
    which situations the
    algorithm will merge nodes. Nodes will only be merged if they are equal
    within the history, this
    means the larger the history is we take into account the less likely it
    is we merge any nodes.

    The training data parameter can be used to pass in a Rasa NLU training
    data instance. It will
    be used to replace the user messages from the story file with actual
    messages from the training data."""
    import networkx as nx

    story_graph = StoryGraph(story_steps)
    G = nx.MultiDiGraph()
    next_node_idx = 0
    G.add_node(0, label="START", fillcolor="green", style="filled")
    G.add_node(-1, label="END", fillcolor="red", style="filled")

    checkpoint_indices = defaultdict(list)
    checkpoint_indices[STORY_START] = [0]

    for step in story_graph.ordered_steps():
        current_nodes = checkpoint_indices[step.start_checkpoint]
        message = None
        for el in step.events:
            if isinstance(el, UserUttered):
                message = interpreter.parse(el.text)
            elif isinstance(el, ActionExecuted):
                if message:
                    message_key = message.get("intent", {}).get("name", None)
                    message_label = message.get("text", None)
                else:
                    message_key = None
                    message_label = None

                next_node_idx += 1
                G.add_node(next_node_idx, label=el.action_name)
                for current_node in current_nodes:
                    _add_edge(G, current_node, next_node_idx, message_key,
                              message_label)

                current_nodes = [next_node_idx]
                message = None
        if not step.end_checkpoint:
            for current_node in current_nodes:
                G.add_edge(current_node, -1, key=EDGE_NONE_LABEL)
        else:
            checkpoint_indices[step.end_checkpoint].extend(current_nodes)

    _merge_equivalent_nodes(G, max_history)
    _replace_edge_labels_with_nodes(
            G, next_node_idx, interpreter, training_data)

    if output_file:
        _persist_graph(G, output_file)
    return G
from examples.concerts.policy import ConcertPolicy
from conversationinsights.agent import Agent
from conversationinsights.channels.console import ConsoleInputChannel
from conversationinsights.interpreter import RegexInterpreter
from conversationinsights.policies.memoization import MemoizationPolicy

logger = logging.getLogger(__name__)


def run_concertbot_online(input_channel, interpreter):
    training_data_file = 'examples/concerts/data/stories.md'

    agent = Agent("examples/concerts/concert_domain.yml",
                  policies=[MemoizationPolicy(),
                            ConcertPolicy()],
                  interpreter=interpreter)

    agent.train_online(training_data_file,
                       input_channel=input_channel,
                       max_history=2,
                       batch_size=50,
                       epochs=200,
                       max_training_samples=300)

    return agent


if __name__ == '__main__':
    logging.basicConfig(level="INFO")
    run_concertbot_online(ConsoleInputChannel(), RegexInterpreter())