Пример #1
0
def _write_domain_to_file(domain_path: Text, evts: List[Dict[Text, Any]],
                          endpoint: EndpointConfig) -> None:
    """Write an updated domain file to the file path."""

    domain = retrieve_domain(endpoint)
    old_domain = Domain.from_dict(domain)

    messages = _collect_messages(evts)
    actions = _collect_actions(evts)

    domain_dict = dict.fromkeys(domain.keys(), [])

    # TODO for now there is no way to distinguish between action and form
    domain_dict["forms"] = []
    domain_dict["intents"] = _intents_from_messages(messages)
    domain_dict["entities"] = _entities_from_messages(messages)
    # do not automatically add default actions to the domain dict
    domain_dict["actions"] = list({
        e["name"]
        for e in actions if e["name"] not in default_action_names()
    })

    new_domain = Domain.from_dict(domain_dict)

    old_domain.merge(new_domain).persist_clean(domain_path)
Пример #2
0
def _write_domain_to_file(
    domain_path: Text,
    evts: List[Dict[Text, Any]],
    endpoint: EndpointConfig
) -> None:
    """Write an updated domain file to the file path."""

    domain = retrieve_domain(endpoint)
    old_domain = Domain.from_dict(domain)

    messages = _collect_messages(evts)
    actions = _collect_actions(evts)

    # TODO for now there is no way to distinguish between action and form
    intent_properties = Domain.collect_intent_properties(
        _intents_from_messages(messages))

    collected_actions = list({e["name"]
                              for e in actions
                              if e["name"] not in default_action_names()})

    new_domain = Domain(
        intent_properties=intent_properties,
        entities=_entities_from_messages(messages),
        slots=[],
        templates={},
        action_names=collected_actions,
        form_names=[])

    old_domain.merge(new_domain).persist_clean(domain_path)
Пример #3
0
def _write_domain_to_file(domain_path, evts, endpoint):
    # type: (Text, List[Dict[Text, Any]], EndpointConfig) -> None
    """Write an updated domain file to the file path."""

    domain = retrieve_domain(endpoint)
    old_domain = Domain.from_dict(domain)

    messages = _collect_messages(evts)
    actions = _collect_actions(evts)

    domain_dict = dict.fromkeys(domain.keys(), {})  # type: Dict[Text, Any]

    domain_dict["intents"] = _intents_from_messages(messages)
    domain_dict["entities"] = _entities_from_messages(messages)
    domain_dict["actions"] = list({e["name"] for e in actions})

    new_domain = Domain.from_dict(domain_dict)

    old_domain.merge(new_domain).persist_clean(domain_path)
Пример #4
0
def record_messages(
        endpoint,  # type: EndpointConfig
        sender_id=UserMessage.DEFAULT_SENDER_ID,  # type: Text
        max_message_limit=None,  # type: Optional[int]
        on_finish=None,  # type: Optional[Callable[[], None]]
        finetune=False,  # type: bool
        stories=None,  # type: Optional[Text]
        skip_visualization=False  # type: bool
):
    """Read messages from the command line and print bot responses."""

    from rasa_core import training

    try:
        _print_help(skip_visualization)

        try:
            domain = retrieve_domain(endpoint)
        except requests.exceptions.ConnectionError:
            logger.exception("Failed to connect to rasa core server at '{}'. "
                             "Is the server running?".format(endpoint.url))
            return

        trackers = training.load_data(
            stories,
            Domain.from_dict(domain),
            augmentation_factor=0,
            use_story_concatenation=False,
        )

        intents = [next(iter(i)) for i in (domain.get("intents") or [])]

        num_messages = 0
        sender_ids = [t.events for t in trackers] + [sender_id]

        if not skip_visualization:
            plot_file = "story_graph.dot"
            _plot_trackers(sender_ids, plot_file, endpoint)
        else:
            plot_file = None

        while not utils.is_limit_reached(num_messages, max_message_limit):
            try:
                if is_listening_for_message(sender_id, endpoint):
                    _enter_user_message(sender_id, endpoint)
                    _validate_nlu(intents, endpoint, sender_id)
                _predict_till_next_listen(endpoint, sender_id, finetune,
                                          sender_ids, plot_file)

                num_messages += 1
            except RestartConversation:
                send_event(endpoint, sender_id, {"event": "restart"})
                send_event(endpoint, sender_id, {
                    "event": "action",
                    "name": ACTION_LISTEN_NAME
                })

                logger.info("Restarted conversation, starting a new one.")
            except UndoLastStep:
                _undo_latest(sender_id, endpoint)
                _print_history(sender_id, endpoint)
            except ForkTracker:
                _print_history(sender_id, endpoint)

                evts = _request_fork_from_user(sender_id, endpoint)
                sender_id = uuid.uuid4().hex

                if evts is not None:
                    replace_events(endpoint, sender_id, evts)
                    sender_ids.append(sender_id)
                    _print_history(sender_id, endpoint)
                    _plot_trackers(sender_ids, plot_file, endpoint)

    except Exception:
        logger.exception("An exception occurred while recording messages.")
        raise
    finally:
        if on_finish:
            on_finish()
Пример #5
0
async def record_messages(endpoint: EndpointConfig,
                          sender_id: Text = UserMessage.DEFAULT_SENDER_ID,
                          max_message_limit: Optional[int] = None,
                          finetune: bool = False,
                          stories: Optional[Text] = None,
                          skip_visualization: bool = False):
    """Read messages from the command line and print bot responses."""

    from rasa_core import training

    try:
        _print_help(skip_visualization)

        try:
            domain = await retrieve_domain(endpoint)
        except ClientError:
            logger.exception("Failed to connect to Rasa Core server at '{}'. "
                             "Is the server running?".format(endpoint.url))
            return

        trackers = await training.load_data(
            stories,
            Domain.from_dict(domain),
            augmentation_factor=0,
            use_story_concatenation=False,
        )

        intents = [next(iter(i)) for i in (domain.get("intents") or [])]

        num_messages = 0
        sender_ids = [t.events for t in trackers] + [sender_id]

        if not skip_visualization:
            plot_file = "story_graph.dot"
            await _plot_trackers(sender_ids, plot_file, endpoint)
        else:
            plot_file = None

        while not utils.is_limit_reached(num_messages, max_message_limit):
            try:
                if await is_listening_for_message(sender_id, endpoint):
                    await _enter_user_message(sender_id, endpoint)
                    await _validate_nlu(intents, endpoint, sender_id)
                await _predict_till_next_listen(endpoint, sender_id, finetune,
                                                sender_ids, plot_file)

                num_messages += 1
            except RestartConversation:
                await send_event(endpoint, sender_id, Restarted().as_dict())

                await send_event(endpoint, sender_id,
                                 ActionExecuted(ACTION_LISTEN_NAME).as_dict())

                logger.info("Restarted conversation, starting a new one.")
            except UndoLastStep:
                await _undo_latest(sender_id, endpoint)
                await _print_history(sender_id, endpoint)
            except ForkTracker:
                await _print_history(sender_id, endpoint)

                evts_fork = await _request_fork_from_user(sender_id, endpoint)

                await send_event(endpoint, sender_id, Restarted().as_dict())

                if evts_fork:
                    for evt in evts_fork:
                        await send_event(endpoint, sender_id, evt)
                logger.info("Restarted conversation at fork.")

                await _print_history(sender_id, endpoint)
                await _plot_trackers(sender_ids, plot_file, endpoint)

    except Abort:
        return
    except Exception:
        logger.exception("An exception occurred while recording messages.")
        raise